Shortcuts

TabularClassificationData

class flash.tabular.classification.data.TabularClassificationData(train_input=None, val_input=None, test_input=None, predict_input=None, data_fetcher=None, val_split=None, batch_size=None, num_workers=0, sampler=None, pin_memory=True, persistent_workers=False)[source]

The TabularClassificationData class is a DataModule with a set of classmethods for loading data for tabular classification.

classmethod from_csv(categorical_fields=None, numerical_fields=None, target_fields=None, parameters=None, train_file=None, val_file=None, test_file=None, predict_file=None, train_transform=<class 'flash.core.data.io.input_transform.InputTransform'>, val_transform=<class 'flash.core.data.io.input_transform.InputTransform'>, test_transform=<class 'flash.core.data.io.input_transform.InputTransform'>, predict_transform=<class 'flash.core.data.io.input_transform.InputTransform'>, target_formatter=None, input_cls=<class 'flash.tabular.classification.input.TabularClassificationCSVInput'>, transform_kwargs=None, **data_module_kwargs)[source]

Creates a TabularClassificationData object from the given CSV files.

Note

The categorical_fields, numerical_fields, and target_fields do not need to be provided if parameters are passed instead. These can be obtained from the parameters attribute of the TabularData object that contains your training data.

The targets will be extracted from the target_fields columns in the CSV files and can be in any of our supported classification target formats. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.

Parameters
Return type

TabularClassificationData

Returns

The constructed TabularClassificationData.

Examples

We have a train_data.csv with the following contents:

animal,friendly,weight
cat,yes,6
dog,yes,10
cat,no,5

and a predict_data.csv with the following contents:

friendly,weight
yes,7
no,12
yes,5
>>> from flash import Trainer
>>> from flash.tabular import TabularClassifier, TabularClassificationData
>>> datamodule = TabularClassificationData.from_csv(
...     "friendly",
...     "weight",
...     "animal",
...     train_file="train_data.csv",
...     predict_file="predict_data.csv",
...     batch_size=4,
... )
>>> datamodule.num_classes
2
>>> datamodule.labels
['cat', 'dog']
>>> model = TabularClassifier.from_data(datamodule, backbone="tabnet")
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule)  
Training...
>>> trainer.predict(model, datamodule=datamodule)  
Predicting...
classmethod from_data_frame(categorical_fields=None, numerical_fields=None, target_fields=None, parameters=None, train_data_frame=None, val_data_frame=None, test_data_frame=None, predict_data_frame=None, train_transform=<class 'flash.core.data.io.input_transform.InputTransform'>, val_transform=<class 'flash.core.data.io.input_transform.InputTransform'>, test_transform=<class 'flash.core.data.io.input_transform.InputTransform'>, predict_transform=<class 'flash.core.data.io.input_transform.InputTransform'>, target_formatter=None, input_cls=<class 'flash.tabular.classification.input.TabularClassificationDataFrameInput'>, transform_kwargs=None, **data_module_kwargs)[source]

Creates a TabularClassificationData object from the given data frames.

Note

The categorical_fields, numerical_fields, and target_fields do not need to be provided if parameters are passed instead. These can be obtained from the parameters attribute of the TabularData object that contains your training data.

The targets will be extracted from the target_fields in the data frames and can be in any of our supported classification target formats. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.

Parameters
Return type

TabularClassificationData

Returns

The constructed TabularClassificationData.

Examples

We have a DataFrame train_data with the following contents:

>>> train_data.head(3)
  animal friendly  weight
0    cat      yes       6
1    dog      yes      10
2    cat       no       5

and a DataFrame predict_data with the following contents:

>>> predict_data.head(3)
  friendly  weight
0      yes       7
1       no      12
2      yes       5
>>> from flash import Trainer
>>> from flash.tabular import TabularClassifier, TabularClassificationData
>>> datamodule = TabularClassificationData.from_data_frame(
...     "friendly",
...     "weight",
...     "animal",
...     train_data_frame=train_data,
...     predict_data_frame=predict_data,
...     batch_size=4,
... )
>>> datamodule.num_classes
2
>>> datamodule.labels
['cat', 'dog']
>>> model = TabularClassifier.from_data(datamodule, backbone="tabnet")
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule)  
Training...
>>> trainer.predict(model, datamodule=datamodule)  
Predicting...
Read the Docs v: 0.7.2
Versions
latest
stable
0.7.2
0.7.1
0.7.0
0.6.0
0.5.2
0.5.1
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
docs-fix_typing
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.