Shortcuts

TabularClassificationData

class flash.tabular.classification.data.TabularClassificationData(train_input=None, val_input=None, test_input=None, predict_input=None, data_fetcher=None, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, val_split=None, batch_size=None, num_workers=0, sampler=None, pin_memory=True, persistent_workers=False)[source]

The TabularClassificationData class is a DataModule with a set of classmethods for loading data for tabular classification.

classmethod from_csv(categorical_fields=None, numerical_fields=None, target_fields=None, parameters=None, train_file=None, val_file=None, test_file=None, predict_file=None, target_formatter=None, input_cls=<class 'flash.tabular.classification.input.TabularClassificationCSVInput'>, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, **data_module_kwargs)[source]

Creates a TabularClassificationData object from the given CSV files.

Note

The categorical_fields, numerical_fields, and target_fields do not need to be provided if parameters are passed instead. These can be obtained from the parameters attribute of the TabularData object that contains your training data.

The targets will be extracted from the target_fields columns in the CSV files and can be in any of our supported classification target formats. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.

Parameters
Return type

TabularClassificationData

Returns

The constructed TabularClassificationData.

Examples

The files can be in Comma Separated Values (CSV) format with either a .csv or .txt extension.

We have a train_data.csv with the following contents:

animal,friendly,weight
cat,yes,6
dog,yes,10
cat,no,5

and a predict_data.csv with the following contents:

friendly,weight
yes,7
no,12
yes,5
>>> from flash import Trainer
>>> from flash.tabular import TabularClassifier, TabularClassificationData
>>> datamodule = TabularClassificationData.from_csv(
...     "friendly",
...     "weight",
...     "animal",
...     train_file="train_data.csv",
...     predict_file="predict_data.csv",
...     batch_size=4,
... )
>>> datamodule.num_classes
2
>>> datamodule.labels
['cat', 'dog']
>>> model = TabularClassifier.from_data(datamodule, backbone="tabnet")
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule)  
Training...
>>> trainer.predict(model, datamodule=datamodule)  
Predicting...

Alternatively, the files can be in Tab Separated Values (TSV) format with a .tsv extension.

We have a train_data.tsv with the following contents:

animal  friendly    weight
cat     yes         6
dog     yes         10
cat     no          5

and a predict_data.tsv with the following contents:

friendly    weight
yes         7
no          12
yes         5
>>> from flash import Trainer
>>> from flash.tabular import TabularClassifier, TabularClassificationData
>>> datamodule = TabularClassificationData.from_csv(
...     "friendly",
...     "weight",
...     "animal",
...     train_file="train_data.tsv",
...     predict_file="predict_data.tsv",
...     batch_size=4,
... )
>>> datamodule.num_classes
2
>>> datamodule.labels
['cat', 'dog']
>>> model = TabularClassifier.from_data(datamodule, backbone="tabnet")
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule)  
Training...
>>> trainer.predict(model, datamodule=datamodule)  
Predicting...
classmethod from_data_frame(categorical_fields=None, numerical_fields=None, target_fields=None, parameters=None, train_data_frame=None, val_data_frame=None, test_data_frame=None, predict_data_frame=None, target_formatter=None, input_cls=<class 'flash.tabular.classification.input.TabularClassificationDataFrameInput'>, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, **data_module_kwargs)[source]

Creates a TabularClassificationData object from the given data frames.

Note

The categorical_fields, numerical_fields, and target_fields do not need to be provided if parameters are passed instead. These can be obtained from the parameters attribute of the TabularData object that contains your training data.

The targets will be extracted from the target_fields in the data frames and can be in any of our supported classification target formats. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.

Parameters
Return type

TabularClassificationData

Returns

The constructed TabularClassificationData.

Examples

We have a DataFrame train_data with the following contents:

>>> train_data.head(3)
  animal friendly  weight
0    cat      yes       6
1    dog      yes      10
2    cat       no       5

and a DataFrame predict_data with the following contents:

>>> predict_data.head(3)
  friendly  weight
