Shortcuts

Source code for flash.video.classification.input

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Any, Callable, Dict, List, Optional, Type, Union

import pandas as pd
import torch
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch.utils.data import Sampler

from flash.core.data.io.classification_input import ClassificationInputMixin
from flash.core.data.io.input import DataKeys, Input, IterableInput
from flash.core.data.utilities.classification import MultiBinaryTargetFormatter, TargetFormatter
from flash.core.data.utilities.data_frame import read_csv, resolve_files, resolve_targets
from flash.core.data.utilities.paths import list_valid_files, make_dataset, PATH_TYPE
from flash.core.integrations.fiftyone.utils import FiftyOneLabelUtilities
from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _PYTORCHVIDEO_AVAILABLE, lazy_import, requires

if _FIFTYONE_AVAILABLE:
    fol = lazy_import("fiftyone.core.labels")
    SampleCollection = "fiftyone.core.collections.SampleCollection"
else:
    fol = None
    SampleCollection = None

if _PYTORCHVIDEO_AVAILABLE:
    from pytorchvideo.data.clip_sampling import ClipSampler, make_clip_sampler
    from pytorchvideo.data.encoded_video import EncodedVideo
    from pytorchvideo.data.labeled_video_dataset import LabeledVideoDataset
    from pytorchvideo.data.labeled_video_paths import LabeledVideoPaths
else:
    ClipSampler, LabeledVideoDataset, EncodedVideo, ApplyTransformToKey = None, None, None, None


def _make_clip_sampler(
    clip_sampler: Union[str, "ClipSampler"] = "random",
    clip_duration: float = 2,
    clip_sampler_kwargs: Dict[str, Any] = None,
) -> "ClipSampler":
    if clip_sampler_kwargs is None:
        clip_sampler_kwargs = {}
    return make_clip_sampler(clip_sampler, clip_duration, **clip_sampler_kwargs)


