Shortcuts

TabularClassifier

class flash.tabular.classification.model.TabularClassifier(parameters, embedding_sizes, cat_dims, num_features, num_classes, labels=None, backbone='tabnet', loss_fn=torch.nn.functional.cross_entropy, optimizer='Adam', lr_scheduler=None, metrics=None, learning_rate=None, **backbone_kwargs)[source]

The TabularClassifier is a Task for classifying tabular data. For more details, see Tabular Classification.

Parameters
  • parameters (Dict[str, Any]) – The parameters computed from the training data (can be obtained from the parameters attribute of the TabularClassificationData object containing your training data).

  • embedding_sizes (list) – List of (num_classes, emb_dim) to form categorical embeddings.

  • cat_dims (list) – Number of distinct values for each categorical column

  • num_features (int) – Number of columns in table

  • num_classes (int) – Number of classes to classify

  • backbone (str) – name of the model to use

  • loss_fn (Callable) – Loss function for training, defaults to cross entropy.

  • optimizer (TypeVar(OPTIMIZER_TYPE, str, Callable, Tuple[str, Dict[str, Any]], None)) – Optimizer to use for training.

  • lr_scheduler (Optional[TypeVar(LR_SCHEDULER_TYPE, str, Callable, Tuple[str, Dict[str, Any]], Tuple[str, Dict[str, Any], Dict[str, Any]], None)]) – The LR scheduler to use during training.

  • metrics (Optional[TypeVar(METRICS_TYPE, Metric, Mapping, Sequence, None)]) – Metrics to compute for training and evaluation. Can either be an metric from the torchmetrics package, a custom metric inherenting from torchmetrics.Metric, a callable function or a list/dict containing a combination of the aforementioned. In all cases, each metric needs to have the signature metric(preds,target) and return a single scalar tensor. Defaults to torchmetrics.Accuracy.

  • learning_rate (Optional[float]) – Learning rate to use for training.

  • **backbone_kwargs – Optional additional arguments for the model.

classmethod available_finetuning_strategies(cls)

Returns a list containing the keys of the available Finetuning Strategies.

Return type

List[str]

classmethod available_lr_schedulers(cls)

Returns a list containing the keys of the available LR schedulers.

Return type

List[str]

classmethod available_optimizers(cls)

Returns a list containing the keys of the available Optimizers.

Return type

List[str]

classmethod available_outputs(cls)

Returns the list of available outputs (that can be used during prediction or serving) for this Task.

Examples

..testsetup:

>>> from flash import Task
>>> print(Task.available_outputs())
['preds', 'raw']
Return type

List[str]

property data_parameters: Dict[str, Any]

Get the parameters computed from the training data used to create this TabularClassifier. Use these parameters to load data for evaluation / prediction.

Examples

>>> import flash
>>> from flash.core.data.utils import download_data
>>> from flash.tabular import TabularClassificationData, TabularClassifier
>>> download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", "./data")
>>> model = TabularClassifier.load_from_checkpoint(
...     "https://flash-weights.s3.amazonaws.com/0.7.0/tabular_classification_model.pt"
... )
>>> datamodule = TabularClassificationData.from_csv(
...     predict_file="data/titanic/predict.csv",
...     parameters=model.data_parameters,
...     batch_size=8,
... )
>>> trainer = flash.Trainer()
>>> trainer.predict(
...     model,
...     datamodule=datamodule,
...     output="classes",
... )  
Predicting...
Return type

Dict[str, Any]

Read the Docs v: 0.8.1
Versions
latest
stable
0.8.1
0.8.0
0.7.5
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.2
0.5.1
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.