0      yes       7
1       no      12
2      yes       5
>>> from flash import Trainer
>>> from flash.tabular import TabularClassifier, TabularClassificationData
>>> datamodule = TabularClassificationData.from_data_frame(
...     "friendly",
...     "weight",
...     "animal",
...     train_data_frame=train_data,
...     predict_data_frame=predict_data,
...     batch_size=4,
... )
>>> datamodule.num_classes
2
>>> datamodule.labels
['cat', 'dog']
>>> model = TabularClassifier.from_data(datamodule, backbone="tabnet")
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule)  
Training...
>>> trainer.predict(model, datamodule=datamodule)  
Predicting...
classmethod from_dicts(categorical_fields=None, numerical_fields=None, target_fields=None, parameters=None, train_dict=None, val_dict=None, test_dict=None, predict_dict=None, target_formatter=None, input_cls=<class 'flash.tabular.classification.input.TabularClassificationDictInput'>, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, **data_module_kwargs)[source]

Creates a TabularClassificationData object from the given dictionary.

Note

The categorical_fields, numerical_fields, and target_fields do not need to be provided if parameters are passed instead. These can be obtained from the parameters attribute of the TabularData object that contains your training data.

The targets will be extracted from the target_fields in the dict and can be in any of our supported classification target formats. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.

Parameters
Return type

TabularClassificationData

Returns

The constructed TabularClassificationData.

Examples

We have a dictionary train_data with the following contents:

{
    "age": [2, 4, 1],
    "animal": ["cat", "dog", "cat"],
    "weight": [6, 10, 5]
}

and a dictionary predict_data with the following contents:

{
    "animal": ["dog", "dog", "cat"],
    "weight": [7, 12, 5]
}
>>> from flash import Trainer
>>> from flash.tabular import TabularClassifier, TabularClassificationData
>>> datamodule = TabularClassificationData.from_dicts(
...     "friendly",
...     "weight",
...     "animal",
...     train_dict=train_data,
...     predict_dict=predict_data,
...     batch_size=4,
... )
>>> datamodule.num_classes
2
>>> datamodule.labels
['cat', 'dog']
>>> model = TabularClassifier.from_data(datamodule, backbone="tabnet")
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule)  
Training...
>>> trainer.predict(model, datamodule=datamodule)  
Predicting...
classmethod from_lists(categorical_fields=None, numerical_fields=None, target_fields=None, parameters=None, train_list=None, val_list=None, test_list=None, predict_list=None, target_formatter=None, input_cls=<class 'flash.tabular.classification.input.TabularClassificationListInput'>, transform=<class 'flash.core.data.io.input_transform.InputTransform'>, transform_kwargs=None, **data_module_kwargs)[source]

Creates a TabularClassificationData object from the given data (in the form of list of a tuple or a dictionary).

Note

The categorical_fields, numerical_fields, and target_fields do not need to be provided if parameters are passed instead. These can be obtained from the parameters attribute of the TabularData object that contains your training data.

The targets will be extracted from the target_fields in the dict and can be in any of our supported classification target formats. To learn how to customize the transforms applied for each stage, read our customizing transforms guide.

Parameters
Return type

TabularClassificationData

Returns

The constructed TabularClassificationData.

Examples

We have a list of dictionaries train_data with the following contents:

[
    {"animal": "cat", "friendly": "yes", "weight": 6},
    {"animal": "dog", "friendly": "yes", "weight": 10},
    {"animal": "cat", "friendly": "no", "weight": 5},
]

and a list of dictionaries predict_data with the following contents:

[
    {"friendly": "yes", "weight": 7},
    {"friendly": "no", "weight": 12},
    {"friendly": "yes", "weight": 5},
]
>>> from flash import Trainer
>>> from flash.tabular import TabularClassifier, TabularClassificationData
>>> datamodule = TabularClassificationData.from_lists(
...     "friendly",
...     "weight",
...     "animal",
...     train_list=train_data,
...     predict_list=predict_data,
...     batch_size=4,
... )
>>> datamodule.num_classes
2
>>> datamodule.labels
['cat', 'dog']
>>> model = TabularClassifier.from_data(datamodule, backbone="tabnet")
>>> trainer = Trainer(fast_dev_run=True)
>>> trainer.fit(model, datamodule=datamodule)  
Training...
>>> trainer.predict(model, datamodule=datamodule)  
Predicting...
Read the Docs v: 0.8.1.post0
Versions
latest
stable
0.8.1.post0
0.8.1
0.8.0
0.7.5
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.2
0.5.1
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.