Shortcuts

QuestionAnsweringData

class flash.text.question_answering.data.QuestionAnsweringData(train_input=None, val_input=None, test_input=None, predict_input=None, data_fetcher=None, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, val_split=None, batch_size=None, num_workers=0, sampler=None, pin_memory=True, persistent_workers=False)[source]

The QuestionAnsweringData class is a DataModule with a set of classmethods for loading data for extractive question answering.

classmethod from_csv(train_file=None, val_file=None, test_file=None, predict_file=None, input_cls=<class 'flash.text.question_answering.input.QuestionAnsweringCSVInput'>, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, question_column_name='question', context_column_name='context', answer_column_name='answer', **data_module_kwargs)[source]

Load the QuestionAnsweringData from CSV files containing questions, contexts and their corresponding answers.

Question snippets will be extracted from the question_column_name column in the CSV files. Context snippets will be extracted from the context_column_name column in the CSV files. Answer snippets will be extracted from the answer_column_name column in the CSV files. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.

Parameters
Return type

QuestionAnsweringData

Returns

The constructed QuestionAnsweringData.

Examples

The files can be in Comma Separated Values (CSV) format with either a .csv or .txt extension.

The file train_data.csv contains the following:

id,context,question,answer_text,answer_start
1,I am three years old,How old are you?,three,0
2,I am six feet tall,How tall are you?,six,0
3,I am eight years old,How old are you?,eight,0

The file predict_data.csv contains the following:

id,context,question
4,I am five feet tall,How tall are you?
>>> from flash import Trainer
>>> from flash.text import QuestionAnsweringData, QuestionAnsweringTask
>>> datamodule = QuestionAnsweringData.from_csv(
...     train_file="train_data.csv",
...     predict_file="predict_data.csv",
...     batch_size=2,
... )
>>> model = QuestionAnsweringTask(max_source_length=32, max_target_length=32)
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule)  
Training...
>>> trainer.predict(model, datamodule=datamodule)  
Predicting...

Alternatively, the files can be in Tab Separated Values (TSV) format with a .tsv extension.

The file train_data.tsv contains the following:

id  context                 question            answer_text answer_start
1   I am three years old    How old are you?    three       0
2   I am six feet tall          How tall are you?   six             0
3   I am eight years old    How old are you?    eight       0

The file predict_data.tsv contains the following:

id  context             question
4   I am five feet tall How tall are you?
>>> from flash import Trainer
>>> from flash.text import QuestionAnsweringData, QuestionAnsweringTask
>>> datamodule = QuestionAnsweringData.from_csv(
...     train_file="train_data.tsv",
...     predict_file="predict_data.tsv",
...     batch_size=2,
... )
>>> model = QuestionAnsweringTask(max_source_length=32, max_target_length=32)
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule)  
Training...
>>> trainer.predict(model, datamodule=datamodule)  
Predicting...
classmethod from_dicts(train_data=None, val_data=None, test_data=None, predict_data=None, input_cls=<class 'flash.text.question_answering.input.QuestionAnsweringDictionaryInput'>, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, question_column_name='question', context_column_name='context', answer_column_name='answer', **data_module_kwargs)[source]

Load the QuestionAnsweringData from Python dictionary objects containing questions, contexts and their corresponding answers.

Question snippets will be extracted from the question_column_name field in the dictionaries. Context snippets will be extracted from the context_column_name field in the dictionaries. Answer snippets will be extracted from the answer_column_name field in the dictionaries. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.

Parameters
  • train_data (Optional[Dict[str, Any]]) – The dictionary containing the training data.

  • val_data (Optional[Dict[str, Any]]) – The dictionary containing the validation data.

  • test_data (Optional[Dict[str, Any]]) – The dictionary containing the testing data.

  • predict_data (Optional[Dict[str, Any]]) – The dictionary containing the data to use when predicting.

  • input_cls (Type[Input]) – The Input type to use for loading the data.

  • transform (TypeVar(INPUT_TRANSFORM_TYPE, Type[flash.core.data.io.input_transform.InputTransform], Callable, Tuple[Union[StrEnum, str], Dict[str, Any]], Union[StrEnum, str], None)) – The InputTransform type to use.

  • transform_kwargs (Optional[Dict]) – Dict of keyword arguments to be provided when instantiating the transforms.

  • question_column_name (str) – The key in the JSON file to recognize the question field.

  • context_column_name (str) – The key in the JSON file to recognize the context field.

  • answer_column_name (str) – The key in the JSON file to recognize the answer field.

