TabularClassificationData¶
- class flash.tabular.classification.data.TabularClassificationData(train_input=None, val_input=None, test_input=None, predict_input=None, data_fetcher=None, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, val_split=None, batch_size=None, num_workers=0, sampler=None, pin_memory=True, persistent_workers=False)[source]¶
The
TabularClassificationDataclass is aDataModulewith a set of classmethods for loading data for tabular classification.- classmethod from_csv(categorical_fields=None, numerical_fields=None, target_fields=None, parameters=None, train_file=None, val_file=None, test_file=None, predict_file=None, target_formatter=None, input_cls=<class 'flash.tabular.classification.input.TabularClassificationCSVInput'>, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, **data_module_kwargs)[source]¶
Creates a
TabularClassificationDataobject from the given CSV files.Note
The
categorical_fields,numerical_fields, andtarget_fieldsdo not need to be provided ifparametersare passed instead. These can be obtained from theparametersattribute of theTabularDataobject that contains your training data.The targets will be extracted from the
target_fieldscolumns in the CSV files and can be in any of our supported classification target formats. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.- Parameters
categorical_fields¶ (
Union[str,List[str],None]) – The fields (column names) in the CSV files containing categorical data.numerical_fields¶ (
Union[str,List[str],None]) – The fields (column names) in the CSV files containing numerical data.target_fields¶ (
Union[str,List[str],None]) – The field (column name) or list of fields in the CSV files containing the targets.parameters¶ (
Optional[Dict[str,Any]]) – Parameters to use ifcategorical_fields,numerical_fields, andtarget_fieldsare not provided (e.g. when loading data for inference or validation).train_file¶ (
Optional[str]) – The path to the CSV file to use when training.val_file¶ (
Optional[str]) – The path to the CSV file to use when validating.test_file¶ (
Optional[str]) – The path to the CSV file to use when testing.predict_file¶ (
Optional[str]) – The path to the CSV file to use when predicting.target_formatter¶ (
Optional[TargetFormatter]) – Optionally provide aTargetFormatterto control how targets are handled. See Formatting Classification Targets for more details.input_cls¶ (
Type[Input]) – TheInputtype to use for loading the data.transform¶ (
TypeVar(INPUT_TRANSFORM_TYPE,Type[flash.core.data.io.input_transform.InputTransform],Callable,Tuple[Union[StrEnum,str],Dict[str,Any]],Union[StrEnum,str],None)) – TheInputTransformtype to use.transform_kwargs¶ (
Optional[Dict]) – Dict of keyword arguments to be provided when instantiating the transforms.data_module_kwargs¶ (
Any) – Additional keyword arguments to provide to theDataModuleconstructor.
- Return type
- Returns
The constructed
TabularClassificationData.
Examples
The files can be in Comma Separated Values (CSV) format with either a
.csvor.txtextension.We have a
train_data.csvwith the following contents:animal,friendly,weight cat,yes,6 dog,yes,10 cat,no,5
and a
predict_data.csvwith the following contents:friendly,weight yes,7 no,12 yes,5
>>> from flash import Trainer >>> from flash.tabular import TabularClassifier, TabularClassificationData >>> datamodule = TabularClassificationData.from_csv( ... "friendly", ... "weight", ... "animal", ... train_file="train_data.csv", ... predict_file="predict_data.csv", ... batch_size=4, ... ) >>> datamodule.num_classes 2 >>> datamodule.labels ['cat', 'dog'] >>> model = TabularClassifier.from_data(datamodule, backbone="tabnet") >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) Training... >>> trainer.predict(model, datamodule=datamodule) Predicting...
Alternatively, the files can be in Tab Separated Values (TSV) format with a
.tsvextension.We have a
train_data.tsvwith the following contents:animal friendly weight cat yes 6 dog yes 10 cat no 5
and a
predict_data.tsvwith the following contents:friendly weight yes 7 no 12 yes 5
>>> from flash import Trainer >>> from flash.tabular import TabularClassifier, TabularClassificationData >>> datamodule = TabularClassificationData.from_csv( ... "friendly", ... "weight", ... "animal", ... train_file="train_data.tsv", ... predict_file="predict_data.tsv", ... batch_size=4, ... ) >>> datamodule.num_classes 2 >>> datamodule.labels ['cat', 'dog'] >>> model = TabularClassifier.from_data(datamodule, backbone="tabnet") >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) Training... >>> trainer.predict(model, datamodule=datamodule) Predicting...
- classmethod from_data_frame(categorical_fields=None, numerical_fields=None, target_fields=None, parameters=None, train_data_frame=None, val_data_frame=None, test_data_frame=None, predict_data_frame=None, target_formatter=None, input_cls=<class 'flash.tabular.classification.input.TabularClassificationDataFrameInput'>, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, **data_module_kwargs)[source]¶
Creates a
TabularClassificationDataobject from the given data frames.Note
The
categorical_fields,numerical_fields, andtarget_fieldsdo not need to be provided ifparametersare passed instead. These can be obtained from theparametersattribute of theTabularDataobject that contains your training data.The targets will be extracted from the
target_fieldsin the data frames and can be in any of our supported classification target formats. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.- Parameters
categorical_fields¶ (
Union[str,List[str],None]) – The fields (column names) in the data frames containing categorical data.numerical_fields¶ (
Union[str,List[str],None]) – The fields (column names) in the data frames containing numerical data.target_fields¶ (
Union[str,List[str],None]) – The field (column name) or list of fields in the data frames containing the targets.parameters¶ (
Optional[Dict[str,Any]]) – Parameters to use ifcategorical_fields,numerical_fields, andtarget_fieldsare not provided (e.g. when loading data for inference or validation).train_data_frame¶ (
Optional[DataFrame]) – The DataFrame to use when training.val_data_frame¶ (
Optional[DataFrame]) – The DataFrame to use when validating.test_data_frame¶ (
Optional[DataFrame]) – The DataFrame to use when testing.predict_data_frame¶ (
Optional[DataFrame]) – The DataFrame to use when predicting.target_formatter¶ (
Optional[TargetFormatter]) – Optionally provide aTargetFormatterto control how targets are handled. See Formatting Classification Targets for more details.input_cls¶ (
Type[Input]) – TheInputtype to use for loading the data.transform¶ (
TypeVar(INPUT_TRANSFORM_TYPE,Type[flash.core.data.io.input_transform.InputTransform],Callable,Tuple[Union[StrEnum,str],Dict[str,Any]],Union[StrEnum,str],None)) – TheInputTransformtype to use.transform_kwargs¶ (
Optional[Dict]) – Dict of keyword arguments to be provided when instantiating the transforms.data_module_kwargs¶ (
Any) – Additional keyword arguments to provide to theDataModuleconstructor.
- Return type
- Returns
The constructed
TabularClassificationData.
Examples
We have a DataFrame
train_datawith the following contents:>>> train_data.head(3) animal friendly weight 0 cat yes 6 1 dog yes 10 2 cat no 5
and a DataFrame
predict_datawith the following contents:>>> predict_data.head(3) friendly weight 0 yes 7 1 no 12 2 yes 5
>>> from flash import Trainer >>> from flash.tabular import TabularClassifier, TabularClassificationData >>> datamodule = TabularClassificationData.from_data_frame( ... "friendly", ... "weight", ... "animal", ... train_data_frame=train_data, ... predict_data_frame=predict_data, ... batch_size=4, ... ) >>> datamodule.num_classes 2 >>> datamodule.labels ['cat', 'dog'] >>> model = TabularClassifier.from_data(datamodule, backbone="tabnet") >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) Training... >>> trainer.predict(model, datamodule=datamodule) Predicting...
- classmethod from_dicts(categorical_fields=None, numerical_fields=None, target_fields=None, parameters=None, train_dict=None, val_dict=None, test_dict=None, predict_dict=None, target_formatter=None, input_cls=<class 'flash.tabular.classification.input.TabularClassificationDictInput'>, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, **data_module_kwargs)[source]¶
Creates a
TabularClassificationDataobject from the given dictionary.Note
The
categorical_fields,numerical_fields, andtarget_fieldsdo not need to be provided ifparametersare passed instead. These can be obtained from theparametersattribute of theTabularDataobject that contains your training data.The targets will be extracted from the
target_fieldsin the dict and can be in any of our supported classification target formats. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.- Parameters
categorical_fields¶ (
Union[str,List[str],None]) – The fields (column names) in the dictionary containing categorical data.numerical_fields¶ (
Union[str,List[str],None]) – The fields (column names) in the dictionary containing numerical data.target_fields¶ (
Union[str,List[str],None]) – The field (column name) or list of fields in the dictionary containing the targets.parameters¶ (
Optional[Dict[str,Any]]) – Parameters to use ifcategorical_fields,numerical_fields, andtarget_fieldsare not provided (e.g. when loading data for inference or validation).train_dict¶ (
Optional[Dict[str,List[Any]]]) – The data to use when training.val_dict¶ (
Optional[Dict[str,List[Any]]]) – The data to use when validating.test_dict¶ (
Optional[Dict[str,List[Any]]]) – The data to use when testing.predict_dict¶ (
Optional[Dict[str,List[Any]]]) – The data to use when predicting.target_formatter¶ (
Optional[TargetFormatter]) – Optionally provide aTargetFormatterto control how targets are handled. See Formatting Classification Targets for more details.input_cls¶ (
Type[Input]) – TheInputtype to use for loading the data.transform¶ (
TypeVar(INPUT_TRANSFORM_TYPE,Type[flash.core.data.io.input_transform.InputTransform],Callable,Tuple[Union[StrEnum,str],Dict[str,Any]],Union[StrEnum,str],None)) – TheInputTransformtype to use.transform_kwargs¶ (
Optional[Dict]) – Dict of keyword arguments to be provided when instantiating the transforms.data_module_kwargs¶ (
Any) – Additional keyword arguments to provide to theDataModuleconstructor.
- Return type
- Returns
The constructed
TabularClassificationData.
Examples
We have a dictionary
train_datawith the following contents:{ "age": [2, 4, 1], "animal": ["cat", "dog", "cat"], "weight": [6, 10, 5] }
and a dictionary
predict_datawith the following contents:{ "animal": ["dog", "dog", "cat"], "weight": [7, 12, 5] }
>>> from flash import Trainer >>> from flash.tabular import TabularClassifier, TabularClassificationData >>> datamodule = TabularClassificationData.from_dicts( ... "friendly", ... "weight", ... "animal", ... train_dict=train_data, ... predict_dict=predict_data, ... batch_size=4, ... ) >>> datamodule.num_classes 2 >>> datamodule.labels ['cat', 'dog'] >>> model = TabularClassifier.from_data(datamodule, backbone="tabnet") >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) Training... >>> trainer.predict(model, datamodule=datamodule) Predicting...
- classmethod from_lists(categorical_fields=None, numerical_fields=None, target_fields=None, parameters=None, train_list=None, val_list=None, test_list=None, predict_list=None, target_formatter=None, input_cls=<class 'flash.tabular.classification.input.TabularClassificationListInput'>, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, **data_module_kwargs)[source]¶
Creates a
TabularClassificationDataobject from the given data (in the form of list of a tuple or a dictionary).Note
The
categorical_fields,numerical_fields, andtarget_fieldsdo not need to be provided ifparametersare passed instead. These can be obtained from theparametersattribute of theTabularDataobject that contains your training data.The targets will be extracted from the
target_fieldsin the dict and can be in any of our supported classification target formats. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.- Parameters
categorical_fields¶ (
Union[str,List[str],None]) – The fields (column names) in the dictionary containing categorical data.numerical_fields¶ (
Union[str,List[str],None]) – The fields (column names) in the dictionary containing numerical data.target_fields¶ (
Union[str,List[str],None]) – The field (column name) or list of fields in the dictionary containing the targets.parameters¶ (
Optional[Dict[str,Any]]) – Parameters to use ifcategorical_fields,numerical_fields, andtarget_fieldsare not provided (e.g. when loading data for inference or validation).train_list¶ (
Optional[List[Union[tuple,dict]]]) – The data to use when training.val_list¶ (
Optional[List[Union[tuple,dict]]]) – The data to use when validating.testing. (test_lists _sphinx_paramlinks_flash.tabular.classification.data.TabularClassificationData.from_lists.The data to use when) –
predict_list¶ (
Optional[List[Union[tuple,dict]]]) – The data to use when predicting.target_formatter¶ (
Optional[TargetFormatter]) – Optionally provide aTargetFormatterto control how targets are handled. See Formatting Classification Targets for more details.input_cls¶ (
Type[Input]) – TheInputtype to use for loading the data.transform¶ (
TypeVar(INPUT_TRANSFORM_TYPE,Type[flash.core.data.io.input_transform.InputTransform],Callable,Tuple[Union[StrEnum,str],Dict[str,Any]],Union[StrEnum,str],None)) – TheInputTransformtype to use.transform_kwargs¶ (
Optional[Dict]) – Dict of keyword arguments to be provided when instantiating the transforms.data_module_kwargs¶ (
Any) – Additional keyword arguments to provide to theDataModuleconstructor.
- Return type
- Returns
The constructed
TabularClassificationData.
Examples
We have a list of dictionaries
train_datawith the following contents:[ {"animal": "cat", "friendly": "yes", "weight": 6}, {"animal": "dog", "friendly": "yes", "weight": 10}, {"animal": "cat", "friendly": "no", "weight": 5}, ]
and a list of dictionaries
predict_datawith the following contents:[ {"friendly": "yes", "weight": 7}, {"friendly": "no", "weight": 12}, {"friendly": "yes", "weight": 5}, ]
>>> from flash import Trainer >>> from flash.tabular import TabularClassifier, TabularClassificationData >>> datamodule = TabularClassificationData.from_lists( ... "friendly", ... "weight", ... "animal", ... train_list=train_data, ... predict_list=predict_data, ... batch_size=4, ... ) >>> datamodule.num_classes 2 >>> datamodule.labels ['cat', 'dog'] >>> model = TabularClassifier.from_data(datamodule, backbone="tabnet") >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) Training... >>> trainer.predict(model, datamodule=datamodule) Predicting...