Shortcuts

SummarizationData

class flash.text.seq2seq.summarization.data.SummarizationData(train_input=None, val_input=None, test_input=None, predict_input=None, data_fetcher=None, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, val_split=None, batch_size=None, num_workers=0, sampler=None, pin_memory=True, persistent_workers=False)[source]

The SummarizationData class is a DataModule with a set of classmethods for loading data for text summarization.

classmethod from_csv(input_field, target_field=None, train_file=None, val_file=None, test_file=None, predict_file=None, input_cls=<class 'flash.text.seq2seq.core.input.Seq2SeqCSVInput'>, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, **data_module_kwargs)[source]

Load the SummarizationData from CSV files containing input text snippets and their corresponding target text snippets.

Input text snippets will be extracted from the input_field column in the CSV files. Target text snippets will be extracted from the target_field column in the CSV files. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.

Parameters
Return type

SummarizationData

Returns

The constructed SummarizationData.

Examples

The files can be in Comma Separated Values (CSV) format with either a .csv or .txt extension.

The file train_data.csv contains the following:

texts,summaries
A long paragraph,A short paragraph
A news article,A news headline

The file predict_data.csv contains the following:

texts
A movie review
A book chapter
>>> from flash import Trainer
>>> from flash.text import SummarizationTask, SummarizationData
>>> datamodule = SummarizationData.from_csv(
...     "texts",
...     "summaries",
...     train_file="train_data.csv",
...     predict_file="predict_data.csv",
...     batch_size=2,
... )
>>> model = SummarizationTask(backbone="JulesBelveze/t5-small-headline-generator")
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule)  
Training...
>>> trainer.predict(model, datamodule=datamodule)  
Predicting...

Alternatively, the files can be in Tab Separated Values (TSV) format with a .tsv extension.

The file train_data.tsv contains the following:

texts               summaries
A long paragraph    A short paragraph
A news article      A news headline

The file predict_data.tsv contains the following:

texts
A movie review
A book chapter
>>> from flash import Trainer
>>> from flash.text import SummarizationTask, SummarizationData
>>> datamodule = SummarizationData.from_csv(
...     "texts",
...     "summaries",
...     train_file="train_data.tsv",
...     predict_file="predict_data.tsv",
...     batch_size=2,
... )
>>> model = SummarizationTask(backbone="JulesBelveze/t5-small-headline-generator")
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule)  
Training...
>>> trainer.predict(model, datamodule=datamodule)  
Predicting...
classmethod from_hf_datasets(input_field, target_field=None, train_hf_dataset=None, val_hf_dataset=None, test_hf_dataset=None, predict_hf_dataset=None, input_cls=<class 'flash.text.seq2seq.core.input.Seq2SeqInputBase'>, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, **data_module_kwargs)[source]

Load the SummarizationData from Hugging Face Dataset objects containing input text snippets and their corresponding target text snippets.

Input text snippets will be extracted from the input_field column in the Dataset objects. Target text snippets will be extracted from the target_field column in the Dataset objects. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.

Parameters
  • input_field (str) – The field (column name) in the Dataset objects containing the input text snippets.

  • target_field (Optional[str]) – The field (column name) in the Dataset objects containing the target text snippets.

  • train_hf_dataset (Optional[object]) – The Dataset to use when training.

  • val_hf_dataset (Optional[object]) – The Dataset to use when validating.

  • test_hf_dataset (Optional[object]) – The Dataset to use when testing.

  • predict_hf_dataset (Optional[object]) – The Dataset to use when predicting.

  • input_cls (Type[Input]) – The Input type to use for loading the data.

  • transform (TypeVar(INPUT_TRANSFORM_TYPE, Type[flash.core.data.io.input_transform.InputTransform], Callable, Tuple[Union[StrEnum, str], Dict[str, Any]], Union[StrEnum, str], None)) – The InputTransform type to use.

  • transform_kwargs (Optional[Dict]) – Dict of keyword arguments to be provided when instantiating the transforms.

  • data_module_kwargs (Any) – Additional keyword arguments to provide to the DataModule constructor.

Return type

SummarizationData

Returns

The constructed SummarizationData.

Examples

