TextClassificationData¶
- class flash.text.classification.data.TextClassificationData(train_input=None, val_input=None, test_input=None, predict_input=None, data_fetcher=None, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, val_split=None, batch_size=None, num_workers=0, sampler=None, pin_memory=True, persistent_workers=False)[source]¶
The
TextClassificationData
class is aDataModule
with a set of classmethods for loading data for text classification.- classmethod from_csv(input_field, target_fields=None, train_file=None, val_file=None, test_file=None, predict_file=None, target_formatter=None, input_cls=<class 'flash.text.classification.input.TextClassificationCSVInput'>, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, **data_module_kwargs)[source]¶
Load the
TextClassificationData
from CSV files containing text snippets and their corresponding targets.Input text snippets will be extracted from the
input_field
column in the CSV files. The targets will be extracted from thetarget_fields
in the CSV files and can be in any of our supported classification target formats. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.- Parameters
input_field¶ (
str
) – The field (column name) in the CSV files containing the text snippets.target_fields¶ (
Union
[str
,Sequence
[str
],None
]) – The field (column name) or list of fields in the CSV files containing the targets.train_file¶ (
Union
[str
,bytes
,PathLike
,None
]) – The CSV file to use when training.val_file¶ (
Union
[str
,bytes
,PathLike
,None
]) – The CSV file to use when validating.test_file¶ (
Union
[str
,bytes
,PathLike
,None
]) – The CSV file to use when testing.predict_file¶ (
Union
[str
,bytes
,PathLike
,None
]) – The CSV file to use when predicting.target_formatter¶ (
Optional
[TargetFormatter
]) – Optionally provide aTargetFormatter
to control how targets are handled. See Formatting Classification Targets for more details.input_cls¶ (
Type
[Input
]) – TheInput
type to use for loading the data.transform¶ (
Optional
[Dict
[str
,Callable
]]) – TheInputTransform
type to use.transform_kwargs¶ (
Optional
[Dict
]) – Dict of keyword arguments to be provided when instantiating the transforms.data_module_kwargs¶ (
Any
) – Additional keyword arguments to provide to theDataModule
constructor.
- Return type
- Returns
The constructed
TextClassificationData
.
Examples
The files can be in Comma Separated Values (CSV) format with either a
.csv
or.txt
extension.The file
train_data.csv
contains the following:reviews,targets Best movie ever!,positive Not good,negative Fine I guess,neutral
The file
predict_data.csv
contains the following:reviews Worst movie ever! I didn't enjoy it It was ok
>>> from flash import Trainer >>> from flash.text import TextClassifier, TextClassificationData >>> datamodule = TextClassificationData.from_csv( ... "reviews", ... "targets", ... train_file="train_data.csv", ... predict_file="predict_data.csv", ... batch_size=2, ... ) >>> datamodule.num_classes 3 >>> datamodule.labels ['negative', 'neutral', 'positive'] >>> model = TextClassifier(num_classes=datamodule.num_classes, backbone="prajjwal1/bert-tiny") >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) Training... >>> trainer.predict(model, datamodule=datamodule) Predicting...
Alternatively, the files can be in Tab Separated Values (TSV) format with a
.tsv
extension.The file
train_data.tsv
contains the following:reviews targets Best movie ever! positive Not good negative Fine I guess neutral
The file
predict_data.tsv
contains the following:reviews Worst movie ever! I didn't enjoy it It was ok
>>> from flash import Trainer >>> from flash.text import TextClassifier, TextClassificationData >>> datamodule = TextClassificationData.from_csv( ... "reviews", ... "targets", ... train_file="train_data.tsv", ... predict_file="predict_data.tsv", ... batch_size=2, ... ) >>> datamodule.num_classes 3 >>> datamodule.labels ['negative', 'neutral', 'positive'] >>> model = TextClassifier(num_classes=datamodule.num_classes, backbone="prajjwal1/bert-tiny") >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) Training... >>> trainer.predict(model, datamodule=datamodule) Predicting...
- classmethod from_data_frame(input_field, target_fields=None, train_data_frame=None, val_data_frame=None, test_data_frame=None, predict_data_frame=None, target_formatter=None, input_cls=<class 'flash.text.classification.input.TextClassificationDataFrameInput'>, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, **data_module_kwargs)[source]¶
Load the
TextClassificationData
from PandasDataFrame
objects containing text snippets and their corresponding targets.Input text snippets will be extracted from the
input_field
column in theDataFrame
objects. The targets will be extracted from thetarget_fields
in theDataFrame
objects and can be in any of our supported classification target formats. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.- Parameters
input_field¶ (
str
) – The field (column name) in theDataFrame
objects containing the text snippets.target_fields¶ (
Union
[str
,Sequence
[str
],None
]) – The field (column name) or list of fields in theDataFrame
objects containing the targets.train_data_frame¶ (
Optional
[DataFrame
]) – TheDataFrame
to use when training.val_data_frame¶ (
Optional
[DataFrame
]) – TheDataFrame
to use when validating.test_data_frame¶ (
Optional
[DataFrame
]) – TheDataFrame
to use when testing.predict_data_frame¶ (
Optional
[DataFrame
]) – TheDataFrame
to use when predicting.target_formatter¶ (
Optional
[TargetFormatter
]) – Optionally provide aTargetFormatter
to control how targets are handled. See Formatting Classification Targets for more details.input_cls¶ (
Type
[Input
]) – TheInput
type to use for loading the data.transform¶ (
Optional
[Dict
[str
,Callable
]]) – TheInputTransform
type to use.transform_kwargs¶ (
Optional
[Dict
]) – Dict of keyword arguments to be provided when instantiating the transforms.data_module_kwargs¶ (
Any
) – Additional keyword arguments to provide to theDataModule
constructor.
- Return type
- Returns
The constructed
TextClassificationData
.
Examples
>>> from pandas import DataFrame >>> from flash import Trainer >>> from flash.text import TextClassifier, TextClassificationData >>> train_data = DataFrame.from_dict( ... { ... "reviews": ["Best movie ever!", "Not good", "Fine I guess"], ... "targets": ["positive", "negative", "neutral"], ... } ... ) >>> predict_data = DataFrame.from_dict( ... { ... "reviews": ["Worst movie ever!", "I didn't enjoy it", "It was ok"], ... } ... ) >>> datamodule = TextClassificationData.from_data_frame( ... "reviews", ... "targets", ... train_data_frame=train_data, ... predict_data_frame=predict_data, ... batch_size=2, ... ) >>> datamodule.num_classes 3 >>> datamodule.labels ['negative', 'neutral', 'positive'] >>> model = TextClassifier(num_classes=datamodule.num_classes, backbone="prajjwal1/bert-tiny") >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) Training... >>> trainer.predict(model, datamodule=datamodule) Predicting...
- classmethod from_hf_datasets(input_field, target_fields=None, train_hf_dataset=None, val_hf_dataset=None, test_hf_dataset=None, predict_hf_dataset=None, target_formatter=None, input_cls=<class 'flash.text.classification.input.TextClassificationInput'>, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, **data_module_kwargs)[source]¶
Load the
TextClassificationData
from Hugging FaceDataset
objects containing text snippets and their corresponding targets.Input text snippets will be extracted from the
input_field
column in theDataset
objects. The targets will be extracted from thetarget_fields
in theDataset
objects and can be in any of our supported classification target formats. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.- Parameters
input_field¶ (
str
) – The field (column name) in theDataset
objects containing the text snippets.target_fields¶ (
Union
[str
,Sequence
[str
],None
]) – The field (column name) or list of fields in theDataset
objects containing the targets.train_hf_dataset¶ (
Optional
[object
]) – TheDataset
to use when training.val_hf_dataset¶ (
Optional
[object
]) – TheDataset
to use when validating.test_hf_dataset¶ (
Optional
[object
]) – TheDataset
to use when testing.predict_hf_dataset¶ (
Optional
[object
]) – TheDataset
to use when predicting.target_formatter¶ (
Optional
[TargetFormatter
]) – Optionally provide aTargetFormatter
to control how targets are handled. See Formatting Classification Targets for more details.input_cls¶ (
Type
[Input
]) – TheInput
type to use for loading the data.transform¶ (
Optional
[Dict
[str
,Callable
]]) – TheInputTransform
type to use.transform_kwargs¶ (
Optional
[Dict
]) – Dict of keyword arguments to be provided when instantiating the transforms.data_module_kwargs¶ (
Any
) – Additional keyword arguments to provide to theDataModule
constructor.
- Return type
- Returns
The constructed
TextClassificationData
.
Examples
>>> from datasets import Dataset >>> from flash import Trainer >>> from flash.text import TextClassifier, TextClassificationData >>> train_data = Dataset.from_dict( ... { ... "reviews": ["Best movie ever!", "Not good", "Fine I guess"], ... "targets": ["positive", "negative", "neutral"], ... } ... ) >>> predict_data = Dataset.from_dict( ... { ... "reviews": ["Worst movie ever!", "I didn't enjoy it", "It was ok"], ... } ... ) >>> datamodule = TextClassificationData.from_hf_datasets( ... "reviews", ... "targets", ... train_hf_dataset=train_data, ... predict_hf_dataset=predict_data, ... batch_size=2, ... ) >>> datamodule.num_classes 3 >>> datamodule.labels ['negative', 'neutral', 'positive'] >>> model = TextClassifier(num_classes=datamodule.num_classes, backbone="prajjwal1/bert-tiny") >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) Training... >>> trainer.predict(model, datamodule=datamodule) Predicting...
- classmethod from_json(input_field, target_fields=None, train_file=None, val_file=None, test_file=None, predict_file=None, target_formatter=None, input_cls=<class 'flash.text.classification.input.TextClassificationJSONInput'>, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, field=None, **data_module_kwargs)[source]¶
Load the
TextClassificationData
from JSON files containing text snippets and their corresponding targets.Input text snippets will be extracted from the
input_field
in the JSON objects. The targets will be extracted from thetarget_fields
in the JSON objects and can be in any of our supported classification target formats. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.- Parameters
input_field¶ (
str
) – The field in the JSON objects containing the text snippets.target_fields¶ (
Union
[str
,Sequence
[str
],None
]) – The field or list of fields in the JSON objects containing the targets.train_file¶ (
Union
[str
,bytes
,PathLike
,None
]) – The JSON file to use when training.val_file¶ (
Union
[str
,bytes
,PathLike
,None
]) – The JSON file to use when validating.test_file¶ (
Union
[str
,bytes
,PathLike
,None
]) – The JSON file to use when testing.predict_file¶ (
Union
[str
,bytes
,PathLike
,None
]) – The JSON file to use when predicting.target_formatter¶ (
Optional
[TargetFormatter
]) – Optionally provide aTargetFormatter
to control how targets are handled. See Formatting Classification Targets for more details.input_cls¶ (
Type
[Input
]) – TheInput
type to use for loading the data.transform¶ (
Optional
[Dict
[str
,Callable
]]) – TheInputTransform
type to use.transform_kwargs¶ (
Optional
[Dict
]) – Dict of keyword arguments to be provided when instantiating the transforms.field¶ (
Optional
[str
]) – To specify the field that holds the data in the JSON file.data_module_kwargs¶ (
Any
) – Additional keyword arguments to provide to theDataModule
constructor.
- Return type
- Returns
The constructed
TextClassificationData
.
Examples
The file
train_data.json
contains the following:{"reviews":"Best movie ever!","targets":"positive"} {"reviews":"Not good","targets":"negative"} {"reviews":"Fine I guess","targets":"neutral"}
The file
predict_data.json
contains the following:{"reviews":"Worst movie ever!"} {"reviews":"I didn't enjoy it"} {"reviews":"It was ok"}
>>> from flash import Trainer >>> from flash.text import TextClassifier, TextClassificationData >>> datamodule = TextClassificationData.from_json( ... "reviews", ... "targets", ... train_file="train_data.json", ... predict_file="predict_data.json", ... batch_size=2, ... ) Downloading... >>> datamodule.num_classes 3 >>> datamodule.labels ['negative', 'neutral', 'positive'] >>> model = TextClassifier(num_classes=datamodule.num_classes, backbone="prajjwal1/bert-tiny") >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) Training... >>> trainer.predict(model, datamodule=datamodule) Predicting...
- classmethod from_labelstudio(export_json=None, train_export_json=None, val_export_json=None, test_export_json=None, predict_export_json=None, data_folder=None, train_data_folder=None, val_data_folder=None, test_data_folder=None, predict_data_folder=None, input_cls=<class 'flash.core.integrations.labelstudio.input.LabelStudioTextClassificationInput'>, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, val_split=None, multi_label=False, **data_module_kwargs)[source]¶
Creates a
DataModule
object from the given export file and data directory using theInput
of nameFOLDERS
from the passed or constructedInputTransform
.- Parameters
export_json¶ (
Optional
[str
]) – path to label studio export filetrain_export_json¶ (
Optional
[str
]) – path to label studio export file for train set, overrides export_json if specifiedval_export_json¶ (
Optional
[str
]) – path to label studio export file for validationtest_export_json¶ (
Optional
[str
]) – path to label studio export file for testpredict_export_json¶ (
Optional
[str
]) – path to label studio export file for predictdata_folder¶ (
Optional
[str
]) – path to label studio data foldertrain_data_folder¶ (
Optional
[str
]) – path to label studio data folder for train data set, overrides data_folder if specifiedval_data_folder¶ (
Optional
[str
]) – path to label studio data folder for validation datatest_data_folder¶ (
Optional
[str
]) – path to label studio data folder for test datapredict_data_folder¶ (
Optional
[str
]) – path to label studio data folder for predict datainput_cls¶ (
Type
[Input
]) – TheInput
type to use for loading the data.transform¶ (
Optional
[Dict
[str
,Callable
]]) – TheInputTransform
type to use.transform_kwargs¶ (
Optional
[Dict
]) – Dict of keyword arguments to be provided when instantiating the transforms.val_split¶ (
Optional
[float
]) – Theval_split
argument to pass to theDataModule
.multi_label¶ (
Optional
[bool
]) – Whether the labels are multi encoded.data_module_kwargs¶ (
Any
) – Additional keyword arguments to use when constructing the datamodule.
- Return type
- Returns
The constructed data module.
- classmethod from_lists(train_data=None, train_targets=None, val_data=None, val_targets=None, test_data=None, test_targets=None, predict_data=None, target_formatter=None, input_cls=<class 'flash.text.classification.input.TextClassificationListInput'>, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, **data_module_kwargs)[source]¶
Load the
TextClassificationData
from lists of text snippets and corresponding lists of targets.The targets can be in any of our supported classification target formats. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.
- Parameters
train_data¶ (
Optional
[List
[str
]]) – The list of text snippets to use when training.train_targets¶ (
Union
[List
[Any
],List
[List
[Any
]],None
]) – The list of targets to use when training.val_data¶ (
Optional
[List
[str
]]) – The list of text snippets to use when validating.val_targets¶ (
Union
[List
[Any
],List
[List
[Any
]],None
]) – The list of targets to use when validating.test_data¶ (
Optional
[List
[str
]]) – The list of text snippets to use when testing.test_targets¶ (
Union
[List
[Any
],List
[List
[Any
]],None
]) – The list of targets to use when testing.predict_data¶ (
Optional
[List
[str
]]) – The list of text snippets to use when predicting.target_formatter¶ (
Optional
[TargetFormatter
]) – Optionally provide aTargetFormatter
to control how targets are handled. See Formatting Classification Targets for more details.input_cls¶ (
Type
[Input
]) – TheInput
type to use for loading the data.transform¶ (
Optional
[Dict
[str
,Callable
]]) – TheInputTransform
type to use.transform_kwargs¶ (
Optional
[Dict
]) – Dict of keyword arguments to be provided when instantiating the transforms.data_module_kwargs¶ (
Any
) – Additional keyword arguments to provide to theDataModule
constructor.
- Return type
- Returns
The constructed
TextClassificationData
.
Examples
>>> from flash import Trainer >>> from flash.text import TextClassifier, TextClassificationData >>> datamodule = TextClassificationData.from_lists( ... train_data=["Best movie ever!", "Not good", "Fine I guess"], ... train_targets=["positive", "negative", "neutral"], ... predict_data=["Worst movie ever!", "I didn't enjoy it", "It was ok"], ... batch_size=2, ... ) >>> datamodule.num_classes 3 >>> datamodule.labels ['negative', 'neutral', 'positive'] >>> model = TextClassifier(num_classes=datamodule.num_classes, backbone="prajjwal1/bert-tiny") >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) Training... >>> trainer.predict(model, datamodule=datamodule) Predicting...
- classmethod from_parquet(input_field, target_fields=None, train_file=None, val_file=None, test_file=None, predict_file=None, target_formatter=None, input_cls=<class 'flash.text.classification.input.TextClassificationParquetInput'>, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, **data_module_kwargs)[source]¶
Load the
TextClassificationData
from PARQUET files containing text snippets and their corresponding targets.Input text snippets will be extracted from the
input_field
column in the PARQUET files. The targets will be extracted from thetarget_fields
in the PARQUET files and can be in any of our supported classification target formats. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.- Parameters
input_field¶ (
str
) – The field (column name) in the PARQUET files containing the text snippets.target_fields¶ (
Union
[str
,Sequence
[str
],None
]) – The field (column name) or list of fields in the PARQUET files containing the targets.train_file¶ (
Union
[str
,bytes
,PathLike
,None
]) – The PARQUET file to use when training.val_file¶ (
Union
[str
,bytes
,PathLike
,None
]) – The PARQUET file to use when validating.test_file¶ (
Union
[str
,bytes
,PathLike
,None
]) – The PARQUET file to use when testing.predict_file¶ (
Union
[str
,bytes
,PathLike
,None
]) – The PARQUET file to use when predicting.target_formatter¶ (
Optional
[TargetFormatter
]) – Optionally provide aTargetFormatter
to control how targets are handled. See Formatting Classification Targets for more details.input_cls¶ (
Type
[Input
]) – TheInput
type to use for loading the data.transform¶ (
Optional
[Dict
[str
,Callable
]]) – TheInputTransform
type to use.transform_kwargs¶ (
Optional
[Dict
]) – Dict of keyword arguments to be provided when instantiating the transforms.data_module_kwargs¶ (
Any
) – Additional keyword arguments to provide to theDataModule
constructor.
- Return type
- Returns
The constructed
TextClassificationData
.
Examples
The file
train_data.parquet
contains the following contents encoded in the PARQUET format:reviews,targets Best movie ever!,positive Not good,negative Fine I guess,neutral
The file
predict_data.parquet
contains the following contents encoded in the PARQUET format:reviews Worst movie ever! I didn't enjoy it It was ok
>>> from flash import Trainer >>> from flash.text import TextClassifier, TextClassificationData >>> datamodule = TextClassificationData.from_parquet( ... "reviews", ... "targets", ... train_file="train_data.parquet", ... predict_file="predict_data.parquet", ... batch_size=2, ... ) Downloading... >>> datamodule.num_classes 3 >>> datamodule.labels ['negative', 'neutral', 'positive'] >>> model = TextClassifier(num_classes=datamodule.num_classes, backbone="prajjwal1/bert-tiny") >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) Training... >>> trainer.predict(model, datamodule=datamodule) Predicting...
- input_transform_cls¶