QuestionAnsweringData¶
- class flash.text.question_answering.data.QuestionAnsweringData(train_input=None, val_input=None, test_input=None, predict_input=None, data_fetcher=None, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, val_split=None, batch_size=None, num_workers=0, sampler=None, pin_memory=True, persistent_workers=False)[source]¶
The
QuestionAnsweringData
class is aDataModule
with a set of classmethods for loading data for extractive question answering.- classmethod from_csv(train_file=None, val_file=None, test_file=None, predict_file=None, input_cls=<class 'flash.text.question_answering.input.QuestionAnsweringCSVInput'>, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, question_column_name='question', context_column_name='context', answer_column_name='answer', **data_module_kwargs)[source]¶
Load the
QuestionAnsweringData
from CSV files containing questions, contexts and their corresponding answers.Question snippets will be extracted from the
question_column_name
column in the CSV files. Context snippets will be extracted from thecontext_column_name
column in the CSV files. Answer snippets will be extracted from theanswer_column_name
column in the CSV files. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.- Parameters
train_file¶ (
Union
[str
,bytes
,PathLike
,None
]) – The CSV file containing the training data.val_file¶ (
Union
[str
,bytes
,PathLike
,None
]) – The CSV file containing the validation data.test_file¶ (
Union
[str
,bytes
,PathLike
,None
]) – The CSV file containing the testing data.predict_file¶ (
Union
[str
,bytes
,PathLike
,None
]) – The CSV file containing the data to use when predicting.input_cls¶ (
Type
[Input
]) – TheInput
type to use for loading the data.transform¶ (
TypeVar
(INPUT_TRANSFORM_TYPE
,Type
[flash.core.data.io.input_transform.InputTransform],Callable
,Tuple
[Union
[StrEnum
,str
],Dict
[str
,Any
]],Union
[StrEnum
,str
],None
)) – TheInputTransform
type to use.transform_kwargs¶ (
Optional
[Dict
]) – Dict of keyword arguments to be provided when instantiating the transforms.question_column_name¶ (
str
) – The key in the JSON file to recognize the question field.context_column_name¶ (
str
) – The key in the JSON file to recognize the context field.answer_column_name¶ (
str
) – The key in the JSON file to recognize the answer field.
- Return type
- Returns
The constructed
QuestionAnsweringData
.
Examples
The files can be in Comma Separated Values (CSV) format with either a
.csv
or.txt
extension.The file
train_data.csv
contains the following:id,context,question,answer_text,answer_start 1,I am three years old,How old are you?,three,0 2,I am six feet tall,How tall are you?,six,0 3,I am eight years old,How old are you?,eight,0
The file
predict_data.csv
contains the following:id,context,question 4,I am five feet tall,How tall are you?
>>> from flash import Trainer >>> from flash.text import QuestionAnsweringData, QuestionAnsweringTask >>> datamodule = QuestionAnsweringData.from_csv( ... train_file="train_data.csv", ... predict_file="predict_data.csv", ... batch_size=2, ... ) >>> model = QuestionAnsweringTask(max_source_length=32, max_target_length=32) >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) Training... >>> trainer.predict(model, datamodule=datamodule) Predicting...
Alternatively, the files can be in Tab Separated Values (TSV) format with a
.tsv
extension.The file
train_data.tsv
contains the following:id context question answer_text answer_start 1 I am three years old How old are you? three 0 2 I am six feet tall How tall are you? six 0 3 I am eight years old How old are you? eight 0
The file
predict_data.tsv
contains the following:id context question 4 I am five feet tall How tall are you?
>>> from flash import Trainer >>> from flash.text import QuestionAnsweringData, QuestionAnsweringTask >>> datamodule = QuestionAnsweringData.from_csv( ... train_file="train_data.tsv", ... predict_file="predict_data.tsv", ... batch_size=2, ... ) >>> model = QuestionAnsweringTask(max_source_length=32, max_target_length=32) >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) Training... >>> trainer.predict(model, datamodule=datamodule) Predicting...
- classmethod from_dicts(train_data=None, val_data=None, test_data=None, predict_data=None, input_cls=<class 'flash.text.question_answering.input.QuestionAnsweringDictionaryInput'>, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, question_column_name='question', context_column_name='context', answer_column_name='answer', **data_module_kwargs)[source]¶
Load the
QuestionAnsweringData
from Python dictionary objects containing questions, contexts and their corresponding answers.Question snippets will be extracted from the
question_column_name
field in the dictionaries. Context snippets will be extracted from thecontext_column_name
field in the dictionaries. Answer snippets will be extracted from theanswer_column_name
field in the dictionaries. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.- Parameters
train_data¶ (
Optional
[Dict
[str
,Any
]]) – The dictionary containing the training data.val_data¶ (
Optional
[Dict
[str
,Any
]]) – The dictionary containing the validation data.test_data¶ (
Optional
[Dict
[str
,Any
]]) – The dictionary containing the testing data.predict_data¶ (
Optional
[Dict
[str
,Any
]]) – The dictionary containing the data to use when predicting.input_cls¶ (
Type
[Input
]) – TheInput
type to use for loading the data.transform¶ (
TypeVar
(INPUT_TRANSFORM_TYPE
,Type
[flash.core.data.io.input_transform.InputTransform],Callable
,Tuple
[Union
[StrEnum
,str
],Dict
[str
,Any
]],Union
[StrEnum
,str
],None
)) – TheInputTransform
type to use.transform_kwargs¶ (
Optional
[Dict
]) – Dict of keyword arguments to be provided when instantiating the transforms.question_column_name¶ (
str
) – The key in the JSON file to recognize the question field.context_column_name¶ (
str
) – The key in the JSON file to recognize the context field.answer_column_name¶ (
str
) – The key in the JSON file to recognize the answer field.
- Return type
- Returns
The constructed
QuestionAnsweringData
.
Examples
>>> from flash import Trainer >>> from flash.text import QuestionAnsweringData, QuestionAnsweringTask >>> train_data = { ... "id": ["12345", "12346", "12347", "12348"], ... "context": [ ... "this is an answer one. this is a context one", ... "this is an answer two. this is a context two", ... "this is an answer three. this is a context three", ... "this is an answer four. this is a context four", ... ], ... "question": [ ... "this is a question one", ... "this is a question two", ... "this is a question three", ... "this is a question four", ... ], ... "answer_text": [ ... "this is an answer one", ... "this is an answer two", ... "this is an answer three", ... "this is an answer four", ... ], ... "answer_start": [0, 0, 0, 0], ... } >>> predict_data = { ... "id": ["12349", "12350"], ... "context": [ ... "this is an answer five. this is a context five", ... "this is an answer six. this is a context six", ... ], ... "question": [ ... "this is a question five", ... "this is a question six", ... ], ... } >>> datamodule = QuestionAnsweringData.from_dicts( ... train_data=train_data, ... predict_data=predict_data, ... batch_size=2, ... ) >>> model = QuestionAnsweringTask(max_source_length=32, max_target_length=32) >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) Training... >>> trainer.predict(model, datamodule=datamodule) Predicting...
- classmethod from_json(train_file=None, val_file=None, test_file=None, predict_file=None, input_cls=<class 'flash.text.question_answering.input.QuestionAnsweringJSONInput'>, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, field=None, question_column_name='question', context_column_name='context', answer_column_name='answer', **data_module_kwargs)[source]¶
Load the
QuestionAnsweringData
from JSON files containing questions, contexts and their corresponding answers.Question snippets will be extracted from the
question_column_name
column in the JSON files. Context snippets will be extracted from thecontext_column_name
column in the JSON files. Answer snippets will be extracted from theanswer_column_name
column in the JSON files. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.- Parameters
train_file¶ (
Union
[str
,bytes
,PathLike
,None
]) – The JSON file containing the training data.val_file¶ (
Union
[str
,bytes
,PathLike
,None
]) – The JSON file containing the validation data.test_file¶ (
Union
[str
,bytes
,PathLike
,None
]) – The JSON file containing the testing data.predict_file¶ (
Union
[str
,bytes
,PathLike
,None
]) – The JSON file containing the data to use when predicting.input_cls¶ (
Type
[Input
]) – TheInput
type to use for loading the data.transform¶ (
TypeVar
(INPUT_TRANSFORM_TYPE
,Type
[flash.core.data.io.input_transform.InputTransform],Callable
,Tuple
[Union
[StrEnum
,str
],Dict
[str
,Any
]],Union
[StrEnum
,str
],None
)) – TheInputTransform
type to use.transform_kwargs¶ (
Optional
[Dict
]) – Dict of keyword arguments to be provided when instantiating the transforms.field¶ (
Optional
[str
]) – The field that holds the data in the JSON file.question_column_name¶ (
str
) – The key in the JSON file to recognize the question field.context_column_name¶ (
str
) – The key in the JSON file to recognize the context field.answer_column_name¶ (
str
) – The key in the JSON file to recognize the answer field.
- Return type
- Returns
The constructed
QuestionAnsweringData
.
Examples
The file
train_data.json
contains the following:{"id":"12345","context":"this is an answer one. this is a context one","question":"this is a question one", "answer_text":"this is an answer one","answer_start":0} {"id":"12346","context":"this is an answer two. this is a context two","question":"this is a question two", "answer_text":"this is an answer two","answer_start":0} {"id":"12347","context":"this is an answer three. this is a context three","question":"this is a question three","answer_text":"this is an answer three","answer_start":0} {"id":"12348","context":"this is an answer four. this is a context four","question":"this is a question four","answer_text":"this is an answer four","answer_start":0}
The file
predict_data.json
contains the following:{"id":"12349","context":"this is an answer five. this is a context five","question":"this is a question five"} {"id":"12350","context":"this is an answer six. this is a context six","question":"this is a question six"}
>>> from flash import Trainer >>> from flash.text import QuestionAnsweringData, QuestionAnsweringTask >>> datamodule = QuestionAnsweringData.from_json( ... train_file="train_data.json", ... predict_file="predict_data.json", ... batch_size=2, ... ) Downloading... >>> model = QuestionAnsweringTask(max_source_length=32, max_target_length=32) >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) Training... >>> trainer.predict(model, datamodule=datamodule) Predicting...
- classmethod from_squad_v2(train_file=None, val_file=None, test_file=None, predict_file=None, input_cls=<class 'flash.text.question_answering.input.QuestionAnsweringSQuADInput'>, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, question_column_name='question', context_column_name='context', answer_column_name='answer', **data_module_kwargs)[source]¶
Load the
QuestionAnsweringData
from JSON files containing questions, contexts and their corresponding answers in the SQuAD2.0 format.Question snippets will be extracted from the
question_column_name
column in the JSON files. Context snippets will be extracted from thecontext_column_name
column in the JSON files. Answer snippets will be extracted from theanswer_column_name
column in the JSON files. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.- Parameters
train_file¶ (
Optional
[str
]) – The JSON file containing the training data.val_file¶ (
Optional
[str
]) – The JSON file containing the validation data.test_file¶ (
Optional
[str
]) – The JSON file containing the testing data.predict_file¶ (
Optional
[str
]) – The JSON file containing the predict data.input_cls¶ (
Type
[Input
]) – TheInput
type to use for loading the data.transform¶ (
TypeVar
(INPUT_TRANSFORM_TYPE
,Type
[flash.core.data.io.input_transform.InputTransform],Callable
,Tuple
[Union
[StrEnum
,str
],Dict
[str
,Any
]],Union
[StrEnum
,str
],None
)) – TheInputTransform
type to use.transform_kwargs¶ (
Optional
[Dict
]) – Dict of keyword arguments to be provided when instantiating the transforms.question_column_name¶ (
str
) – The key in the JSON file to recognize the question field.context_column_name¶ (
str
) – The key in the JSON file to recognize the context field.answer_column_name¶ (
str
) – The key in the JSON file to recognize the answer field.
- Return type
- Returns
The constructed data module.
Examples
The file
train_data.json
contains the following:{ "version": "v2.0", "data": [ { "title": "ExampleSet1", "paragraphs": [ { "qas": [ { "question": "this is a question one", "id": "12345", "answers": [{"text": "this is an answer one", "answer_start": 0}], "is_impossible": false } ], "context": "this is an answer one. this is a context one" }, { "qas": [ { "question": "this is a question two", "id": "12346", "answers": [{"text": "this is an answer two", "answer_start": 0}], "is_impossible": false } ], "context": "this is an answer two. this is a context two" } ] }, { "title": "ExampleSet2", "paragraphs": [ { "qas": [ { "question": "this is a question three", "id": "12347", "answers": [{"text": "this is an answer three", "answer_start": 0}], "is_impossible": false } ], "context": "this is an answer three. this is a context three" }, { "qas": [ { "question": "this is a question four", "id": "12348", "answers": [{"text": "this is an answer four", "answer_start": 0}], "is_impossible": false } ], "context": "this is an answer four. this is a context four" } ] } ] }
The file
predict_data.json
contains the following:{ "version": "v2.0", "data": [ { "title": "ExampleSet3", "paragraphs": [ { "qas": [ { "question": "this is a question five", "id": "12349", "is_impossible": false } ], "context": "this is an answer five. this is a context five" }, { "qas": [ { "question": "this is a question six", "id": "12350", "is_impossible": false } ], "context": "this is an answer six. this is a context six" } ] } ] }
>>> from flash import Trainer >>> from flash.text import QuestionAnsweringData, QuestionAnsweringTask >>> datamodule = QuestionAnsweringData.from_squad_v2( ... train_file="train_data.json", ... predict_file="predict_data.json", ... batch_size=2, ... ) >>> model = QuestionAnsweringTask(max_source_length=32, max_target_length=32) >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) Training... >>> trainer.predict(model, datamodule=datamodule) Predicting...
- input_transform_cls¶