Shortcuts

TextClassificationData

class flash.text.classification.data.TextClassificationData(train_input=None, val_input=None, test_input=None, predict_input=None, data_fetcher=None, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, val_split=None, batch_size=None, num_workers=0, sampler=None, pin_memory=True, persistent_workers=False)[source]

The TextClassificationData class is a DataModule with a set of classmethods for loading data for text classification.

classmethod from_csv(input_field, target_fields=None, train_file=None, val_file=None, test_file=None, predict_file=None, target_formatter=None, input_cls=<class 'flash.text.classification.input.TextClassificationCSVInput'>, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, **data_module_kwargs)[source]

Load the TextClassificationData from CSV files containing text snippets and their corresponding targets.

Input text snippets will be extracted from the input_field column in the CSV files. The targets will be extracted from the target_fields in the CSV files and can be in any of our supported classification target formats. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.

Parameters
Return type

TextClassificationData

Returns

The constructed TextClassificationData.

Examples

The files can be in Comma Separated Values (CSV) format with either a .csv or .txt extension.

The file train_data.csv contains the following:

reviews,targets
Best movie ever!,positive
Not good,negative
Fine I guess,neutral

The file predict_data.csv contains the following:

reviews
Worst movie ever!
I didn't enjoy it
It was ok
>>> from flash import Trainer
>>> from flash.text import TextClassifier, TextClassificationData
>>> datamodule = TextClassificationData.from_csv(
...     "reviews",
...     "targets",
...     train_file="train_data.csv",
...     predict_file="predict_data.csv",
...     batch_size=2,
... )
>>> datamodule.num_classes
3
>>> datamodule.labels
['negative', 'neutral', 'positive']
>>> model = TextClassifier(num_classes=datamodule.num_classes, backbone="prajjwal1/bert-tiny")
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule)  
Training...
>>> trainer.predict(model, datamodule=datamodule)  
Predicting...

Alternatively, the files can be in Tab Separated Values (TSV) format with a .tsv extension.

The file train_data.tsv contains the following:

reviews             targets
Best movie ever!    positive
Not good            negative
Fine I guess        neutral

The file predict_data.tsv contains the following:

reviews
Worst movie ever!
I didn't enjoy it
It was ok
>>> from flash import Trainer
>>> from flash.text import TextClassifier, TextClassificationData
>>> datamodule = TextClassificationData.from_csv(
...     "reviews",
...     "targets",
...     train_file="train_data.tsv",
...     predict_file="predict_data.tsv",
...     batch_size=2,
... )
>>> datamodule.num_classes
3
>>> datamodule.labels
['negative', 'neutral', 'positive']
>>> model = TextClassifier(num_classes=datamodule.num_classes, backbone="prajjwal1/bert-tiny")
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule)  
Training...
>>> trainer.predict(model, datamodule=datamodule)  
Predicting...
classmethod from_data_frame(input_field, target_fields=None, train_data_frame=None, val_data_frame=None, test_data_frame=None, predict_data_frame=None, target_formatter=None, input_cls=<class 'flash.text.classification.input.TextClassificationDataFrameInput'>, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, **data_module_kwargs)[source]

Load the TextClassificationData from Pandas DataFrame objects containing text snippets and their corresponding targets.

Input text snippets will be extracted from the input_field column in the DataFrame objects. The targets will be extracted from the target_fields in the DataFrame objects and can be in any of our supported classification target formats. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.

Parameters
Return type

TextClassificationData

Returns

The constructed TextClassificationData.

Examples

>>> from pandas import DataFrame
>>> from flash import Trainer
>>> from flash.text import TextClassifier, TextClassificationData
>>> train_data = DataFrame.from_dict(
...     {
...         "reviews": ["Best movie ever!", "Not good", "Fine I guess"],
...         "targets": ["positive", "negative", "neutral"],
...     }
... )
>>> predict_data = DataFrame.from_dict(
...     {
...         "reviews": ["Worst movie ever!", "I didn't enjoy it", "It was ok"],
...     }
... )
>>> datamodule = TextClassificationData.from_data_frame(
...     "reviews",
...     "targets",
...     train_data_frame=train_data,
...     predict_data_frame=predict_data,
...     batch_size=2,
... )
>>> datamodule.num_classes
3
>>> datamodule.labels
['negative', 'neutral', 'positive']
>>> model = TextClassifier(num_classes=datamodule.num_classes, backbone="prajjwal1/bert-tiny")
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule)  
Training...
>>> trainer.predict(model, datamodule=datamodule)  
Predicting...
classmethod from_hf_datasets(input_field, target_fields=None, train_hf_dataset=None, val_hf_dataset=None, test_hf_dataset=None, predict_hf_dataset=None, target_formatter=None, input_cls=<class 'flash.text.classification.input.TextClassificationInput'>, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, **data_module_kwargs)[source]