[docs]class VideoClassificationInput(IterableInput, ClassificationInputMixin): def load_data( self, files: List[PATH_TYPE], targets: List[Any], clip_sampler: Union[str, "ClipSampler"] = "random", clip_duration: float = 2, clip_sampler_kwargs: Dict[str, Any] = None, video_sampler: Type[Sampler] = torch.utils.data.RandomSampler, decode_audio: bool = False, decoder: str = "pyav", target_formatter: Optional[TargetFormatter] = None, ) -> "LabeledVideoDataset": dataset = LabeledVideoDataset( LabeledVideoPaths(list(zip(files, targets))), _make_clip_sampler(clip_sampler, clip_duration, clip_sampler_kwargs), video_sampler=video_sampler, decode_audio=decode_audio, decoder=decoder, ) if not self.predicting: self.load_target_metadata( [sample[1] for sample in dataset._labeled_videos._paths_and_labels], target_formatter=target_formatter ) return dataset def load_sample(self, sample): sample["label"] = self.format_target(sample["label"]) return sample
[docs]class VideoClassificationFoldersInput(VideoClassificationInput): def load_data( self, path: str, clip_sampler: Union[str, "ClipSampler"] = "random", clip_duration: float = 2, clip_sampler_kwargs: Dict[str, Any] = None, video_sampler: Type[Sampler] = torch.utils.data.RandomSampler, decode_audio: bool = False, decoder: str = "pyav", target_formatter: Optional[TargetFormatter] = None, ) -> "LabeledVideoDataset": return super().load_data( *make_dataset(path, extensions=("mp4", "avi")), clip_sampler=clip_sampler, clip_duration=clip_duration, clip_sampler_kwargs=clip_sampler_kwargs, video_sampler=video_sampler, decode_audio=decode_audio, decoder=decoder, target_formatter=target_formatter, )
[docs]class VideoClassificationFilesInput(VideoClassificationInput): def load_data( self, paths: List[str], targets: List[Any], clip_sampler: Union[str, "ClipSampler"] = "random", clip_duration: float = 2, clip_sampler_kwargs: Dict[str, Any] = None, video_sampler: Type[Sampler] = torch.utils.data.RandomSampler, decode_audio: bool = False, decoder: str = "pyav", target_formatter: Optional[TargetFormatter] = None, ) -> "LabeledVideoDataset": return super().load_data( paths, targets, clip_sampler=clip_sampler, clip_duration=clip_duration, clip_sampler_kwargs=clip_sampler_kwargs, video_sampler=video_sampler, decode_audio=decode_audio, decoder=decoder, target_formatter=target_formatter, )
[docs]class VideoClassificationDataFrameInput(VideoClassificationInput): def load_data( self, data_frame: pd.DataFrame, input_key: str, target_keys: Union[str, List[str]], root: Optional[PATH_TYPE] = None, resolver: Optional[Callable[[Optional[PATH_TYPE], Any], PATH_TYPE]] = None, clip_sampler: Union[str, "ClipSampler"] = "random", clip_duration: float = 2, clip_sampler_kwargs: Dict[str, Any] = None, video_sampler: Type[Sampler] = torch.utils.data.RandomSampler, decode_audio: bool = False, decoder: str = "pyav", target_formatter: Optional[TargetFormatter] = None, ) -> "LabeledVideoDataset": result = super().load_data( resolve_files(data_frame, input_key, root, resolver), resolve_targets(data_frame, target_keys), clip_sampler=clip_sampler, clip_duration=clip_duration, clip_sampler_kwargs=clip_sampler_kwargs, video_sampler=video_sampler, decode_audio=decode_audio, decoder=decoder, target_formatter=target_formatter, ) # If we had binary multi-class targets then we also know the labels (column names) if ( self.training and isinstance(self.target_formatter, MultiBinaryTargetFormatter) and isinstance(target_keys, List) ): self.labels = target_keys return result
[docs]class VideoClassificationCSVInput(VideoClassificationDataFrameInput): def load_data( self, csv_file: PATH_TYPE, input_key: str, target_keys: Optional[Union[str, List[str]]] = None, root: Optional[PATH_TYPE] = None, resolver: Optional[Callable[[Optional[PATH_TYPE], Any], PATH_TYPE]] = None, clip_sampler: Union[str, "ClipSampler"] = "random", clip_duration: float = 2, clip_sampler_kwargs: Dict[str, Any] = None, video_sampler: Type[Sampler] = torch.utils.data.RandomSampler, decode_audio: bool = False, decoder: str = "pyav", target_formatter: Optional[TargetFormatter] = None, ) -> "LabeledVideoDataset": data_frame = read_csv(csv_file) if root is None: root = os.path.dirname(csv_file) return super().load_data( data_frame, input_key, target_keys, root, resolver, clip_sampler=clip_sampler, clip_duration=clip_duration, clip_sampler_kwargs=clip_sampler_kwargs, video_sampler=video_sampler, decode_audio=decode_audio, decoder=decoder, target_formatter=target_formatter, )
[docs]class VideoClassificationFiftyOneInput(VideoClassificationInput): @requires("fiftyone") def load_data( self, sample_collection: SampleCollection, clip_sampler: Union[str, "ClipSampler"] = "random", clip_duration: float = 2, clip_sampler_kwargs: Dict[str, Any] = None, video_sampler: Type[Sampler] = torch.utils.data.RandomSampler, decode_audio: bool = False, decoder: str = "pyav", label_field: str = "ground_truth", target_formatter: Optional[TargetFormatter] = None, ) -> "LabeledVideoDataset": label_utilities = FiftyOneLabelUtilities(label_field, fol.Classification) label_utilities.validate(sample_collection) return super().load_data( sample_collection.values("filepath"), sample_collection.values(label_field + ".label"), clip_sampler=clip_sampler, clip_duration=clip_duration, clip_sampler_kwargs=clip_sampler_kwargs, video_sampler=video_sampler, decode_audio=decode_audio, decoder=decoder, target_formatter=target_formatter, )
[docs]class VideoClassificationPathsPredictInput(Input): def predict_load_data( self, paths: List[str], clip_sampler: Union[str, "ClipSampler"] = "random", clip_duration: float = 2, clip_sampler_kwargs: Dict[str, Any] = None, decode_audio: bool = False, decoder: str = "pyav", ) -> List[str]: paths = list_valid_files(paths, valid_extensions=("mp4", "avi")) self._clip_sampler = _make_clip_sampler(clip_sampler, clip_duration, clip_sampler_kwargs) self._decode_audio = decode_audio self._decoder = decoder return paths def predict_load_sample(self, sample: str) -> Dict[str, Any]: video = EncodedVideo.from_path(sample, decode_audio=self._decode_audio, decoder=self._decoder) ( clip_start, clip_end, clip_index, aug_index, is_last_clip, ) = self._clip_sampler(0.0, video.duration, None) loaded_clip = video.get_clip(clip_start, clip_end) clip_is_null = ( loaded_clip is None or loaded_clip["video"] is None or (loaded_clip["audio"] is None and self._decode_audio) ) if clip_is_null: raise MisconfigurationException( f"The provided video is too short {video.duration} to be clipped at {self._clip_sampler._clip_duration}" ) frames = loaded_clip["video"] audio_samples = loaded_clip["audio"] return { "video": frames, "video_name": video.name, "video_index": 0, "clip_index": clip_index, "aug_index": aug_index, **({"audio": audio_samples} if audio_samples is not None else {}), DataKeys.METADATA: {"filepath": sample}, }
[docs]class VideoClassificationDataFramePredictInput(VideoClassificationPathsPredictInput): def predict_load_data( self, data_frame: pd.DataFrame, input_key: str, root: Optional[PATH_TYPE] = None, resolver: Optional[Callable[[Optional[PATH_TYPE], Any], PATH_TYPE]] = None, clip_sampler: Union[str, "ClipSampler"] = "random", clip_duration: float = 2, clip_sampler_kwargs: Dict[str, Any] = None, decode_audio: bool = False, decoder: str = "pyav", ) -> List[str]: return super().predict_load_data( resolve_files(data_frame, input_key, root, resolver), clip_sampler=clip_sampler, clip_duration=clip_duration, clip_sampler_kwargs=clip_sampler_kwargs, decode_audio=decode_audio, decoder=decoder, )
[docs]class VideoClassificationCSVPredictInput(VideoClassificationDataFramePredictInput): def predict_load_data( self, csv_file: PATH_TYPE, input_key: str, root: Optional[PATH_TYPE] = None, resolver: Optional[Callable[[Optional[PATH_TYPE], Any], PATH_TYPE]] = None, clip_sampler: Union[str, "ClipSampler"] = "random", clip_duration: float = 2, clip_sampler_kwargs: Dict[str, Any] = None, decode_audio: bool = False, decoder: str = "pyav", ) -> List[str]: data_frame = read_csv(csv_file) if root is None: root = os.path.dirname(csv_file) return super().predict_load_data( data_frame, input_key, root, resolver, clip_sampler=clip_sampler, clip_duration=clip_duration, clip_sampler_kwargs=clip_sampler_kwargs, decode_audio=decode_audio, decoder=decoder, )
class VideoClassificationFiftyOnePredictInput(VideoClassificationPathsPredictInput): @requires("fiftyone") def predict_load_data( self, data: SampleCollection, clip_sampler: Union[str, "ClipSampler"] = "random", clip_duration: float = 2, clip_sampler_kwargs: Dict[str, Any] = None, decode_audio: bool = False, decoder: str = "pyav", ) -> List[str]: return super().predict_load_data( data.values("filepath"), clip_sampler=clip_sampler, clip_duration=clip_duration, clip_sampler_kwargs=clip_sampler_kwargs, decode_audio=decode_audio, decoder=decoder, )

© Copyright 2020-2021, PyTorch Lightning. Revision 8e9123c7.

Built with Sphinx using a theme provided by Read the Docs.
Read the Docs v: 0.7.1
Versions
latest
stable
0.7.1
0.7.0
0.6.0
0.5.2
0.5.1
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
docs-fix_typing
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.