Shortcuts

Video Classification

The Task

Typically, Video Classification refers to the task of producing a label for actions identified in a given video. The task is to predict which class the video clip belongs to.

Lightning Flash VideoClassifier and VideoClassificationData classes internally rely on PyTorchVideo.


Example

Let’s develop a model to classifying video clips of Humans performing actions (such as: archery , bowling, etc.). We’ll use data from the Kinetics dataset. Here’s an outline of the folder structure:

video_dataset
├── train
│   ├── archery
│   │   ├── -1q7jA3DXQM_000005_000015.mp4
│   │   ├── -5NN5hdIwTc_000036_000046.mp4
│   │   ...
│   ├── bowling
│   │   ├── -5ExwuF5IUI_000030_000040.mp4
│   │   ├── -7sTNNI1Bcg_000075_000085.mp4
│   ... ...
└── val
    ├── archery
    │   ├── 0S-P4lr_c7s_000022_000032.mp4
    │   ├── 2x1lIrgKxYo_000589_000599.mp4
    │   ...
    ├── bowling
    │   ├── 1W7HNDBA4pA_000002_000012.mp4
    │   ├── 4JxH3S5JwMs_000003_000013.mp4
    ... ...

Once we’ve downloaded the data using download_data(), we create the VideoClassificationData. We select a pre-trained backbone to use for our VideoClassifier and fine-tune on the Kinetics data. The backbone can be any model from the PyTorchVideo Model Zoo. We then use the trained VideoClassifier for inference. Finally, we save the model. Here’s the full example:

import os

import flash
from flash.core.data.utils import download_data
from flash.video import VideoClassificationData, VideoClassifier

# 1. Create the DataModule
# Find more datasets at https://pytorchvideo.readthedocs.io/en/latest/data.html
download_data("https://pl-flash-data.s3.amazonaws.com/kinetics.zip", "./data")

datamodule = VideoClassificationData.from_folders(
    train_folder=os.path.join(os.getcwd(), "data/kinetics/train"),
    val_folder=os.path.join(os.getcwd(), "data/kinetics/val"),
    clip_sampler="uniform",
    clip_duration=1,
    decode_audio=False,
)

# 2. Build the task
model = VideoClassifier(backbone="x3d_xs", num_classes=datamodule.num_classes, pretrained=False)

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=3)
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 4. Make a prediction
predictions = model.predict(os.path.join(os.getcwd(), "data/kinetics/predict"))
print(predictions)

# 5. Save the model!
trainer.save_checkpoint("video_classification.pt")
Read the Docs v: latest
Versions
latest
stable
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
docs_build2
docs_build
Downloads
pdf
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.