Load the TextClassificationData from Hugging Face Dataset objects containing text snippets and their corresponding targets.

Input text snippets will be extracted from the input_field column in the Dataset objects. The targets will be extracted from the target_fields in the Dataset objects and can be in any of our supported classification target formats. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.

Parameters
Return type

TextClassificationData

Returns

The constructed TextClassificationData.

Examples

>>> from datasets import Dataset
>>> from flash import Trainer
>>> from flash.text import TextClassifier, TextClassificationData
>>> train_data = Dataset.from_dict(
...     {
...         "reviews": ["Best movie ever!", "Not good", "Fine I guess"],
...         "targets": ["positive", "negative", "neutral"],
...     }
... )
>>> predict_data = Dataset.from_dict(
...     {
...         "reviews": ["Worst movie ever!", "I didn't enjoy it", "It was ok"],
...     }
... )
>>> datamodule = TextClassificationData.from_hf_datasets(
...     "reviews",
...     "targets",
...     train_hf_dataset=train_data,
...     predict_hf_dataset=predict_data,
...     batch_size=2,
... )
>>> datamodule.num_classes
3
>>> datamodule.labels
['negative', 'neutral', 'positive']
>>> model = TextClassifier(num_classes=datamodule.num_classes, backbone="prajjwal1/bert-tiny")
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule)  
Training...
>>> trainer.predict(model, datamodule=datamodule)  
Predicting...
classmethod from_json(input_field, target_fields=None, train_file=None, val_file=None, test_file=None, predict_file=None, target_formatter=None, input_cls=<class 'flash.text.classification.input.TextClassificationJSONInput'>, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, field=None, **data_module_kwargs)[source]

Load the TextClassificationData from JSON files containing text snippets and their corresponding targets.

Input text snippets will be extracted from the input_field in the JSON objects. The targets will be extracted from the target_fields in the JSON objects and can be in any of our supported classification target formats. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.

Parameters
Return type

TextClassificationData

Returns

The constructed TextClassificationData.

Examples

The file train_data.json contains the following:

{"reviews":"Best movie ever!","targets":"positive"}
{"reviews":"Not good","targets":"negative"}
{"reviews":"Fine I guess","targets":"neutral"}

The file predict_data.json contains the following:

{"reviews":"Worst movie ever!"}
{"reviews":"I didn't enjoy it"}
{"reviews":"It was ok"}
>>> from flash import Trainer
>>> from flash.text import TextClassifier, TextClassificationData
>>> datamodule = TextClassificationData.from_json(
...     "reviews",
...     "targets",
...     train_file="train_data.json",
...     predict_file="predict_data.json",
...     batch_size=2,
... )  

...
>>> datamodule.num_classes
3
>>> datamodule.labels
['negative', 'neutral', 'positive']
>>> model = TextClassifier(num_classes=datamodule.num_classes, backbone="prajjwal1/bert-tiny")
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule)  
Training...
>>> trainer.predict(model, datamodule=datamodule)  
Predicting...
classmethod from_labelstudio(export_json=None, train_export_json=None, val_export_json=None, test_export_json=None, predict_export_json=None, data_folder=None, train_data_folder=None, val_data_folder=None, test_data_folder=None, predict_data_folder=None, input_cls=<class 'flash.core.integrations.labelstudio.input.LabelStudioTextClassificationInput'>, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, val_split=None, multi_label=False, **data_module_kwargs)[source]

Creates a DataModule object from the given export file and data directory using the Input of name FOLDERS from the passed or constructed InputTransform.

