TextClassificationData¶
- class flash.text.classification.data.TextClassificationData(train_input=None, val_input=None, test_input=None, predict_input=None, data_fetcher=None, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, val_split=None, batch_size=None, num_workers=0, sampler=None, pin_memory=True, persistent_workers=False)[source]¶
The
TextClassificationDataclass is aDataModulewith a set of classmethods for loading data for text classification.- classmethod from_csv(input_field, target_fields=None, train_file=None, val_file=None, test_file=None, predict_file=None, target_formatter=None, input_cls=<class 'flash.text.classification.input.TextClassificationCSVInput'>, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, **data_module_kwargs)[source]¶
Load the
TextClassificationDatafrom CSV files containing text snippets and their corresponding targets.Input text snippets will be extracted from the
input_fieldcolumn in the CSV files. The targets will be extracted from thetarget_fieldsin the CSV files and can be in any of our supported classification target formats. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.- Parameters
input_field¶ (
str) – The field (column name) in the CSV files containing the text snippets.target_fields¶ (
Union[str,Sequence[str],None]) – The field (column name) or list of fields in the CSV files containing the targets.train_file¶ (
Union[str,bytes,PathLike,None]) – The CSV file to use when training.val_file¶ (
Union[str,bytes,PathLike,None]) – The CSV file to use when validating.test_file¶ (
Union[str,bytes,PathLike,None]) – The CSV file to use when testing.predict_file¶ (
Union[str,bytes,PathLike,None]) – The CSV file to use when predicting.target_formatter¶ (
Optional[TargetFormatter]) – Optionally provide aTargetFormatterto control how targets are handled. See Formatting Classification Targets for more details.input_cls¶ (
Type[Input]) – TheInputtype to use for loading the data.transform¶ (
Optional[Dict[str,Callable]]) – TheInputTransformtype to use.transform_kwargs¶ (
Optional[Dict]) – Dict of keyword arguments to be provided when instantiating the transforms.data_module_kwargs¶ (
Any) – Additional keyword arguments to provide to theDataModuleconstructor.
- Return type
- Returns
The constructed
TextClassificationData.
Examples
The files can be in Comma Separated Values (CSV) format with either a
.csvor.txtextension.The file
train_data.csvcontains the following:reviews,targets Best movie ever!,positive Not good,negative Fine I guess,neutral
The file
predict_data.csvcontains the following:reviews Worst movie ever! I didn't enjoy it It was ok
>>> from flash import Trainer >>> from flash.text import TextClassifier, TextClassificationData >>> datamodule = TextClassificationData.from_csv( ... "reviews", ... "targets", ... train_file="train_data.csv", ... predict_file="predict_data.csv", ... batch_size=2, ... ) >>> datamodule.num_classes 3 >>> datamodule.labels ['negative', 'neutral', 'positive'] >>> model = TextClassifier(num_classes=datamodule.num_classes, backbone="prajjwal1/bert-tiny") >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) Training... >>> trainer.predict(model, datamodule=datamodule) Predicting...
Alternatively, the files can be in Tab Separated Values (TSV) format with a
.tsvextension.The file
train_data.tsvcontains the following:reviews targets Best movie ever! positive Not good negative Fine I guess neutral
The file
predict_data.tsvcontains the following:reviews Worst movie ever! I didn't enjoy it It was ok
>>> from flash import Trainer >>> from flash.text import TextClassifier, TextClassificationData >>> datamodule = TextClassificationData.from_csv( ... "reviews", ... "targets", ... train_file="train_data.tsv", ... predict_file="predict_data.tsv", ... batch_size=2, ... ) >>> datamodule.num_classes 3 >>> datamodule.labels ['negative', 'neutral', 'positive'] >>> model = TextClassifier(num_classes=datamodule.num_classes, backbone="prajjwal1/bert-tiny") >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) Training... >>> trainer.predict(model, datamodule=datamodule) Predicting...
- classmethod from_data_frame(input_field, target_fields=None, train_data_frame=None, val_data_frame=None, test_data_frame=None, predict_data_frame=None, target_formatter=None, input_cls=<class 'flash.text.classification.input.TextClassificationDataFrameInput'>, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, **data_module_kwargs)[source]¶
Load the
TextClassificationDatafrom PandasDataFrameobjects containing text snippets and their corresponding targets.Input text snippets will be extracted from the
input_fieldcolumn in theDataFrameobjects. The targets will be extracted from thetarget_fieldsin theDataFrameobjects and can be in any of our supported classification target formats. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.- Parameters
input_field¶ (
str) – The field (column name) in theDataFrameobjects containing the text snippets.target_fields¶ (
Union[str,Sequence[str],None]) – The field (column name) or list of fields in theDataFrameobjects containing the targets.train_data_frame¶ (
Optional[DataFrame]) – TheDataFrameto use when training.val_data_frame¶ (
Optional[DataFrame]) – TheDataFrameto use when validating.test_data_frame¶ (
Optional[DataFrame]) – TheDataFrameto use when testing.predict_data_frame¶ (
Optional[DataFrame]) – TheDataFrameto use when predicting.target_formatter¶ (
Optional[TargetFormatter]) – Optionally provide aTargetFormatterto control how targets are handled. See Formatting Classification Targets for more details.input_cls¶ (
Type[Input]) – TheInputtype to use for loading the data.transform¶ (
Optional[Dict[str,Callable]]) – TheInputTransformtype to use.transform_kwargs¶ (
Optional[Dict]) – Dict of keyword arguments to be provided when instantiating the transforms.data_module_kwargs¶ (
Any) – Additional keyword arguments to provide to theDataModuleconstructor.
- Return type
- Returns
The constructed
TextClassificationData.
Examples
>>> from pandas import DataFrame >>> from flash import Trainer >>> from flash.text import TextClassifier, TextClassificationData >>> train_data = DataFrame.from_dict( ... { ... "reviews": ["Best movie ever!", "Not good", "Fine I guess"], ... "targets": ["positive", "negative", "neutral"], ... } ... ) >>> predict_data = DataFrame.from_dict( ... { ... "reviews": ["Worst movie ever!", "I didn't enjoy it", "It was ok"], ... } ... ) >>> datamodule = TextClassificationData.from_data_frame( ... "reviews", ... "targets", ... train_data_frame=train_data, ... predict_data_frame=predict_data, ... batch_size=2, ... ) >>> datamodule.num_classes 3 >>> datamodule.labels ['negative', 'neutral', 'positive'] >>> model = TextClassifier(num_classes=datamodule.num_classes, backbone="prajjwal1/bert-tiny") >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) Training... >>> trainer.predict(model, datamodule=datamodule) Predicting...
- classmethod from_hf_datasets(input_field, target_fields=None, train_hf_dataset=None, val_hf_dataset=None, test_hf_dataset=None, predict_hf_dataset=None, target_formatter=None, input_cls=<class 'flash.text.classification.input.TextClassificationInput'>, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, **data_module_kwargs)[source]¶
Load the
TextClassificationDatafrom Hugging FaceDatasetobjects containing text snippets and their corresponding targets.Input text snippets will be extracted from the
input_fieldcolumn in theDatasetobjects. The targets will be extracted from thetarget_fieldsin theDatasetobjects and can be in any of our supported classification target formats. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.- Parameters
input_field¶ (
str) – The field (column name) in theDatasetobjects containing the text snippets.target_fields¶ (
Union[str,Sequence[str],None]) – The field (column name) or list of fields in theDatasetobjects containing the targets.train_hf_dataset¶ (
Optional[object]) – TheDatasetto use when training.val_hf_dataset¶ (
Optional[object]) – TheDatasetto use when validating.test_hf_dataset¶ (
Optional[object]) – TheDatasetto use when testing.predict_hf_dataset¶ (
Optional[object]) – TheDatasetto use when predicting.target_formatter¶ (
Optional[TargetFormatter]) – Optionally provide aTargetFormatterto control how targets are handled. See Formatting Classification Targets for more details.input_cls¶ (
Type[Input]) – TheInputtype to use for loading the data.transform¶ (
Optional[Dict[str,Callable]]) – TheInputTransformtype to use.transform_kwargs¶ (
Optional[Dict]) – Dict of keyword arguments to be provided when instantiating the transforms.data_module_kwargs¶ (
Any) – Additional keyword arguments to provide to theDataModuleconstructor.
- Return type
- Returns
The constructed
TextClassificationData.
Examples
>>> from datasets import Dataset >>> from flash import Trainer >>> from flash.text import TextClassifier, TextClassificationData >>> train_data = Dataset.from_dict( ... { ... "reviews": ["Best movie ever!", "Not good", "Fine I guess"], ... "targets": ["positive", "negative", "neutral"], ... } ... ) >>> predict_data = Dataset.from_dict( ... { ... "reviews": ["Worst movie ever!", "I didn't enjoy it", "It was ok"], ... } ... ) >>> datamodule = TextClassificationData.from_hf_datasets( ... "reviews", ... "targets", ... train_hf_dataset=train_data, ... predict_hf_dataset=predict_data, ... batch_size=2, ... ) >>> datamodule.num_classes 3 >>> datamodule.labels ['negative', 'neutral', 'positive'] >>> model = TextClassifier(num_classes=datamodule.num_classes, backbone="prajjwal1/bert-tiny") >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) Training... >>> trainer.predict(model, datamodule=datamodule) Predicting...
- classmethod from_json(input_field, target_fields=None, train_file=None, val_file=None, test_file=None, predict_file=None, target_formatter=None, input_cls=<class 'flash.text.classification.input.TextClassificationJSONInput'>, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, field=None, **data_module_kwargs)[source]¶
Load the
TextClassificationDatafrom JSON files containing text snippets and their corresponding targets.Input text snippets will be extracted from the
input_fieldin the JSON objects. The targets will be extracted from thetarget_fieldsin the JSON objects and can be in any of our supported classification target formats. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.- Parameters
input_field¶ (
str) – The field in the JSON objects containing the text snippets.target_fields¶ (
Union[str,Sequence[str],None]) – The field or list of fields in the JSON objects containing the targets.train_file¶ (
Union[str,bytes,PathLike,None]) – The JSON file to use when training.val_file¶ (
Union[str,bytes,PathLike,None]) – The JSON file to use when validating.test_file¶ (
Union[str,bytes,PathLike,None]) – The JSON file to use when testing.predict_file¶ (
Union[str,bytes,PathLike,None]) – The JSON file to use when predicting.target_formatter¶ (
Optional[TargetFormatter]) – Optionally provide aTargetFormatterto control how targets are handled. See Formatting Classification Targets for more details.input_cls¶ (
Type[Input]) – TheInputtype to use for loading the data.transform¶ (
Optional[Dict[str,Callable]]) – TheInputTransformtype to use.transform_kwargs¶ (
Optional[Dict]) – Dict of keyword arguments to be provided when instantiating the transforms.field¶ (
Optional[str]) – To specify the field that holds the data in the JSON file.data_module_kwargs¶ (
Any) – Additional keyword arguments to provide to theDataModuleconstructor.
- Return type
- Returns
The constructed
TextClassificationData.
Examples
The file
train_data.jsoncontains the following:{"reviews":"Best movie ever!","targets":"positive"} {"reviews":"Not good","targets":"negative"} {"reviews":"Fine I guess","targets":"neutral"}
The file
predict_data.jsoncontains the following:{"reviews":"Worst movie ever!"} {"reviews":"I didn't enjoy it"} {"reviews":"It was ok"}
>>> from flash import Trainer >>> from flash.text import TextClassifier, TextClassificationData >>> datamodule = TextClassificationData.from_json( ... "reviews", ... "targets", ... train_file="train_data.json", ... predict_file="predict_data.json", ... batch_size=2, ... ) ... >>> datamodule.num_classes 3 >>> datamodule.labels ['negative', 'neutral', 'positive'] >>> model = TextClassifier(num_classes=datamodule.num_classes, backbone="prajjwal1/bert-tiny") >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) Training... >>> trainer.predict(model, datamodule=datamodule) Predicting...
- classmethod from_labelstudio(export_json=None, train_export_json=None, val_export_json=None, test_export_json=None, predict_export_json=None, data_folder=None, train_data_folder=None, val_data_folder=None, test_data_folder=None, predict_data_folder=None, input_cls=<class 'flash.core.integrations.labelstudio.input.LabelStudioTextClassificationInput'>, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, val_split=None, multi_label=False, **data_module_kwargs)[source]¶
Creates a
DataModuleobject from the given export file and data directory using theInputof nameFOLDERSfrom the passed or constructedInputTransform.- Parameters
export_json¶ (
Optional[str]) – path to label studio export filetrain_export_json¶ (
Optional[str]) – path to label studio export file for train set, overrides export_json if specifiedval_export_json¶ (
Optional[str]) – path to label studio export file for validationtest_export_json¶ (
Optional[str]) – path to label studio export file for testpredict_export_json¶ (
Optional[str]) – path to label studio export file for predictdata_folder¶ (
Optional[str]) – path to label studio data foldertrain_data_folder¶ (
Optional[str]) – path to label studio data folder for train data set, overrides data_folder if specifiedval_data_folder¶ (
Optional[str]) – path to label studio data folder for validation datatest_data_folder¶ (
Optional[str]) – path to label studio data folder for test datapredict_data_folder¶ (
Optional[str]) – path to label studio data folder for predict datainput_cls¶ (
Type[Input]) – TheInputtype to use for loading the data.transform¶ (
Optional[Dict[str,Callable]]) – TheInputTransformtype to use.transform_kwargs¶ (
Optional[Dict]) – Dict of keyword arguments to be provided when instantiating the transforms.val_split¶ (
Optional[float]) – Theval_splitargument to pass to theDataModule.multi_label¶ (
Optional[bool]) – Whether the labels are multi encoded.data_module_kwargs¶ (
Any) – Additional keyword arguments to use when constructing the datamodule.
- Return type
- Returns
The constructed data module.
- classmethod from_lists(train_data=None, train_targets=None, val_data=None, val_targets=None, test_data=None, test_targets=None, predict_data=None, target_formatter=None, input_cls=<class 'flash.text.classification.input.TextClassificationListInput'>, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, **data_module_kwargs)[source]¶
Load the
TextClassificationDatafrom lists of text snippets and corresponding lists of targets.The targets can be in any of our supported classification target formats. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.
- Parameters
train_data¶ (
Optional[List[str]]) – The list of text snippets to use when training.train_targets¶ (
Union[List[Any],List[List[Any]],None]) – The list of targets to use when training.val_data¶ (
Optional[List[str]]) – The list of text snippets to use when validating.val_targets¶ (
Union[List[Any],List[List[Any]],None]) – The list of targets to use when validating.test_data¶ (
Optional[List[str]]) – The list of text snippets to use when testing.test_targets¶ (
Union[List[Any],List[List[Any]],None]) – The list of targets to use when testing.predict_data¶ (
Optional[List[str]]) – The list of text snippets to use when predicting.target_formatter¶ (
Optional[TargetFormatter]) – Optionally provide aTargetFormatterto control how targets are handled. See Formatting Classification Targets for more details.input_cls¶ (
Type[Input]) – TheInputtype to use for loading the data.transform¶ (
Optional[Dict[str,Callable]]) – TheInputTransformtype to use.transform_kwargs¶ (
Optional[Dict]) – Dict of keyword arguments to be provided when instantiating the transforms.data_module_kwargs¶ (
Any) – Additional keyword arguments to provide to theDataModuleconstructor.
- Return type
- Returns
The constructed
TextClassificationData.
Examples
>>> from flash import Trainer >>> from flash.text import TextClassifier, TextClassificationData >>> datamodule = TextClassificationData.from_lists( ... train_data=["Best movie ever!", "Not good", "Fine I guess"], ... train_targets=["positive", "negative", "neutral"], ... predict_data=["Worst movie ever!", "I didn't enjoy it", "It was ok"], ... batch_size=2, ... ) >>> datamodule.num_classes 3 >>> datamodule.labels ['negative', 'neutral', 'positive'] >>> model = TextClassifier(num_classes=datamodule.num_classes, backbone="prajjwal1/bert-tiny") >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) Training... >>> trainer.predict(model, datamodule=datamodule) Predicting...
- classmethod from_parquet(input_field, target_fields=None, train_file=None, val_file=None, test_file=None, predict_file=None, target_formatter=None, input_cls=<class 'flash.text.classification.input.TextClassificationParquetInput'>, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, **data_module_kwargs)[source]¶
Load the
TextClassificationDatafrom PARQUET files containing text snippets and their corresponding targets.Input text snippets will be extracted from the
input_fieldcolumn in the PARQUET files. The targets will be extracted from thetarget_fieldsin the PARQUET files and can be in any of our supported classification target formats. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.- Parameters
input_field¶ (
str) – The field (column name) in the PARQUET files containing the text snippets.target_fields¶ (
Union[str,Sequence[str],None]) – The field (column name) or list of fields in the PARQUET files containing the targets.train_file¶ (
Union[str,bytes,PathLike,None]) – The PARQUET file to use when training.val_file¶ (
Union[str,bytes,PathLike,None]) – The PARQUET file to use when validating.test_file¶ (
Union[str,bytes,PathLike,None]) – The PARQUET file to use when testing.predict_file¶ (
Union[str,bytes,PathLike,None]) – The PARQUET file to use when predicting.target_formatter¶ (
Optional[TargetFormatter]) – Optionally provide aTargetFormatterto control how targets are handled. See Formatting Classification Targets for more details.input_cls¶ (
Type[Input]) – TheInputtype to use for loading the data.transform¶ (
Optional[Dict[str,Callable]]) – TheInputTransformtype to use.transform_kwargs¶ (
Optional[Dict]) – Dict of keyword arguments to be provided when instantiating the transforms.data_module_kwargs¶ (
Any) – Additional keyword arguments to provide to theDataModuleconstructor.
- Return type
- Returns
The constructed
TextClassificationData.
Examples
The file
train_data.parquetcontains the following contents encoded in the PARQUET format:reviews,targets Best movie ever!,positive Not good,negative Fine I guess,neutral
The file
predict_data.parquetcontains the following contents encoded in the PARQUET format:reviews Worst movie ever! I didn't enjoy it It was ok
>>> from flash import Trainer >>> from flash.text import TextClassifier, TextClassificationData >>> datamodule = TextClassificationData.from_parquet( ... "reviews", ... "targets", ... train_file="train_data.parquet", ... predict_file="predict_data.parquet", ... batch_size=2, ... ) ... >>> datamodule.num_classes 3 >>> datamodule.labels ['negative', 'neutral', 'positive'] >>> model = TextClassifier(num_classes=datamodule.num_classes, backbone="prajjwal1/bert-tiny") >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) Training... >>> trainer.predict(model, datamodule=datamodule) Predicting...
- input_transform_cls¶