from typing import Optional
import hydra
from lightning.pytorch.utilities import rank_zero_info
from omegaconf import DictConfig
from eztorch.datamodules.video import VideoBaseDataModule
from eztorch.datasets.hmdb51 import Hmdb51
[docs]
class Hmdb51DataModule(VideoBaseDataModule):
    """Datamodule for the HMDB51 dataset.
    Args:
        datadir: Path to the data (eg: csv, folder, ...).
        train: Configuration for the training data to define the loading of data, the transforms and the dataloader.
        val: Configuration for the validation data to define the loading of data, the transforms and the dataloader.
        test: Configuration for the testing data to define the loading of data, the transforms and the dataloader.
        video_path_prefix: Path to root directory where the videos are stored.
            All the video paths before loading are prefixed with this path.
        decode_audio: If ``True``, decode audio.
        decoder: Defines which backend should be used to decode videos.
        decoder_args: Arguments to configure the decoder.
        split_id: Split used for training and testing.
    """
    def __init__(
        self,
        datadir: str,
        train: Optional[DictConfig] = None,
        val: Optional[DictConfig] = None,
        test: Optional[DictConfig] = None,
        split_id: int = 1,
        video_path_prefix: str = "",
        decode_audio: bool = False,
        decoder: str = "pyav",
        decoder_args: DictConfig = {},
    ) -> None:
        super().__init__(
            datadir=datadir,
            train=train,
            val=val,
            test=test,
            video_path_prefix=video_path_prefix,
            decode_audio=decode_audio,
            decoder=decoder,
            decoder_args=decoder_args,
        )
        self.split_id = split_id
    @property
    def num_classes(self) -> int:
        """Number of classes."""
        return 51
    def _verify_classes(self) -> None:
        dirs = [dir.stem for dir in self.datadir.iterdir() if dir.is_dir()]
        assert (
            len(dirs) == self.num_classes
        ), f"{len(dirs)}/{self.num_classes} classes found: {dirs}"
    def prepare_data(self) -> None:
        self._verify_classes()
    def setup(self, stage: Optional[str] = None) -> None:
        if stage == "fit":
            if self.train is None:
                raise RuntimeError("No training configuration has been passed.")
            self.train_transform = hydra.utils.instantiate(self.train.transform)
            self.train_clip_sampler = hydra.utils.instantiate(self.train.clip_sampler)
            rank_zero_info(f"Train transform: {self.train_transform}")
            self.train_dataset = Hmdb51(
                self.traindir,
                clip_sampler=self.train_clip_sampler,
                transform=self.train_transform,
                video_path_prefix=self.train_video_path_prefix,
                split_id=self.split_id,
                split_type="train",
                decode_audio=self.decode_audio,
                decoder=self.train_decoder,
                decoder_args=self.train_decoder_args,
            )
            if self.val is not None:
                self.val_transform = hydra.utils.instantiate(self.val.transform)
                self.val_clip_sampler = hydra.utils.instantiate(self.val.clip_sampler)
                rank_zero_info(f"Val transform: {self.val_transform}")
                self.val_dataset = Hmdb51(
                    self.valdir,
                    clip_sampler=self.val_clip_sampler,
                    transform=self.val_transform,
                    video_path_prefix=self.val_video_path_prefix,
                    split_id=self.split_id,
                    split_type="test",
                    decode_audio=self.decode_audio,
                    decoder=self.val_decoder,
                    decoder_args=self.val_decoder_args,
                )
        elif stage == "test":
            if self.test is None:
                raise RuntimeError("No testing configuration has been passed.")
            self.test_transform = hydra.utils.instantiate(self.test.transform)
            self.test_clip_sampler = hydra.utils.instantiate(self.test.clip_sampler)
            rank_zero_info(f"Test transform: {self.test_transform}")
            self.test_dataset = Hmdb51(
                self.testdir,
                clip_sampler=self.test_clip_sampler,
                transform=self.test_transform,
                video_path_prefix=self.test_video_path_prefix,
                split_id=self.split_id,
                split_type="test",
                decode_audio=self.decode_audio,
                decoder=self.test_decoder,
                decoder_args=self.test_decoder_args,
            )