Shortcuts

TranslationData

class flash.text.seq2seq.translation.data.TranslationData(train_input=None, val_input=None, test_input=None, predict_input=None, data_fetcher=None, val_split=None, batch_size=None, num_workers=0, sampler=None, pin_memory=True, persistent_workers=False)[source]

The TranslationData class is a DataModule with a set of classmethods for loading data for text translation.

classmethod from_csv(input_field, target_field=None, train_file=None, val_file=None, test_file=None, predict_file=None, train_transform=<class 'flash.core.data.io.input_transform.InputTransform'>, val_transform=<class 'flash.core.data.io.input_transform.InputTransform'>, test_transform=<class 'flash.core.data.io.input_transform.InputTransform'>, predict_transform=<class 'flash.core.data.io.input_transform.InputTransform'>, input_cls=<class 'flash.text.seq2seq.core.input.Seq2SeqCSVInput'>, transform_kwargs=None, **data_module_kwargs)[source]

Load the TranslationData from CSV files containing input text snippets and their corresponding target text snippets.

Input text snippets will be extracted from the input_field column in the CSV files. Target text snippets will be extracted from the target_field column in the CSV files. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.

Parameters
Return type

TranslationData

Returns

The constructed TranslationData.

Examples

The file train_data.csv contains the following:

pig latin,english
ayay entencesay inyay igpay atinlay,a sentence in pig latin
ellohay orldway,hello world

The file predict_data.csv contains the following:

pig latin
ayay entencesay orfay edictionpray
>>> from flash import Trainer
>>> from flash.text import TranslationTask, TranslationData
>>> datamodule = TranslationData.from_csv(
...     "pig latin",
...     "english",
...     train_file="train_data.csv",
...     predict_file="predict_data.csv",
...     batch_size=2,
... )  
Downloading...
>>> model = TranslationTask()
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule)  
Training...
>>> trainer.predict(model, datamodule=datamodule)  
Predicting...
classmethod from_hf_datasets(input_field, target_field=None, train_hf_dataset=None, val_hf_dataset=None, test_hf_dataset=None, predict_hf_dataset=None, train_transform=<class 'flash.core.data.io.input_transform.InputTransform'>, val_transform=<class 'flash.core.data.io.input_transform.InputTransform'>, test_transform=<class 'flash.core.data.io.input_transform.InputTransform'>, predict_transform=<class 'flash.core.data.io.input_transform.InputTransform'>, input_cls=<class 'flash.text.seq2seq.core.input.Seq2SeqInputBase'>, transform_kwargs=None, **data_module_kwargs)[source]

Load the TranslationData from Hugging Face Dataset objects containing input text snippets and their corresponding target text snippets.

Input text snippets will be extracted from the input_field column in the Dataset objects. Target text snippets will be extracted from the target_field column in the Dataset objects. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.

Parameters
Return type

TranslationData

Returns

The constructed TranslationData.

Examples

>>> from datasets import Dataset
>>> from flash import Trainer
>>> from flash.text import TranslationTask, TranslationData
>>> train_data = Dataset.from_dict(
...     {
...         "pig latin": ["ayay entencesay inyay igpay atinlay", "ellohay orldway"],
...         "english": ["a sentence in pig latin", "hello world"],
...     }
... )
>>> predict_data = Dataset.from_dict(
...     {
...         "pig latin": ["ayay entencesay orfay edictionpray"],
...     }
... )
>>> datamodule = TranslationData.from_hf_datasets(
...     "pig latin",
...     "english",
...     train_hf_dataset=train_data,
...     predict_hf_dataset=predict_data,
...     batch_size=2,
... )  
>>> model = TranslationTask()
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule)  
Training...
>>> trainer.predict(model, datamodule=datamodule)  
Predicting...
classmethod from_json(input_field, target_field=None, train_file=None, val_file=None, test_file=None, predict_file=None, train_transform=<class 'flash.core.data.io.input_transform.InputTransform'>, val_transform=<class 'flash.core.data.io.input_transform.InputTransform'>, test_transform=<class 'flash.core.data.io.input_transform.InputTransform'>, predict_transform=<class 'flash.core.data.io.input_transform.InputTransform'>, input_cls=<class 'flash.text.seq2seq.core.input.Seq2SeqJSONInput'>, transform_kwargs=None, field=None, **data_module_kwargs)[source]

Load the TranslationData from JSON files containing input text snippets and their corresponding target text snippets.

Input text snippets will be extracted from the input_field column in the JSON files. Target text snippets will be extracted from the target_field column in the JSON files. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.

Parameters
Return type

TranslationData

Returns

The constructed TranslationData.

Examples

The file train_data.json contains the following:

{"pig latin":"ayay entencesay inyay igpay atinlay","english":"a sentence in pig latin"}
{"pig latin":"ellohay orldway","english":"hello world"}

The file predict_data.json contains the following:

{"pig latin":"ayay entencesay orfay edictionpray"}
>>> from flash import Trainer
>>> from flash.text import TranslationTask, TranslationData
>>> datamodule = TranslationData.from_json(
...     "pig latin",
...     "english",
...     train_file="train_data.json",
...     predict_file="predict_data.json",
...     batch_size=2,
... )  
Downloading...
>>> model = TranslationTask()
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule)  
Training...
>>> trainer.predict(model, datamodule=datamodule)  
Predicting...
classmethod from_lists(train_data=None, train_targets=None, val_data=None, val_targets=None, test_data=None, test_targets=None, predict_data=None, train_transform=<class 'flash.core.data.io.input_transform.InputTransform'>, val_transform=<class 'flash.core.data.io.input_transform.InputTransform'>, test_transform=<class 'flash.core.data.io.input_transform.InputTransform'>, predict_transform=<class 'flash.core.data.io.input_transform.InputTransform'>, input_cls=<class 'flash.text.seq2seq.core.input.Seq2SeqListInput'>, transform_kwargs=None, **data_module_kwargs)[source]

Load the TranslationData from lists of input text snippets and corresponding lists of target text snippets.

To learn how to customize the transforms applied for each stage, read our customizing transforms guide.

Parameters
Return type

TranslationData

Returns

The constructed TranslationData.

Examples

>>> from flash import Trainer
>>> from flash.text import TranslationTask, TranslationData
>>> datamodule = TranslationData.from_lists(
...     train_data=["ayay entencesay inyay igpay atinlay", "ellohay orldway"],
...     train_targets=["a sentence in pig latin", "hello world"],
...     predict_data=["ayay entencesay orfay edictionpray"],
...     batch_size=2,
... )  
>>> model = TranslationTask()
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule)  
Training...
>>> trainer.predict(model, datamodule=datamodule)  
Predicting...
input_transform_cls

alias of flash.core.data.io.input_transform.InputTransform

Read the Docs v: 0.7.0
Versions
latest
stable
0.7.0
0.6.0
0.5.2
0.5.1
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
docs-fix_typing
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.