Source code for eztorch.datamodules.hmdb51

from typing import Optional

import hydra
from lightning.pytorch.utilities import rank_zero_info
from omegaconf import DictConfig

from eztorch.datamodules.video import VideoBaseDataModule
from eztorch.datasets.hmdb51 import Hmdb51


[docs] class Hmdb51DataModule(VideoBaseDataModule): """Datamodule for the HMDB51 dataset. Args: datadir: Path to the data (eg: csv, folder, ...). train: Configuration for the training data to define the loading of data, the transforms and the dataloader. val: Configuration for the validation data to define the loading of data, the transforms and the dataloader. test: Configuration for the testing data to define the loading of data, the transforms and the dataloader. video_path_prefix: Path to root directory where the videos are stored. All the video paths before loading are prefixed with this path. decode_audio: If ``True``, decode audio. decoder: Defines which backend should be used to decode videos. decoder_args: Arguments to configure the decoder. split_id: Split used for training and testing. """ def __init__( self, datadir: str, train: Optional[DictConfig] = None, val: Optional[DictConfig] = None, test: Optional[DictConfig] = None, split_id: int = 1, video_path_prefix: str = "", decode_audio: bool = False, decoder: str = "pyav", decoder_args: DictConfig = {}, ) -> None: super().__init__( datadir=datadir, train=train, val=val, test=test, video_path_prefix=video_path_prefix, decode_audio=decode_audio, decoder=decoder, decoder_args=decoder_args, ) self.split_id = split_id @property def num_classes(self) -> int: """Number of classes.""" return 51 def _verify_classes(self) -> None: dirs = [dir.stem for dir in self.datadir.iterdir() if dir.is_dir()] assert ( len(dirs) == self.num_classes ), f"{len(dirs)}/{self.num_classes} classes found: {dirs}" def prepare_data(self) -> None: self._verify_classes() def setup(self, stage: Optional[str] = None) -> None: if stage == "fit": if self.train is None: raise RuntimeError("No training configuration has been passed.") self.train_transform = hydra.utils.instantiate(self.train.transform) self.train_clip_sampler = hydra.utils.instantiate(self.train.clip_sampler) rank_zero_info(f"Train transform: {self.train_transform}") self.train_dataset = Hmdb51( self.traindir, clip_sampler=self.train_clip_sampler, transform=self.train_transform, video_path_prefix=self.train_video_path_prefix, split_id=self.split_id, split_type="train", decode_audio=self.decode_audio, decoder=self.train_decoder, decoder_args=self.train_decoder_args, ) if self.val is not None: self.val_transform = hydra.utils.instantiate(self.val.transform) self.val_clip_sampler = hydra.utils.instantiate(self.val.clip_sampler) rank_zero_info(f"Val transform: {self.val_transform}") self.val_dataset = Hmdb51( self.valdir, clip_sampler=self.val_clip_sampler, transform=self.val_transform, video_path_prefix=self.val_video_path_prefix, split_id=self.split_id, split_type="test", decode_audio=self.decode_audio, decoder=self.val_decoder, decoder_args=self.val_decoder_args, ) elif stage == "test": if self.test is None: raise RuntimeError("No testing configuration has been passed.") self.test_transform = hydra.utils.instantiate(self.test.transform) self.test_clip_sampler = hydra.utils.instantiate(self.test.clip_sampler) rank_zero_info(f"Test transform: {self.test_transform}") self.test_dataset = Hmdb51( self.testdir, clip_sampler=self.test_clip_sampler, transform=self.test_transform, video_path_prefix=self.test_video_path_prefix, split_id=self.split_id, split_type="test", decode_audio=self.decode_audio, decoder=self.test_decoder, decoder_args=self.test_decoder_args, )