Return type

QuestionAnsweringData

Returns

The constructed QuestionAnsweringData.

Examples

>>> from flash import Trainer
>>> from flash.text import QuestionAnsweringData, QuestionAnsweringTask
>>> train_data = {
...     "id": ["12345", "12346", "12347", "12348"],
...     "context": [
...         "this is an answer one. this is a context one",
...         "this is an answer two. this is a context two",
...         "this is an answer three. this is a context three",
...         "this is an answer four. this is a context four",
...     ],
...     "question": [
...         "this is a question one",
...         "this is a question two",
...         "this is a question three",
...         "this is a question four",
...     ],
...     "answer_text": [
...         "this is an answer one",
...         "this is an answer two",
...         "this is an answer three",
...         "this is an answer four",
...     ],
...     "answer_start": [0, 0, 0, 0],
... }
>>> predict_data = {
...     "id": ["12349", "12350"],
...     "context": [
...         "this is an answer five. this is a context five",
...         "this is an answer six. this is a context six",
...     ],
...     "question": [
...         "this is a question five",
...         "this is a question six",
...     ],
... }
>>> datamodule = QuestionAnsweringData.from_dicts(
...     train_data=train_data,
...     predict_data=predict_data,
...     batch_size=2,
... )  
>>> model = QuestionAnsweringTask(max_source_length=32, max_target_length=32)
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule)  
Training...
>>> trainer.predict(model, datamodule=datamodule)  
Predicting...
classmethod from_json(train_file=None, val_file=None, test_file=None, predict_file=None, input_cls=<class 'flash.text.question_answering.input.QuestionAnsweringJSONInput'>, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, field=None, question_column_name='question', context_column_name='context', answer_column_name='answer', **data_module_kwargs)[source]

Load the QuestionAnsweringData from JSON files containing questions, contexts and their corresponding answers.

Question snippets will be extracted from the question_column_name column in the JSON files. Context snippets will be extracted from the context_column_name column in the JSON files. Answer snippets will be extracted from the answer_column_name column in the JSON files. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.

Parameters
Return type

QuestionAnsweringData

Returns

The constructed QuestionAnsweringData.

Examples

The file train_data.json contains the following:

{"id":"12345","context":"this is an answer one. this is a context one","question":"this is a question one",
"answer_text":"this is an answer one","answer_start":0}
{"id":"12346","context":"this is an answer two. this is a context two","question":"this is a question two",
"answer_text":"this is an answer two","answer_start":0}
{"id":"12347","context":"this is an answer three. this is a context three","question":"this is a question
 three","answer_text":"this is an answer three","answer_start":0}
{"id":"12348","context":"this is an answer four. this is a context four","question":"this is a question
 four","answer_text":"this is an answer four","answer_start":0}

The file predict_data.json contains the following:

{"id":"12349","context":"this is an answer five. this is a context five","question":"this is a question
 five"}
{"id":"12350","context":"this is an answer six. this is a context six","question":"this is a question
 six"}
>>> from flash import Trainer
>>> from flash.text import QuestionAnsweringData, QuestionAnsweringTask
>>> datamodule = QuestionAnsweringData.from_json(
...     train_file="train_data.json",
...     predict_file="predict_data.json",
...     batch_size=2,
... )  

...
>>> model = QuestionAnsweringTask(max_source_length=32, max_target_length=32)
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule)  
Training...
>>> trainer.predict(model, datamodule=datamodule)  
Predicting...
classmethod from_squad_v2(train_file=None, val_file=None, test_file=None, predict_file=None, input_cls=<class 'flash.text.question_answering.input.QuestionAnsweringSQuADInput'>, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, question_column_name='question', context_column_name='context', answer_column_name='answer', **data_module_kwargs)[source]

Load the QuestionAnsweringData from JSON files containing questions, contexts and their corresponding answers in the SQuAD2.0 format.

Question snippets will be extracted from the question_column_name column in the JSON files. Context snippets will be extracted from the context_column_name column in the JSON files. Answer snippets will be extracted from the answer_column_name column in the JSON files. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.

Parameters
  • train_file (Optional[str]) – The JSON file containing the training data.

  • val_file (Optional[str]) – The JSON file containing the validation data.

  • test_file (Optional[str]) – The JSON file containing the testing data.

  • predict_file (Optional[str]) – The JSON file containing the predict data.

  • input_cls (Type[Input]) – The Input type to use for loading the data.

  • transform (TypeVar(INPUT_TRANSFORM_TYPE, Type[flash.core.data.io.input_transform.InputTransform], Callable, Tuple[Union[StrEnum, str], Dict[str, Any]], Union[StrEnum, str], None)) – The InputTransform type to use.

  • transform_kwargs (Optional[Dict]) – Dict of keyword arguments to be provided when instantiating the transforms.

  • question_column_name (str) – The key in the JSON file to recognize the question field.

  • context_column_name (str) – The key in the JSON file to recognize the context field.

  • answer_column_name (str) – The key in the JSON file to recognize the answer field.

Return type

QuestionAnsweringData

Returns

The constructed data module.

Examples

The file train_data.json contains the following:

{
    "version": "v2.0",
    "data": [
        {
            "title": "ExampleSet1",
            "paragraphs": [
                {
                    "qas": [
                        {
                            "question": "this is a question one",
                            "id": "12345",
                            "answers": [{"text": "this is an answer one", "answer_start": 0}],
                            "is_impossible": false
                        }
                    ],
                    "context": "this is an answer one. this is a context one"
                }, {
                    "qas": [
                        {
                            "question": "this is a question two",
                            "id": "12346",
                            "answers": [{"text": "this is an answer two", "answer_start": 0}],
                            "is_impossible": false
                        }
                    ],
                    "context": "this is an answer two. this is a context two"
                }
            ]
        }, {
            "title": "ExampleSet2",
            "paragraphs": [
                {
                    "qas": [
                        {
                            "question": "this is a question three",
                            "id": "12347",
                            "answers": [{"text": "this is an answer three", "answer_start": 0}],
                            "is_impossible": false
                        }
                    ],
                    "context": "this is an answer three. this is a context three"
                }, {
                    "qas": [
                        {
                            "question": "this is a question four",
                            "id": "12348",
                            "answers": [{"text": "this is an answer four", "answer_start": 0}],
                            "is_impossible": false
                        }
                    ],
                    "context": "this is an answer four. this is a context four"
                }
            ]
        }
    ]
}

The file predict_data.json contains the following:

{
    "version": "v2.0",
    "data": [
        {
            "title": "ExampleSet3",
            "paragraphs": [
                {
                    "qas": [
                        {
                            "question": "this is a question five",
                            "id": "12349",
                            "is_impossible": false
                        }
                    ],
                    "context": "this is an answer five. this is a context five"
                }, {
                    "qas": [
                        {
                            "question": "this is a question six",
                            "id": "12350",
                            "is_impossible": false
                        }
                    ],
                    "context": "this is an answer six. this is a context six"
                }
            ]
        }
    ]
}
>>> from flash import Trainer
>>> from flash.text import QuestionAnsweringData, QuestionAnsweringTask
>>> datamodule = QuestionAnsweringData.from_squad_v2(
...     train_file="train_data.json",
...     predict_file="predict_data.json",
...     batch_size=2,
... )  
>>> model = QuestionAnsweringTask(max_source_length=32, max_target_length=32)
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule)  
Training...
>>> trainer.predict(model, datamodule=datamodule)  
Predicting...
input_transform_cls

alias of flash.core.data.io.input_transform.InputTransform

Read the Docs v: latest
Versions
latest
stable
0.8.2
0.8.1.post0
0.8.1
0.8.0
0.7.5
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.2
0.5.1
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
Downloads
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.