Source code for pytorchvideo.data.clip_sampling

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.

import random
from abc import ABC, abstractmethod
from fractions import Fraction
from typing import Any, Dict, NamedTuple, Optional, Tuple, Union, List


class ClipInfo(NamedTuple):
    """
    Named-tuple for clip information with:
        clip_start_sec  (Union[float, Fraction]): clip start time.
        clip_end_sec (Union[float, Fraction]): clip end time.
        clip_index (int): clip index in the video.
        aug_index (int): augmentation index for the clip. Different augmentation methods
            might generate multiple views for the same clip.
        is_last_clip (bool): a bool specifying whether there are more clips to be
            sampled from the video.
    """

    clip_start_sec: Union[float, Fraction]
    clip_end_sec: Union[float, Fraction]
    clip_index: int
    aug_index: int
    is_last_clip: bool


class ClipInfoList(NamedTuple):
    """
    Named-tuple for clip information with:
        clip_start_sec  (float): clip start time.
        clip_end_sec (float): clip end time.
        clip_index (int): clip index in the video.
        aug_index (int): augmentation index for the clip. Different augmentation methods
            might generate multiple views for the same clip.
        is_last_clip (bool): a bool specifying whether there are more clips to be
            sampled from the video.
    """

    clip_start_sec: List[float]
    clip_end_sec: List[float]
    clip_index: List[float]
    aug_index: List[float]
    is_last_clip: List[float]


class ClipSampler(ABC):
    """
    Interface for clip samplers that take a video time, previous sampled clip time,
    and returns a named-tuple ``ClipInfo``.
    """

    def __init__(self, clip_duration: Union[float, Fraction]) -> None:
        self._clip_duration = Fraction(clip_duration)
        self._current_clip_index = 0
        self._current_aug_index = 0

    @abstractmethod
    def __call__(
        self,
        last_clip_time: Union[float, Fraction],
        video_duration: Union[float, Fraction],
        annotation: Dict[str, Any],
    ) -> ClipInfo:
        pass

    def reset(self) -> None:
        """Resets any video-specific attributes in preperation for next video"""
        pass


def make_clip_sampler(sampling_type: str, *args) -> ClipSampler:
    """
    Constructs the clip samplers found in ``pytorchvideo.data.clip_sampling`` from the
    given arguments.

    Args:
        sampling_type (str): choose clip sampler to return. It has three options:

            * uniform: constructs and return ``UniformClipSampler``
            * random: construct and return ``RandomClipSampler``
            * constant_clips_per_video: construct and return ``ConstantClipsPerVideoSampler``

        *args: the args to pass to the chosen clip sampler constructor.
    """
    if sampling_type == "uniform":
        return UniformClipSampler(*args)
    elif sampling_type == "random":
        return RandomClipSampler(*args)
    elif sampling_type == "constant_clips_per_video":
        return ConstantClipsPerVideoSampler(*args)
    elif sampling_type == "random_multi":
        return RandomMultiClipSampler(*args)
    else:
        raise NotImplementedError(f"{sampling_type} not supported")


