TabularClassifier¶
- class flash.tabular.classification.model.TabularClassifier(parameters, embedding_sizes, cat_dims, num_features, num_classes, labels=None, backbone='tabnet', loss_fn=torch.nn.functional.cross_entropy, optimizer='Adam', lr_scheduler=None, metrics=None, learning_rate=None, **backbone_kwargs)[source]¶
The
TabularClassifier
is aTask
for classifying tabular data. For more details, see Tabular Classification.- Parameters
parameters¶ (
Dict
[str
,Any
]) – The parameters computed from the training data (can be obtained from theparameters
attribute of theTabularClassificationData
object containing your training data).embedding_sizes¶ (
list
) – List of (num_classes, emb_dim) to form categorical embeddings.cat_dims¶ (
list
) – Number of distinct values for each categorical columnloss_fn¶ (
Callable
) – Loss function for training, defaults to cross entropy.optimizer¶ (
TypeVar
(OPTIMIZER_TYPE
,str
,Callable
,Tuple
[str
,Dict
[str
,Any
]],None
)) – Optimizer to use for training.lr_scheduler¶ (
Optional
[TypeVar
(LR_SCHEDULER_TYPE
,str
,Callable
,Tuple
[str
,Dict
[str
,Any
]],Tuple
[str
,Dict
[str
,Any
],Dict
[str
,Any
]],None
)]) – The LR scheduler to use during training.metrics¶ (
Optional
[TypeVar
(METRICS_TYPE
,Metric
,Mapping
,Sequence
,None
)]) – Metrics to compute for training and evaluation. Can either be an metric from the torchmetrics package, a custom metric inherenting from torchmetrics.Metric, a callable function or a list/dict containing a combination of the aforementioned. In all cases, each metric needs to have the signature metric(preds,target) and return a single scalar tensor. Defaults totorchmetrics.Accuracy
.learning_rate¶ (
Optional
[float
]) – Learning rate to use for training.**backbone_kwargs¶ – Optional additional arguments for the model.
- classmethod available_finetuning_strategies(cls)¶
Returns a list containing the keys of the available Finetuning Strategies.
- classmethod available_lr_schedulers(cls)¶
Returns a list containing the keys of the available LR schedulers.
- classmethod available_optimizers(cls)¶
Returns a list containing the keys of the available Optimizers.
- classmethod available_outputs(cls)¶
Returns the list of available outputs (that can be used during prediction or serving) for this
Task
.Examples
..testsetup:
>>> from flash import Task
>>> print(Task.available_outputs()) ['preds', 'raw']
- property data_parameters: Dict[str, Any]¶
Get the parameters computed from the training data used to create this
TabularClassifier
. Use these parameters to load data for evaluation / prediction.Examples
>>> import flash >>> from flash.core.data.utils import download_data >>> from flash.tabular import TabularClassificationData, TabularClassifier >>> download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", "./data") >>> model = TabularClassifier.load_from_checkpoint( ... "https://flash-weights.s3.amazonaws.com/0.7.0/tabular_classification_model.pt" ... ) >>> datamodule = TabularClassificationData.from_csv( ... predict_file="data/titanic/predict.csv", ... parameters=model.data_parameters, ... batch_size=8, ... ) >>> trainer = flash.Trainer() >>> trainer.predict( ... model, ... datamodule=datamodule, ... output="classes", ... ) Predicting...