TabularClassificationData¶
- class flash.tabular.classification.data.TabularClassificationData(train_input=None, val_input=None, test_input=None, predict_input=None, data_fetcher=None, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, val_split=None, batch_size=None, num_workers=0, sampler=None, pin_memory=True, persistent_workers=False)[source]¶
The
TabularClassificationData
class is aDataModule
with a set of classmethods for loading data for tabular classification.- classmethod from_csv(categorical_fields=None, numerical_fields=None, target_fields=None, parameters=None, train_file=None, val_file=None, test_file=None, predict_file=None, target_formatter=None, input_cls=<class 'flash.tabular.classification.input.TabularClassificationCSVInput'>, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, **data_module_kwargs)[source]¶
Creates a
TabularClassificationData
object from the given CSV files.Note
The
categorical_fields
,numerical_fields
, andtarget_fields
do not need to be provided ifparameters
are passed instead. These can be obtained from theparameters
attribute of theTabularData
object that contains your training data.The targets will be extracted from the
target_fields
columns in the CSV files and can be in any of our supported classification target formats. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.- Parameters
categorical_fields¶ (
Union
[str
,List
[str
],None
]) – The fields (column names) in the CSV files containing categorical data.numerical_fields¶ (
Union
[str
,List
[str
],None
]) – The fields (column names) in the CSV files containing numerical data.target_fields¶ (
Union
[str
,List
[str
],None
]) – The field (column name) or list of fields in the CSV files containing the targets.parameters¶ (
Optional
[Dict
[str
,Any
]]) – Parameters to use ifcategorical_fields
,numerical_fields
, andtarget_fields
are not provided (e.g. when loading data for inference or validation).train_file¶ (
Optional
[str
]) – The path to the CSV file to use when training.val_file¶ (
Optional
[str
]) – The path to the CSV file to use when validating.test_file¶ (
Optional
[str
]) – The path to the CSV file to use when testing.predict_file¶ (
Optional
[str
]) – The path to the CSV file to use when predicting.target_formatter¶ (
Optional
[TargetFormatter
]) – Optionally provide aTargetFormatter
to control how targets are handled. See Formatting Classification Targets for more details.input_cls¶ (
Type
[Input
]) – TheInput
type to use for loading the data.transform¶ (
TypeVar
(INPUT_TRANSFORM_TYPE
,Type
[flash.core.data.io.input_transform.InputTransform],Callable
,Tuple
[Union
[StrEnum
,str
],Dict
[str
,Any
]],Union
[StrEnum
,str
],None
)) – TheInputTransform
type to use.transform_kwargs¶ (
Optional
[Dict
]) – Dict of keyword arguments to be provided when instantiating the transforms.data_module_kwargs¶ (
Any
) – Additional keyword arguments to provide to theDataModule
constructor.
- Return type
- Returns
The constructed
TabularClassificationData
.
Examples
The files can be in Comma Separated Values (CSV) format with either a
.csv
or.txt
extension.We have a
train_data.csv
with the following contents:animal,friendly,weight cat,yes,6 dog,yes,10 cat,no,5
and a
predict_data.csv
with the following contents:friendly,weight yes,7 no,12 yes,5
>>> from flash import Trainer >>> from flash.tabular import TabularClassifier, TabularClassificationData >>> datamodule = TabularClassificationData.from_csv( ... "friendly", ... "weight", ... "animal", ... train_file="train_data.csv", ... predict_file="predict_data.csv", ... batch_size=4, ... ) >>> datamodule.num_classes 2 >>> datamodule.labels ['cat', 'dog'] >>> model = TabularClassifier.from_data(datamodule, backbone="tabnet") >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) Training... >>> trainer.predict(model, datamodule=datamodule) Predicting...
Alternatively, the files can be in Tab Separated Values (TSV) format with a
.tsv
extension.We have a
train_data.tsv
with the following contents:animal friendly weight cat yes 6 dog yes 10 cat no 5
and a
predict_data.tsv
with the following contents:friendly weight yes 7 no 12 yes 5
>>> from flash import Trainer >>> from flash.tabular import TabularClassifier, TabularClassificationData >>> datamodule = TabularClassificationData.from_csv( ... "friendly", ... "weight", ... "animal", ... train_file="train_data.tsv", ... predict_file="predict_data.tsv", ... batch_size=4, ... ) >>> datamodule.num_classes 2 >>> datamodule.labels ['cat', 'dog'] >>> model = TabularClassifier.from_data(datamodule, backbone="tabnet") >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) Training... >>> trainer.predict(model, datamodule=datamodule) Predicting...
- classmethod from_data_frame(categorical_fields=None, numerical_fields=None, target_fields=None, parameters=None, train_data_frame=None, val_data_frame=None, test_data_frame=None, predict_data_frame=None, target_formatter=None, input_cls=<class 'flash.tabular.classification.input.TabularClassificationDataFrameInput'>, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, **data_module_kwargs)[source]¶
Creates a
TabularClassificationData
object from the given data frames.Note
The
categorical_fields
,numerical_fields
, andtarget_fields
do not need to be provided ifparameters
are passed instead. These can be obtained from theparameters
attribute of theTabularData
object that contains your training data.The targets will be extracted from the
target_fields
in the data frames and can be in any of our supported classification target formats. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.- Parameters
categorical_fields¶ (
Union
[str
,List
[str
],None
]) – The fields (column names) in the data frames containing categorical data.numerical_fields¶ (
Union
[str
,List
[str
],None
]) – The fields (column names) in the data frames containing numerical data.target_fields¶ (
Union
[str
,List
[str
],None
]) – The field (column name) or list of fields in the data frames containing the targets.parameters¶ (
Optional
[Dict
[str
,Any
]]) – Parameters to use ifcategorical_fields
,numerical_fields
, andtarget_fields
are not provided (e.g. when loading data for inference or validation).train_data_frame¶ (
Optional
[DataFrame
]) – The DataFrame to use when training.val_data_frame¶ (
Optional
[DataFrame
]) – The DataFrame to use when validating.test_data_frame¶ (
Optional
[DataFrame
]) – The DataFrame to use when testing.predict_data_frame¶ (
Optional
[DataFrame
]) – The DataFrame to use when predicting.target_formatter¶ (
Optional
[TargetFormatter
]) – Optionally provide aTargetFormatter
to control how targets are handled. See Formatting Classification Targets for more details.input_cls¶ (
Type
[Input
]) – TheInput
type to use for loading the data.transform¶ (
TypeVar
(INPUT_TRANSFORM_TYPE
,Type
[flash.core.data.io.input_transform.InputTransform],Callable
,Tuple
[Union
[StrEnum
,str
],Dict
[str
,Any
]],Union
[StrEnum
,str
],None
)) – TheInputTransform
type to use.transform_kwargs¶ (
Optional
[Dict
]) – Dict of keyword arguments to be provided when instantiating the transforms.data_module_kwargs¶ (
Any
) – Additional keyword arguments to provide to theDataModule
constructor.
- Return type
- Returns
The constructed
TabularClassificationData
.
Examples
We have a DataFrame
train_data
with the following contents:>>> train_data.head(3) animal friendly weight 0 cat yes 6 1 dog yes 10 2 cat no 5
and a DataFrame
predict_data
with the following contents:>>> predict_data.head(3) friendly weight 0 yes 7 1 no 12 2 yes 5
>>> from flash import Trainer >>> from flash.tabular import TabularClassifier, TabularClassificationData >>> datamodule = TabularClassificationData.from_data_frame( ... "friendly", ... "weight", ... "animal", ... train_data_frame=train_data, ... predict_data_frame=predict_data, ... batch_size=4, ... ) >>> datamodule.num_classes 2 >>> datamodule.labels ['cat', 'dog'] >>> model = TabularClassifier.from_data(datamodule, backbone="tabnet") >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) Training... >>> trainer.predict(model, datamodule=datamodule) Predicting...
- classmethod from_dicts(categorical_fields=None, numerical_fields=None, target_fields=None, parameters=None, train_dict=None, val_dict=None, test_dict=None, predict_dict=None, target_formatter=None, input_cls=<class 'flash.tabular.classification.input.TabularClassificationDictInput'>, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, **data_module_kwargs)[source]¶
Creates a
TabularClassificationData
object from the given dictionary.Note
The
categorical_fields
,numerical_fields
, andtarget_fields
do not need to be provided ifparameters
are passed instead. These can be obtained from theparameters
attribute of theTabularData
object that contains your training data.The targets will be extracted from the
target_fields
in the dict and can be in any of our supported classification target formats. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.- Parameters
categorical_fields¶ (
Union
[str
,List
[str
],None
]) – The fields (column names) in the dictionary containing categorical data.numerical_fields¶ (
Union
[str
,List
[str
],None
]) – The fields (column names) in the dictionary containing numerical data.target_fields¶ (
Union
[str
,List
[str
],None
]) – The field (column name) or list of fields in the dictionary containing the targets.parameters¶ (
Optional
[Dict
[str
,Any
]]) – Parameters to use ifcategorical_fields
,numerical_fields
, andtarget_fields
are not provided (e.g. when loading data for inference or validation).train_dict¶ (
Optional
[Dict
[str
,List
[Any
]]]) – The data to use when training.val_dict¶ (
Optional
[Dict
[str
,List
[Any
]]]) – The data to use when validating.test_dict¶ (
Optional
[Dict
[str
,List
[Any
]]]) – The data to use when testing.predict_dict¶ (
Optional
[Dict
[str
,List
[Any
]]]) – The data to use when predicting.target_formatter¶ (
Optional
[TargetFormatter
]) – Optionally provide aTargetFormatter
to control how targets are handled. See Formatting Classification Targets for more details.input_cls¶ (
Type
[Input
]) – TheInput
type to use for loading the data.transform¶ (
TypeVar
(INPUT_TRANSFORM_TYPE
,Type
[flash.core.data.io.input_transform.InputTransform],Callable
,Tuple
[Union
[StrEnum
,str
],Dict
[str
,Any
]],Union
[StrEnum
,str
],None
)) – TheInputTransform
type to use.transform_kwargs¶ (
Optional
[Dict
]) – Dict of keyword arguments to be provided when instantiating the transforms.data_module_kwargs¶ (
Any
) – Additional keyword arguments to provide to theDataModule
constructor.
- Return type
- Returns
The constructed
TabularClassificationData
.
Examples
We have a dictionary
train_data
with the following contents:{ "age": [2, 4, 1], "animal": ["cat", "dog", "cat"], "weight": [6, 10, 5] }
and a dictionary
predict_data
with the following contents:{ "animal": ["dog", "dog", "cat"], "weight": [7, 12, 5] }
>>> from flash import Trainer >>> from flash.tabular import TabularClassifier, TabularClassificationData >>> datamodule = TabularClassificationData.from_dicts( ... "friendly", ... "weight", ... "animal", ... train_dict=train_data, ... predict_dict=predict_data, ... batch_size=4, ... ) >>> datamodule.num_classes 2 >>> datamodule.labels ['cat', 'dog'] >>> model = TabularClassifier.from_data(datamodule, backbone="tabnet") >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) Training... >>> trainer.predict(model, datamodule=datamodule) Predicting...
- classmethod from_lists(categorical_fields=None, numerical_fields=None, target_fields=None, parameters=None, train_list=None, val_list=None, test_list=None, predict_list=None, target_formatter=None, input_cls=<class 'flash.tabular.classification.input.TabularClassificationListInput'>, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, **data_module_kwargs)[source]¶
Creates a
TabularClassificationData
object from the given data (in the form of list of a tuple or a dictionary).Note
The
categorical_fields
,numerical_fields
, andtarget_fields
do not need to be provided ifparameters
are passed instead. These can be obtained from theparameters
attribute of theTabularData
object that contains your training data.The targets will be extracted from the
target_fields
in the dict and can be in any of our supported classification target formats. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.- Parameters
categorical_fields¶ (
Union
[str
,List
[str
],None
]) – The fields (column names) in the dictionary containing categorical data.numerical_fields¶ (
Union
[str
,List
[str
],None
]) – The fields (column names) in the dictionary containing numerical data.target_fields¶ (
Union
[str
,List
[str
],None
]) – The field (column name) or list of fields in the dictionary containing the targets.parameters¶ (
Optional
[Dict
[str
,Any
]]) – Parameters to use ifcategorical_fields
,numerical_fields
, andtarget_fields
are not provided (e.g. when loading data for inference or validation).train_list¶ (
Optional
[List
[Union
[tuple
,dict
]]]) – The data to use when training.val_list¶ (
Optional
[List
[Union
[tuple
,dict
]]]) – The data to use when validating.testing. (test_lists _sphinx_paramlinks_flash.tabular.classification.data.TabularClassificationData.from_lists.The data to use when) –
predict_list¶ (
Optional
[List
[Union
[tuple
,dict
]]]) – The data to use when predicting.target_formatter¶ (
Optional
[TargetFormatter
]) – Optionally provide aTargetFormatter
to control how targets are handled. See Formatting Classification Targets for more details.input_cls¶ (
Type
[Input
]) – TheInput
type to use for loading the data.transform¶ (
TypeVar
(INPUT_TRANSFORM_TYPE
,Type
[flash.core.data.io.input_transform.InputTransform],Callable
,Tuple
[Union
[StrEnum
,str
],Dict
[str
,Any
]],Union
[StrEnum
,str
],None
)) – TheInputTransform
type to use.transform_kwargs¶ (
Optional
[Dict
]) – Dict of keyword arguments to be provided when instantiating the transforms.data_module_kwargs¶ (
Any
) – Additional keyword arguments to provide to theDataModule
constructor.
- Return type
- Returns
The constructed
TabularClassificationData
.
Examples
We have a list of dictionaries
train_data
with the following contents:[ {"animal": "cat", "friendly": "yes", "weight": 6}, {"animal": "dog", "friendly": "yes", "weight": 10}, {"animal": "cat", "friendly": "no", "weight": 5}, ]
and a list of dictionaries
predict_data
with the following contents:[ {"friendly": "yes", "weight": 7}, {"friendly": "no", "weight": 12}, {"friendly": "yes", "weight": 5}, ]
>>> from flash import Trainer >>> from flash.tabular import TabularClassifier, TabularClassificationData >>> datamodule = TabularClassificationData.from_lists( ... "friendly", ... "weight", ... "animal", ... train_list=train_data, ... predict_list=predict_data, ... batch_size=4, ... ) >>> datamodule.num_classes 2 >>> datamodule.labels ['cat', 'dog'] >>> model = TabularClassifier.from_data(datamodule, backbone="tabnet") >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) Training... >>> trainer.predict(model, datamodule=datamodule) Predicting...