[docs] class UniformClipSampler(ClipSampler): """ Evenly splits the video into clips of size clip_duration. """ def __init__( self, clip_duration: Union[float, Fraction], stride: Optional[Union[float, Fraction]] = None, backpad_last: bool = False, eps: float = 1e-6, ): """ Args: clip_duration (Union[float, Fraction]): The length of the clip to sample (in seconds). stride (floUnion[float, Fraction]at, optional): The amount of seconds to offset the next clip by default value of None is equivalent to no stride => stride == clip_duration. eps (float): Epsilon for floating point comparisons. Used to check the last clip. backpad_last (bool): Whether to include the last frame(s) by "back padding". For instance, if we have a video of 39 frames (30 fps = 1.3s) with a stride of 16 (0.533s) with a clip duration of 32 frames (1.0667s). The clips will be (in frame numbers): with backpad_last = False - [0, 32] with backpad_last = True - [0, 32] - [8, 40], this is "back-padded" from [16, 48] to fit the last window Note that you can use Fraction for clip_duration and stride if you want to avoid float precision issue and need accurate frames in each clip. """ super().__init__(clip_duration) self._stride = stride if stride is not None else clip_duration self._eps = eps self._backpad_last = backpad_last assert ( self._stride > 0 and self._stride <= clip_duration ), f"stride must be >0 and <= clip_duration ({clip_duration})" def _clip_start_end( self, last_clip_time: Union[float, Fraction], video_duration: Union[float, Fraction], backpad_last: bool, ) -> Tuple[Fraction, Fraction]: """ Helper to calculate the start/end clip with backpad logic """ clip_start = Fraction( max(last_clip_time - max(0, self._clip_duration - self._stride), 0) ) clip_end = Fraction(clip_start + self._clip_duration) if backpad_last: buffer_amount = max(0, clip_end - video_duration) clip_start -= buffer_amount clip_start = Fraction(max(0, clip_start)) # handle rounding clip_end = Fraction(clip_start + self._clip_duration) return clip_start, clip_end def __call__( self, last_clip_time: float, video_duration: float, annotation: Dict[str, Any] ) -> ClipInfo: """ Args: last_clip_time (float): the last clip end time sampled from this video. This should be 0.0 if the video hasn't had clips sampled yet. video_duration: (float): the duration of the video that's being sampled in seconds annotation (Dict): Not used by this sampler. Returns: clip_info: (ClipInfo): includes the clip information (clip_start_time, clip_end_time, clip_index, aug_index, is_last_clip), where the times are in seconds and is_last_clip is False when there is still more of time in the video to be sampled. """ clip_start, clip_end = self._clip_start_end( last_clip_time, video_duration, backpad_last=self._backpad_last ) # if they both end at the same time - it's the last clip _, next_clip_end = self._clip_start_end( clip_end, video_duration, backpad_last=self._backpad_last ) if self._backpad_last: is_last_clip = abs(next_clip_end - clip_end) < self._eps else: is_last_clip = next_clip_end > video_duration clip_index = self._current_clip_index self._current_clip_index += 1 if is_last_clip: self.reset() return ClipInfo(clip_start, clip_end, clip_index, 0, is_last_clip) def reset(self): self._current_clip_index = 0
class RandomClipSampler(ClipSampler): """ Randomly samples clip of size clip_duration from the videos. """ def __call__( self, last_clip_time: float, video_duration: float, annotation: Dict[str, Any] ) -> ClipInfo: """ Args: last_clip_time (float): Not used for RandomClipSampler. video_duration: (float): the duration (in seconds) for the video that's being sampled annotation (Dict): Not used by this sampler. Returns: clip_info (ClipInfo): includes the clip information of (clip_start_time, clip_end_time, clip_index, aug_index, is_last_clip). The times are in seconds. clip_index, aux_index and is_last_clip are always 0, 0 and True, respectively. """ max_possible_clip_start = max(video_duration - self._clip_duration, 0) clip_start_sec = Fraction(random.uniform(0, max_possible_clip_start)) return ClipInfo( clip_start_sec, clip_start_sec + self._clip_duration, 0, 0, True ) class RandomMultiClipSampler(RandomClipSampler): """ TODO """ def __init__(self, clip_duration: float, num_clips: int) -> None: super().__init__(clip_duration) self._num_clips = num_clips def __call__( self, last_clip_time: float, video_duration: float, annotation: Dict[str, Any] ) -> ClipInfoList: ( clip_start_list, clip_end_list, clip_index_list, aug_index_list, is_last_clip_list, ) = ( self._num_clips * [None], self._num_clips * [None], self._num_clips * [None], self._num_clips * [None], self._num_clips * [None], ) for i in range(self._num_clips): ( clip_start_list[i], clip_end_list[i], clip_index_list[i], aug_index_list[i], is_last_clip_list[i], ) = super().__call__(last_clip_time, video_duration, annotation) return ClipInfoList( clip_start_list, clip_end_list, clip_index_list, aug_index_list, is_last_clip_list, ) class ConstantClipsPerVideoSampler(ClipSampler): """ Evenly splits the video into clips_per_video increments and samples clips of size clip_duration at these increments. """ def __init__( self, clip_duration: float, clips_per_video: int, augs_per_clip: int = 1 ) -> None: super().__init__(clip_duration) self._clips_per_video = clips_per_video self._augs_per_clip = augs_per_clip def __call__( self, last_clip_time: float, video_duration: float, annotation: Dict[str, Any] ) -> ClipInfo: """ Args: last_clip_time (float): Not used for ConstantClipsPerVideoSampler. video_duration: (float): the duration (in seconds) for the video that's being sampled. annotation (Dict): Not used by this sampler. Returns: a named-tuple `ClipInfo`: includes the clip information of (clip_start_time, clip_end_time, clip_index, aug_index, is_last_clip). The times are in seconds. is_last_clip is True after clips_per_video clips have been sampled or the end of the video is reached. """ max_possible_clip_start = Fraction(max(video_duration - self._clip_duration, 0)) uniform_clip = Fraction(max_possible_clip_start, self._clips_per_video) clip_start_sec = uniform_clip * self._current_clip_index clip_index = self._current_clip_index aug_index = self._current_aug_index self._current_aug_index += 1 if self._current_aug_index >= self._augs_per_clip: self._current_clip_index += 1 self._current_aug_index = 0 # Last clip is True if sampled self._clips_per_video or if end of video is reached. is_last_clip = False if ( self._current_clip_index >= self._clips_per_video or uniform_clip * self._current_clip_index > max_possible_clip_start ): self._current_clip_index = 0 is_last_clip = True if is_last_clip: self.reset() return ClipInfo( clip_start_sec, clip_start_sec + self._clip_duration, clip_index, aug_index, is_last_clip, ) def reset(self): self._current_clip_index = 0 self._current_aug_index = 0