Parameters
  • export_json (Optional[str]) – path to label studio export file

  • train_export_json (Optional[str]) – path to label studio export file for train set, overrides export_json if specified

  • val_export_json (Optional[str]) – path to label studio export file for validation

  • test_export_json (Optional[str]) – path to label studio export file for test

  • predict_export_json (Optional[str]) – path to label studio export file for predict

  • data_folder (Optional[str]) – path to label studio data folder

  • train_data_folder (Optional[str]) – path to label studio data folder for train data set, overrides data_folder if specified

  • val_data_folder (Optional[str]) – path to label studio data folder for validation data

  • test_data_folder (Optional[str]) – path to label studio data folder for test data

  • predict_data_folder (Optional[str]) – path to label studio data folder for predict data

  • input_cls (Type[Input]) – The Input type to use for loading the data.

  • transform (Optional[Dict[str, Callable]]) – The InputTransform type to use.

  • transform_kwargs (Optional[Dict]) – Dict of keyword arguments to be provided when instantiating the transforms.

  • val_split (Optional[float]) – The val_split argument to pass to the DataModule.

  • multi_label (Optional[bool]) – Whether the labels are multi encoded.

  • data_module_kwargs (Any) – Additional keyword arguments to use when constructing the datamodule.

Return type

TextClassificationData

Returns

The constructed data module.

classmethod from_lists(train_data=None, train_targets=None, val_data=None, val_targets=None, test_data=None, test_targets=None, predict_data=None, target_formatter=None, input_cls=<class 'flash.text.classification.input.TextClassificationListInput'>, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, **data_module_kwargs)[source]

Load the TextClassificationData from lists of text snippets and corresponding lists of targets.

The targets can be in any of our supported classification target formats. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.

Parameters
Return type

TextClassificationData

Returns

The constructed TextClassificationData.

Examples

>>> from flash import Trainer
>>> from flash.text import TextClassifier, TextClassificationData
>>> datamodule = TextClassificationData.from_lists(
...     train_data=["Best movie ever!", "Not good", "Fine I guess"],
...     train_targets=["positive", "negative", "neutral"],
...     predict_data=["Worst movie ever!", "I didn't enjoy it", "It was ok"],
...     batch_size=2,
... )
>>> datamodule.num_classes
3
>>> datamodule.labels
['negative', 'neutral', 'positive']
>>> model = TextClassifier(num_classes=datamodule.num_classes, backbone="prajjwal1/bert-tiny")
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule)  
Training...
>>> trainer.predict(model, datamodule=datamodule)  
Predicting...
classmethod from_parquet(input_field, target_fields=None, train_file=None, val_file=None, test_file=None, predict_file=None, target_formatter=None, input_cls=<class 'flash.text.classification.input.TextClassificationParquetInput'>, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, **data_module_kwargs)[source]

Load the TextClassificationData from PARQUET files containing text snippets and their corresponding targets.

Input text snippets will be extracted from the input_field column in the PARQUET files. The targets will be extracted from the target_fields in the PARQUET files and can be in any of our supported classification target formats. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.

Parameters
Return type

TextClassificationData

Returns

The constructed TextClassificationData.

Examples

The file train_data.parquet contains the following contents encoded in the PARQUET format:

reviews,targets
Best movie ever!,positive
Not good,negative
Fine I guess,neutral

The file predict_data.parquet contains the following contents encoded in the PARQUET format:

reviews
Worst movie ever!
I didn't enjoy it
It was ok
>>> from flash import Trainer
>>> from flash.text import TextClassifier, TextClassificationData
>>> datamodule = TextClassificationData.from_parquet(
...     "reviews",
...     "targets",
...     train_file="train_data.parquet",
...     predict_file="predict_data.parquet",
...     batch_size=2,
... )  

...
>>> datamodule.num_classes
3
>>> datamodule.labels
['negative', 'neutral', 'positive']
>>> model = TextClassifier(num_classes=datamodule.num_classes, backbone="prajjwal1/bert-tiny")
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule)  
Training...
>>> trainer.predict(model, datamodule=datamodule)  
Predicting...
input_transform_cls

alias of flash.core.data.io.input_transform.InputTransform

Read the Docs v: latest
Versions
latest
stable
0.8.2
0.8.1.post0
0.8.1
0.8.0
0.7.5
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.2
0.5.1
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
Downloads
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.