>>> from datasets import Dataset
>>> from flash import Trainer
>>> from flash.text import SummarizationTask, SummarizationData
>>> train_data = Dataset.from_dict(
...     {
...         "texts": ["A long paragraph", "A news article"],
...         "summaries": ["A short paragraph", "A news headline"],
...     }
... )
>>> predict_data = Dataset.from_dict(
...     {
...         "texts": ["A movie review", "A book chapter"],
...     }
... )
>>> datamodule = SummarizationData.from_hf_datasets(
...     "texts",
...     "summaries",
...     train_hf_dataset=train_data,
...     predict_hf_dataset=predict_data,
...     batch_size=2,
... )  
>>> model = SummarizationTask(backbone="JulesBelveze/t5-small-headline-generator")
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule)  
Training...
>>> trainer.predict(model, datamodule=datamodule)  
Predicting...
classmethod from_json(input_field, target_field=None, train_file=None, val_file=None, test_file=None, predict_file=None, input_cls=<class 'flash.text.seq2seq.core.input.Seq2SeqJSONInput'>, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, field=None, **data_module_kwargs)[source]

Load the SummarizationData from JSON files containing input text snippets and their corresponding target text snippets.

Input text snippets will be extracted from the input_field column in the JSON files. Target text snippets will be extracted from the target_field column in the JSON files. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.

Parameters
Return type

SummarizationData

Returns

The constructed SummarizationData.

Examples

The file train_data.json contains the following:

{"texts":"A long paragraph","summaries":"A short paragraph"}
{"texts":"A news article","summaries":"A news headline"}

The file predict_data.json contains the following:

{"texts":"A movie review"}
{"texts":"A book chapter"}
>>> from flash import Trainer
>>> from flash.text import SummarizationTask, SummarizationData
>>> datamodule = SummarizationData.from_json(
...     "texts",
...     "summaries",
...     train_file="train_data.json",
...     predict_file="predict_data.json",
...     batch_size=2,
... )  
Downloading...
>>> model = SummarizationTask(backbone="JulesBelveze/t5-small-headline-generator")
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule)  
Training...
>>> trainer.predict(model, datamodule=datamodule)  
Predicting...
classmethod from_lists(train_data=None, train_targets=None, val_data=None, val_targets=None, test_data=None, test_targets=None, predict_data=None, input_cls=<class 'flash.text.seq2seq.core.input.Seq2SeqListInput'>, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, **data_module_kwargs)[source]

Load the SummarizationData from lists of input text snippets and corresponding lists of target text snippets.

To learn how to customize the transforms applied for each stage, read our customizing transforms guide.

Parameters
  • train_data (Optional[List[str]]) – The list of input text snippets to use when training.

  • train_targets (Optional[List[str]]) – The list of target text snippets to use when training.

  • val_data (Optional[List[str]]) – The list of input text snippets to use when validating.

  • val_targets (Optional[List[str]]) – The list of target text snippets to use when validating.

  • test_data (Optional[List[str]]) – The list of input text snippets to use when testing.

  • test_targets (Optional[List[str]]) – The list of target text snippets to use when testing.

  • predict_data (Optional[List[str]]) – The list of input text snippets to use when predicting.

  • input_cls (Type[Input]) – The Input type to use for loading the data.

  • transform (TypeVar(INPUT_TRANSFORM_TYPE, Type[flash.core.data.io.input_transform.InputTransform], Callable, Tuple[Union[StrEnum, str], Dict[str, Any]], Union[StrEnum, str], None)) – The InputTransform type to use.

  • transform_kwargs (Optional[Dict]) – Dict of keyword arguments to be provided when instantiating the transforms.

  • data_module_kwargs (Any) – Additional keyword arguments to provide to the DataModule constructor.

Return type

SummarizationData

Returns

The constructed SummarizationData.

Examples

>>> from flash import Trainer
>>> from flash.text import SummarizationTask, SummarizationData
>>> datamodule = SummarizationData.from_lists(
...     train_data=["A long paragraph", "A news article"],
...     train_targets=["A short paragraph", "A news headline"],
...     predict_data=["A movie review", "A book chapter"],
...     batch_size=2,
... )  
>>> model = SummarizationTask(backbone="JulesBelveze/t5-small-headline-generator")
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule)  
Training...
>>> trainer.predict(model, datamodule=datamodule)  
Predicting...
input_transform_cls

alias of flash.core.data.io.input_transform.InputTransform

Read the Docs v: stable
Versions
latest
stable
0.8.2
0.8.1.post0
0.8.1
0.8.0
0.7.5
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.2
0.5.1
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
Downloads
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.