Shortcuts

Source code for flash.text.question_answering.data

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional, Type

from flash.core.data.data_module import DataModule
from flash.core.data.io.input import Input
from flash.core.data.io.input_transform import InputTransform
from flash.core.data.utilities.paths import PATH_TYPE
from flash.core.utilities.imports import _TEXT_AVAILABLE
from flash.core.utilities.stages import RunningStage
from flash.core.utilities.types import INPUT_TRANSFORM_TYPE
from flash.text.question_answering.input import (
    QuestionAnsweringCSVInput,
    QuestionAnsweringDictionaryInput,
    QuestionAnsweringJSONInput,
    QuestionAnsweringSQuADInput,
)

# Skip doctests if requirements aren't available
if not _TEXT_AVAILABLE:
    __doctest_skip__ = ["QuestionAnsweringData", "QuestionAnsweringData.*"]


[docs]class QuestionAnsweringData(DataModule): """The ``QuestionAnsweringData`` class is a :class:`~flash.core.data.data_module.DataModule` with a set of classmethods for loading data for extractive question answering.""" input_transform_cls = InputTransform
[docs] @classmethod def from_csv( cls, train_file: Optional[PATH_TYPE] = None, val_file: Optional[PATH_TYPE] = None, test_file: Optional[PATH_TYPE] = None, predict_file: Optional[PATH_TYPE] = None, train_transform: INPUT_TRANSFORM_TYPE = InputTransform, val_transform: INPUT_TRANSFORM_TYPE = InputTransform, test_transform: INPUT_TRANSFORM_TYPE = InputTransform, predict_transform: INPUT_TRANSFORM_TYPE = InputTransform, input_cls: Type[Input] = QuestionAnsweringCSVInput, transform_kwargs: Optional[Dict] = None, question_column_name: str = "question", context_column_name: str = "context", answer_column_name: str = "answer", **data_module_kwargs: Any, ) -> "QuestionAnsweringData": """Load the :class:`~flash.text.question_answering.data.QuestionAnsweringData` from CSV files containing questions, contexts and their corresponding answers. Question snippets will be extracted from the ``question_column_name`` column in the CSV files. Context snippets will be extracted from the ``context_column_name`` column in the CSV files. Answer snippets will be extracted from the ``answer_column_name`` column in the CSV files. To learn how to customize the transforms applied for each stage, read our :ref:`customizing transforms guide <customizing_transforms>`. Args: train_file: The CSV file containing the training data. val_file: The CSV file containing the validation data. test_file: The CSV file containing the testing data. predict_file: The CSV file containing the data to use when predicting. train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when predicting. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. question_column_name: The key in the JSON file to recognize the question field. context_column_name: The key in the JSON file to recognize the context field. answer_column_name: The key in the JSON file to recognize the answer field. Returns: The constructed :class:`~flash.text.question_answering.data.QuestionAnsweringData`. Examples ________ .. testsetup:: >>> import os >>> from pandas import DataFrame >>> DataFrame.from_dict({ ... "id": ["12345", "12346", "12347", "12348"], ... "context": [ ... "this is an answer one. this is a context one", ... "this is an answer two. this is a context two", ... "this is an answer three. this is a context three", ... "this is an answer four. this is a context four", ... ], ... "question": [ ... "this is a question one", ... "this is a question two", ... "this is a question three", ... "this is a question four", ... ], ... "answer_text": [ ... "this is an answer one", ... "this is an answer two", ... "this is an answer three", ... "this is an answer four", ... ], ... "answer_start": [0, 0, 0, 0], ... }).to_csv("train_data.csv", index=False) >>> DataFrame.from_dict({ ... "id": ["12349", "12350"], ... "context": [ ... "this is an answer five. this is a context five", ... "this is an answer six. this is a context six", ... ], ... "question": [ ... "this is a question five", ... "this is a question six", ... ], ... }).to_csv("predict_data.csv", index=False) The file ``train_data.csv`` contains the following: .. code-block:: id,context,question,answer_text,answer_start 12345,this is an answer one. this is a context one,this is a question one,this is an answer one,0 12346,this is an answer two. this is a context two,this is a question two,this is an answer two,0 12347,this is an answer three. this is a context three,this is a question three,this is an answer three,0 12348,this is an answer four. this is a context four,this is a question four,this is an answer four,0 The file ``predict_data.csv`` contains the following: .. code-block:: id,context,question 12349,this is an answer five. this is a context five,this is a question five 12350,this is an answer six. this is a context six,this is a question six .. doctest:: >>> from flash import Trainer >>> from flash.text import QuestionAnsweringData, QuestionAnsweringTask >>> datamodule = QuestionAnsweringData.from_csv( ... train_file="train_data.csv", ... predict_file="predict_data.csv", ... batch_size=2, ... ) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE Downloading... >>> model = QuestionAnsweringTask(max_source_length=32, max_target_length=32) >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE Training... >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE Predicting... .. testcleanup:: >>> os.remove("train_data.csv") >>> os.remove("predict_data.csv") """ ds_kw = dict( question_column_name=question_column_name, context_column_name=context_column_name, answer_column_name=answer_column_name, transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) return cls( input_cls(RunningStage.TRAINING, train_file, transform=train_transform, **ds_kw), input_cls(RunningStage.VALIDATING, val_file, transform=val_transform, **ds_kw), input_cls(RunningStage.TESTING, test_file, transform=test_transform, **ds_kw), input_cls(RunningStage.PREDICTING, predict_file, transform=predict_transform, **ds_kw), **data_module_kwargs, )
[docs] @classmethod def from_json( cls, train_file: Optional[PATH_TYPE] = None, val_file: Optional[PATH_TYPE] = None, test_file: Optional[PATH_TYPE] = None, predict_file: Optional[PATH_TYPE] = None, train_transform: INPUT_TRANSFORM_TYPE = InputTransform, val_transform: INPUT_TRANSFORM_TYPE = InputTransform, test_transform: INPUT_TRANSFORM_TYPE = InputTransform, predict_transform: INPUT_TRANSFORM_TYPE = InputTransform, input_cls: Type[Input] = QuestionAnsweringJSONInput, transform_kwargs: Optional[Dict] = None, field: Optional[str] = None, question_column_name: str = "question", context_column_name: str = "context", answer_column_name: str = "answer", **data_module_kwargs: Any, ) -> "QuestionAnsweringData": """Load the :class:`~flash.text.question_answering.data.QuestionAnsweringData` from JSON files containing questions, contexts and their corresponding answers. Question snippets will be extracted from the ``question_column_name`` column in the JSON files. Context snippets will be extracted from the ``context_column_name`` column in the JSON files. Answer snippets will be extracted from the ``answer_column_name`` column in the JSON files. To learn how to customize the transforms applied for each stage, read our :ref:`customizing transforms guide <customizing_transforms>`. Args: train_file: The JSON file containing the training data. val_file: The JSON file containing the validation data. test_file: The JSON file containing the testing data. predict_file: The JSON file containing the data to use when predicting. train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when predicting. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. field: The field that holds the data in the JSON file. question_column_name: The key in the JSON file to recognize the question field. context_column_name: The key in the JSON file to recognize the context field. answer_column_name: The key in the JSON file to recognize the answer field. Returns: The constructed :class:`~flash.text.question_answering.data.QuestionAnsweringData`. Examples ________ .. testsetup:: >>> import os >>> from pandas import DataFrame >>> DataFrame.from_dict({ ... "id": ["12345", "12346", "12347", "12348"], ... "context": [ ... "this is an answer one. this is a context one", ... "this is an answer two. this is a context two", ... "this is an answer three. this is a context three", ... "this is an answer four. this is a context four", ... ], ... "question": [ ... "this is a question one", ... "this is a question two", ... "this is a question three", ... "this is a question four", ... ], ... "answer_text": [ ... "this is an answer one", ... "this is an answer two", ... "this is an answer three", ... "this is an answer four", ... ], ... "answer_start": [0, 0, 0, 0], ... }).to_json("train_data.json", orient="records", lines=True) >>> DataFrame.from_dict({ ... "id": ["12349", "12350"], ... "context": [ ... "this is an answer five. this is a context five", ... "this is an answer six. this is a context six", ... ], ... "question": [ ... "this is a question five", ... "this is a question six", ... ], ... }).to_json("predict_data.json", orient="records", lines=True) The file ``train_data.json`` contains the following: .. code-block:: {"id":"12345","context":"this is an answer one. this is a context one","question":"this is a question one", "answer_text":"this is an answer one","answer_start":0} {"id":"12346","context":"this is an answer two. this is a context two","question":"this is a question two", "answer_text":"this is an answer two","answer_start":0} {"id":"12347","context":"this is an answer three. this is a context three","question":"this is a question three","answer_text":"this is an answer three","answer_start":0} {"id":"12348","context":"this is an answer four. this is a context four","question":"this is a question four","answer_text":"this is an answer four","answer_start":0} The file ``predict_data.json`` contains the following: .. code-block:: {"id":"12349","context":"this is an answer five. this is a context five","question":"this is a question five"} {"id":"12350","context":"this is an answer six. this is a context six","question":"this is a question six"} .. doctest:: >>> from flash import Trainer >>> from flash.text import QuestionAnsweringData, QuestionAnsweringTask >>> datamodule = QuestionAnsweringData.from_json( ... train_file="train_data.json", ... predict_file="predict_data.json", ... batch_size=2, ... ) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE Downloading... >>> model = QuestionAnsweringTask(max_source_length=32, max_target_length=32) >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE Training... >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE Predicting... .. testcleanup:: >>> os.remove("train_data.json") >>> os.remove("predict_data.json") """ ds_kw = dict( field=field, question_column_name=question_column_name, context_column_name=context_column_name, answer_column_name=answer_column_name, transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) return cls( input_cls(RunningStage.TRAINING, train_file, transform=train_transform, **ds_kw), input_cls(RunningStage.VALIDATING, val_file, transform=val_transform, **ds_kw), input_cls(RunningStage.TESTING, test_file, transform=test_transform, **ds_kw), input_cls(RunningStage.PREDICTING, predict_file, transform=predict_transform, **ds_kw), **data_module_kwargs, )
[docs] @classmethod def from_squad_v2( cls, train_file: Optional[str] = None, val_file: Optional[str] = None, test_file: Optional[str] = None, predict_file: Optional[str] = None, train_transform: INPUT_TRANSFORM_TYPE = InputTransform, val_transform: INPUT_TRANSFORM_TYPE = InputTransform, test_transform: INPUT_TRANSFORM_TYPE = InputTransform, predict_transform: INPUT_TRANSFORM_TYPE = InputTransform, input_cls: Type[Input] = QuestionAnsweringSQuADInput, transform_kwargs: Optional[Dict] = None, question_column_name: str = "question", context_column_name: str = "context", answer_column_name: str = "answer", **data_module_kwargs: Any, ) -> "QuestionAnsweringData": """Load the :class:`~flash.text.question_answering.data.QuestionAnsweringData` from JSON files containing questions, contexts and their corresponding answers in the SQuAD2.0 format. Question snippets will be extracted from the ``question_column_name`` column in the JSON files. Context snippets will be extracted from the ``context_column_name`` column in the JSON files. Answer snippets will be extracted from the ``answer_column_name`` column in the JSON files. To learn how to customize the transforms applied for each stage, read our :ref:`customizing transforms guide <customizing_transforms>`. Args: train_file: The JSON file containing the training data. val_file: The JSON file containing the validation data. test_file: The JSON file containing the testing data. predict_file: The JSON file containing the predict data. train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when predicting. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. question_column_name: The key in the JSON file to recognize the question field. context_column_name: The key in the JSON file to recognize the context field. answer_column_name: The key in the JSON file to recognize the answer field. Returns: The constructed data module. Examples ________ .. testsetup:: >>> import json >>> from pathlib import Path >>> train_data = Path("train_data.json") >>> predict_data = Path("predict_data.json") >>> _ = train_data.write_text( ... json.dumps( ... { ... "version": "v2.0", ... "data": [ ... { ... "title": "ExampleSet1", ... "paragraphs": [ ... { ... "qas": [ ... { ... "question": "this is a question one", ... "id": "12345", ... "answers": [{"text": "this is an answer one", "answer_start": 0}], ... "is_impossible": False, ... } ... ], ... "context": "this is an answer one. this is a context one", ... }, ... { ... "qas": [ ... { ... "question": "this is a question two", ... "id": "12346", ... "answers": [{"text": "this is an answer two", "answer_start": 0}], ... "is_impossible": False, ... } ... ], ... "context": "this is an answer two. this is a context two", ... }, ... ], ... }, ... { ... "title": "ExampleSet2", ... "paragraphs": [ ... { ... "qas": [ ... { ... "question": "this is a question three", ... "id": "12347", ... "answers": [{"text": "this is an answer three", "answer_start": 0}], ... "is_impossible": False, ... } ... ], ... "context": "this is an answer three. this is a context three", ... }, ... { ... "qas": [ ... { ... "question": "this is a question four", ... "id": "12348", ... "answers": [{"text": "this is an answer four", "answer_start": 0}], ... "is_impossible": False, ... } ... ], ... "context": "this is an answer four. this is a context four", ... }, ... ], ... }, ... ], ... } ... ) ... ) >>> _ = predict_data.write_text( ... json.dumps( ... { ... "version": "v2.0", ... "data": [ ... { ... "title": "ExampleSet3", ... "paragraphs": [ ... { ... "qas": [ ... { ... "question": "this is a question five", ... "id": "12349", ... "is_impossible": False, ... } ... ], ... "context": "this is an answer five. this is a context five", ... }, ... { ... "qas": [ ... { ... "question": "this is a question six", ... "id": "12350", ... "is_impossible": False, ... } ... ], ... "context": "this is an answer six. this is a context six", ... }, ... ], ... } ... ], ... } ... ) ... ) The file ``train_data.json`` contains the following: .. code-block:: { "version": "v2.0", "data": [ { "title": "ExampleSet1", "paragraphs": [ { "qas": [ { "question": "this is a question one", "id": "12345", "answers": [{"text": "this is an answer one", "answer_start": 0}], "is_impossible": false } ], "context": "this is an answer one. this is a context one" }, { "qas": [ { "question": "this is a question two", "id": "12346", "answers": [{"text": "this is an answer two", "answer_start": 0}], "is_impossible": false } ], "context": "this is an answer two. this is a context two" } ] }, { "title": "ExampleSet2", "paragraphs": [ { "qas": [ { "question": "this is a question three", "id": "12347", "answers": [{"text": "this is an answer three", "answer_start": 0}], "is_impossible": false } ], "context": "this is an answer three. this is a context three" }, { "qas": [ { "question": "this is a question four", "id": "12348", "answers": [{"text": "this is an answer four", "answer_start": 0}], "is_impossible": false } ], "context": "this is an answer four. this is a context four" } ] } ] } The file ``predict_data.json`` contains the following: .. code-block:: { "version": "v2.0", "data": [ { "title": "ExampleSet3", "paragraphs": [ { "qas": [ { "question": "this is a question five", "id": "12349", "is_impossible": false } ], "context": "this is an answer five. this is a context five" }, { "qas": [ { "question": "this is a question six", "id": "12350", "is_impossible": false } ], "context": "this is an answer six. this is a context six" } ] } ] } .. doctest:: >>> from flash import Trainer >>> from flash.text import QuestionAnsweringData, QuestionAnsweringTask >>> datamodule = QuestionAnsweringData.from_squad_v2( ... train_file="train_data.json", ... predict_file="predict_data.json", ... batch_size=2, ... ) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE >>> model = QuestionAnsweringTask(max_source_length=32, max_target_length=32) >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE Training... >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE Predicting... .. testcleanup:: >>> import os >>> os.remove("train_data.json") >>> os.remove("predict_data.json") """ ds_kw = dict( question_column_name=question_column_name, context_column_name=context_column_name, answer_column_name=answer_column_name, transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) return cls( input_cls(RunningStage.TRAINING, train_file, transform=train_transform, **ds_kw), input_cls(RunningStage.VALIDATING, val_file, transform=val_transform, **ds_kw), input_cls(RunningStage.TESTING, test_file, transform=test_transform, **ds_kw), input_cls(RunningStage.PREDICTING, predict_file, transform=predict_transform, **ds_kw), **data_module_kwargs, )
[docs] @classmethod def from_dicts( cls, train_data: Optional[Dict[str, Any]] = None, val_data: Optional[Dict[str, Any]] = None, test_data: Optional[Dict[str, Any]] = None, predict_data: Optional[Dict[str, Any]] = None, train_transform: INPUT_TRANSFORM_TYPE = InputTransform, val_transform: INPUT_TRANSFORM_TYPE = InputTransform, test_transform: INPUT_TRANSFORM_TYPE = InputTransform, predict_transform: INPUT_TRANSFORM_TYPE = InputTransform, input_cls: Type[Input] = QuestionAnsweringDictionaryInput, transform_kwargs: Optional[Dict] = None, question_column_name: str = "question", context_column_name: str = "context", answer_column_name: str = "answer", **data_module_kwargs: Any, ) -> "QuestionAnsweringData": """Load the :class:`~flash.text.question_answering.data.QuestionAnsweringData` from Python dictionary objects containing questions, contexts and their corresponding answers. Question snippets will be extracted from the ``question_column_name`` field in the dictionaries. Context snippets will be extracted from the ``context_column_name`` field in the dictionaries. Answer snippets will be extracted from the ``answer_column_name`` field in the dictionaries. To learn how to customize the transforms applied for each stage, read our :ref:`customizing transforms guide <customizing_transforms>`. Args: train_data: The dictionary containing the training data. val_data: The dictionary containing the validation data. test_data: The dictionary containing the testing data. predict_data: The dictionary containing the data to use when predicting. train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training. val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating. test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing. predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when predicting. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. question_column_name: The key in the JSON file to recognize the question field. context_column_name: The key in the JSON file to recognize the context field. answer_column_name: The key in the JSON file to recognize the answer field. Returns: The constructed :class:`~flash.text.question_answering.data.QuestionAnsweringData`. Examples ________ .. doctest:: >>> from flash import Trainer >>> from flash.text import QuestionAnsweringData, QuestionAnsweringTask >>> train_data = { ... "id": ["12345", "12346", "12347", "12348"], ... "context": [ ... "this is an answer one. this is a context one", ... "this is an answer two. this is a context two", ... "this is an answer three. this is a context three", ... "this is an answer four. this is a context four", ... ], ... "question": [ ... "this is a question one", ... "this is a question two", ... "this is a question three", ... "this is a question four", ... ], ... "answer_text": [ ... "this is an answer one", ... "this is an answer two", ... "this is an answer three", ... "this is an answer four", ... ], ... "answer_start": [0, 0, 0, 0], ... } >>> predict_data = { ... "id": ["12349", "12350"], ... "context": [ ... "this is an answer five. this is a context five", ... "this is an answer six. this is a context six", ... ], ... "question": [ ... "this is a question five", ... "this is a question six", ... ], ... } >>> datamodule = QuestionAnsweringData.from_dicts( ... train_data=train_data, ... predict_data=predict_data, ... batch_size=2, ... ) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE >>> model = QuestionAnsweringTask(max_source_length=32, max_target_length=32) >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE Training... >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE Predicting... .. testcleanup:: >>> del train_data >>> del predict_data """ ds_kw = dict( question_column_name=question_column_name, context_column_name=context_column_name, answer_column_name=answer_column_name, transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) return cls( input_cls(RunningStage.TRAINING, train_data, transform=train_transform, **ds_kw), input_cls(RunningStage.VALIDATING, val_data, transform=val_transform, **ds_kw), input_cls(RunningStage.TESTING, test_data, transform=test_transform, **ds_kw), input_cls(RunningStage.PREDICTING, predict_data, transform=predict_transform, **ds_kw), **data_module_kwargs, )

© Copyright 2020-2021, PyTorch Lightning. Revision f04e9026.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: 0.7.5
Versions
latest
stable
0.7.5
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.2
0.5.1
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
docs-fix_typing
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.