SummarizationData¶
- class flash.text.seq2seq.summarization.data.SummarizationData(train_input=None, val_input=None, test_input=None, predict_input=None, data_fetcher=None, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, val_split=None, batch_size=None, num_workers=0, sampler=None, pin_memory=True, persistent_workers=False)[source]¶
The
SummarizationDataclass is aDataModulewith a set of classmethods for loading data for text summarization.- classmethod from_csv(input_field, target_field=None, train_file=None, val_file=None, test_file=None, predict_file=None, input_cls=<class 'flash.text.seq2seq.core.input.Seq2SeqCSVInput'>, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, **data_module_kwargs)[source]¶
Load the
SummarizationDatafrom CSV files containing input text snippets and their corresponding target text snippets.Input text snippets will be extracted from the
input_fieldcolumn in the CSV files. Target text snippets will be extracted from thetarget_fieldcolumn in the CSV files. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.- Parameters
input_field¶ (
str) – The field (column name) in the CSV files containing the input text snippets.target_field¶ (
Optional[str]) – The field (column name) in the CSV files containing the target text snippets.train_file¶ (
Union[str,bytes,PathLike,None]) – The CSV file to use when training.val_file¶ (
Union[str,bytes,PathLike,None]) – The CSV file to use when validating.test_file¶ (
Union[str,bytes,PathLike,None]) – The CSV file to use when testing.predict_file¶ (
Union[str,bytes,PathLike,None]) – The CSV file to use when predicting.input_cls¶ (
Type[Input]) – TheInputtype to use for loading the data.transform¶ (
TypeVar(INPUT_TRANSFORM_TYPE,Type[flash.core.data.io.input_transform.InputTransform],Callable,Tuple[Union[StrEnum,str],Dict[str,Any]],Union[StrEnum,str],None)) – TheInputTransformtype to use.transform_kwargs¶ (
Optional[Dict]) – Dict of keyword arguments to be provided when instantiating the transforms.data_module_kwargs¶ (
Any) – Additional keyword arguments to provide to theDataModuleconstructor.
- Return type
- Returns
The constructed
SummarizationData.
Examples
The files can be in Comma Separated Values (CSV) format with either a
.csvor.txtextension.The file
train_data.csvcontains the following:texts,summaries A long paragraph,A short paragraph A news article,A news headline
The file
predict_data.csvcontains the following:texts A movie review A book chapter
>>> from flash import Trainer >>> from flash.text import SummarizationTask, SummarizationData >>> datamodule = SummarizationData.from_csv( ... "texts", ... "summaries", ... train_file="train_data.csv", ... predict_file="predict_data.csv", ... batch_size=2, ... ) >>> model = SummarizationTask(backbone="JulesBelveze/t5-small-headline-generator") >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) Training... >>> trainer.predict(model, datamodule=datamodule) Predicting...
Alternatively, the files can be in Tab Separated Values (TSV) format with a
.tsvextension.The file
train_data.tsvcontains the following:texts summaries A long paragraph A short paragraph A news article A news headline
The file
predict_data.tsvcontains the following:texts A movie review A book chapter
>>> from flash import Trainer >>> from flash.text import SummarizationTask, SummarizationData >>> datamodule = SummarizationData.from_csv( ... "texts", ... "summaries", ... train_file="train_data.tsv", ... predict_file="predict_data.tsv", ... batch_size=2, ... ) >>> model = SummarizationTask(backbone="JulesBelveze/t5-small-headline-generator") >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) Training... >>> trainer.predict(model, datamodule=datamodule) Predicting...
- classmethod from_hf_datasets(input_field, target_field=None, train_hf_dataset=None, val_hf_dataset=None, test_hf_dataset=None, predict_hf_dataset=None, input_cls=<class 'flash.text.seq2seq.core.input.Seq2SeqInputBase'>, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, **data_module_kwargs)[source]¶
Load the
SummarizationDatafrom Hugging FaceDatasetobjects containing input text snippets and their corresponding target text snippets.Input text snippets will be extracted from the
input_fieldcolumn in theDatasetobjects. Target text snippets will be extracted from thetarget_fieldcolumn in theDatasetobjects. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.- Parameters
input_field¶ (
str) – The field (column name) in theDatasetobjects containing the input text snippets.target_field¶ (
Optional[str]) – The field (column name) in theDatasetobjects containing the target text snippets.train_hf_dataset¶ (
Optional[object]) – TheDatasetto use when training.val_hf_dataset¶ (
Optional[object]) – TheDatasetto use when validating.test_hf_dataset¶ (
Optional[object]) – TheDatasetto use when testing.predict_hf_dataset¶ (
Optional[object]) – TheDatasetto use when predicting.input_cls¶ (
Type[Input]) – TheInputtype to use for loading the data.transform¶ (
TypeVar(INPUT_TRANSFORM_TYPE,Type[flash.core.data.io.input_transform.InputTransform],Callable,Tuple[Union[StrEnum,str],Dict[str,Any]],Union[StrEnum,str],None)) – TheInputTransformtype to use.transform_kwargs¶ (
Optional[Dict]) – Dict of keyword arguments to be provided when instantiating the transforms.data_module_kwargs¶ (
Any) – Additional keyword arguments to provide to theDataModuleconstructor.
- Return type
- Returns
The constructed
SummarizationData.
Examples
>>> from datasets import Dataset >>> from flash import Trainer >>> from flash.text import SummarizationTask, SummarizationData >>> train_data = Dataset.from_dict( ... { ... "texts": ["A long paragraph", "A news article"], ... "summaries": ["A short paragraph", "A news headline"], ... } ... ) >>> predict_data = Dataset.from_dict( ... { ... "texts": ["A movie review", "A book chapter"], ... } ... ) >>> datamodule = SummarizationData.from_hf_datasets( ... "texts", ... "summaries", ... train_hf_dataset=train_data, ... predict_hf_dataset=predict_data, ... batch_size=2, ... ) >>> model = SummarizationTask(backbone="JulesBelveze/t5-small-headline-generator") >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) Training... >>> trainer.predict(model, datamodule=datamodule) Predicting...
- classmethod from_json(input_field, target_field=None, train_file=None, val_file=None, test_file=None, predict_file=None, input_cls=<class 'flash.text.seq2seq.core.input.Seq2SeqJSONInput'>, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, field=None, **data_module_kwargs)[source]¶
Load the
SummarizationDatafrom JSON files containing input text snippets and their corresponding target text snippets.Input text snippets will be extracted from the
input_fieldcolumn in the JSON files. Target text snippets will be extracted from thetarget_fieldcolumn in the JSON files. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.- Parameters
input_field¶ (
str) – The field (column name) in the JSON objects containing the input text snippets.target_field¶ (
Optional[str]) – The field (column name) in the JSON objects containing the target text snippets.train_file¶ (
Union[str,bytes,PathLike,None]) – The JSON file to use when training.val_file¶ (
Union[str,bytes,PathLike,None]) – The JSON file to use when validating.test_file¶ (
Union[str,bytes,PathLike,None]) – The JSON file to use when testing.predict_file¶ (
Union[str,bytes,PathLike,None]) – The JSON file to use when predicting.input_cls¶ (
Type[Input]) – TheInputtype to use for loading the data.transform¶ (
TypeVar(INPUT_TRANSFORM_TYPE,Type[flash.core.data.io.input_transform.InputTransform],Callable,Tuple[Union[StrEnum,str],Dict[str,Any]],Union[StrEnum,str],None)) – TheInputTransformtype to use.transform_kwargs¶ (
Optional[Dict]) – Dict of keyword arguments to be provided when instantiating the transforms.field¶ (
Optional[str]) – The field that holds the data in the JSON file.data_module_kwargs¶ (
Any) – Additional keyword arguments to provide to theDataModuleconstructor.
- Return type
- Returns
The constructed
SummarizationData.
Examples
The file
train_data.jsoncontains the following:{"texts":"A long paragraph","summaries":"A short paragraph"} {"texts":"A news article","summaries":"A news headline"}
The file
predict_data.jsoncontains the following:{"texts":"A movie review"} {"texts":"A book chapter"}
>>> from flash import Trainer >>> from flash.text import SummarizationTask, SummarizationData >>> datamodule = SummarizationData.from_json( ... "texts", ... "summaries", ... train_file="train_data.json", ... predict_file="predict_data.json", ... batch_size=2, ... ) ... >>> model = SummarizationTask(backbone="JulesBelveze/t5-small-headline-generator") >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) Training... >>> trainer.predict(model, datamodule=datamodule) Predicting...
- classmethod from_lists(train_data=None, train_targets=None, val_data=None, val_targets=None, test_data=None, test_targets=None, predict_data=None, input_cls=<class 'flash.text.seq2seq.core.input.Seq2SeqListInput'>, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, **data_module_kwargs)[source]¶
Load the
SummarizationDatafrom lists of input text snippets and corresponding lists of target text snippets.To learn how to customize the transforms applied for each stage, read our customizing transforms guide.
- Parameters
train_data¶ (
Optional[List[str]]) – The list of input text snippets to use when training.train_targets¶ (
Optional[List[str]]) – The list of target text snippets to use when training.val_data¶ (
Optional[List[str]]) – The list of input text snippets to use when validating.val_targets¶ (
Optional[List[str]]) – The list of target text snippets to use when validating.test_data¶ (
Optional[List[str]]) – The list of input text snippets to use when testing.test_targets¶ (
Optional[List[str]]) – The list of target text snippets to use when testing.predict_data¶ (
Optional[List[str]]) – The list of input text snippets to use when predicting.input_cls¶ (
Type[Input]) – TheInputtype to use for loading the data.transform¶ (
TypeVar(INPUT_TRANSFORM_TYPE,Type[flash.core.data.io.input_transform.InputTransform],Callable,Tuple[Union[StrEnum,str],Dict[str,Any]],Union[StrEnum,str],None)) – TheInputTransformtype to use.transform_kwargs¶ (
Optional[Dict]) – Dict of keyword arguments to be provided when instantiating the transforms.data_module_kwargs¶ (
Any) – Additional keyword arguments to provide to theDataModuleconstructor.
- Return type
- Returns
The constructed
SummarizationData.
Examples
>>> from flash import Trainer >>> from flash.text import SummarizationTask, SummarizationData >>> datamodule = SummarizationData.from_lists( ... train_data=["A long paragraph", "A news article"], ... train_targets=["A short paragraph", "A news headline"], ... predict_data=["A movie review", "A book chapter"], ... batch_size=2, ... ) >>> model = SummarizationTask(backbone="JulesBelveze/t5-small-headline-generator") >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) Training... >>> trainer.predict(model, datamodule=datamodule) Predicting...
- input_transform_cls¶