BaseVisualization¶
- class flash.core.data.base_viz.BaseVisualization(enabled=False)[source]¶
This Base Class is used to create visualization tool on top of
InputTransform
hooks.Override any of the
show_{_hook_name}
to receive the associated data and visualize them.Example:
from flash.image import ImageClassificationData from flash.core.data.base_viz import BaseVisualization class CustomBaseVisualization(BaseVisualization): def show_load_sample(self, samples: List[Any], running_stage): # plot samples def show_per_sample_transform(self, samples: List[Any], running_stage): # plot samples def show_collate(self, batch: List[Any], running_stage): # plot batch def show_per_batch_transform(self, batch: List[Any], running_stage): # plot batch class CustomImageClassificationData(ImageClassificationData): @staticmethod def configure_data_fetcher(*args, **kwargs) -> BaseDataFetcher: return CustomBaseVisualization(*args, **kwargs) dm = CustomImageClassificationData.from_folders( train_folder="./data/train", val_folder="./data/val", test_folder="./data/test", predict_folder="./data/predict") # visualize a ``train`` batch dm.show_train_batches() # visualize next ``train`` batch dm.show_train_batches() # visualize a ``val`` batch dm.show_val_batches() # visualize a ``test`` batch dm.show_test_batches() # visualize a ``predict`` batch dm.show_predict_batches()
Note
If the user wants to plot all different transformation stages at once, override the
show
function directly.Example:
class CustomBaseVisualization(BaseVisualization): def show(self, batch: Dict[str, Any], running_stage: RunningStage): print(batch) # out { 'load_sample': [...], 'per_sample_transform': [...], 'collate': [...], 'per_batch_transform': [...], }
Note
As the
InputTransform
hooks are injected within the threaded workers of the DataLoader, the data won’t be accessible when usingnum_workers > 0
.- show(batch, running_stage, func_names_list, limit_nb_samples=None, figsize=(6.4, 4.8))[source]¶
Override this function when you want to visualize a composition.
- Return type
- show_collate(batch, running_stage, limit_nb_samples=None, figsize=(6.4, 4.8))[source]¶
Override to visualize
collate
output data.- Return type
- show_load_sample(samples, running_stage, limit_nb_samples=None, figsize=(6.4, 4.8))[source]¶
Override to visualize
load_sample
output data.
- show_per_batch_transform(batch, running_stage, limit_nb_samples=None, figsize=(6.4, 4.8))[source]¶
Override to visualize
per_batch_transform
output data.- Return type
- show_per_batch_transform_on_device(batch, running_stage, limit_nb_samples=None, figsize=(6.4, 4.8))[source]¶
Override to visualize
per_batch_transform_on_device
output data.- Return type