# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import random
from abc import ABC, abstractmethod
from fractions import Fraction
from typing import Any, Dict, NamedTuple, Optional, Tuple, Union, List
class ClipInfo(NamedTuple):
"""
Named-tuple for clip information with:
clip_start_sec (Union[float, Fraction]): clip start time.
clip_end_sec (Union[float, Fraction]): clip end time.
clip_index (int): clip index in the video.
aug_index (int): augmentation index for the clip. Different augmentation methods
might generate multiple views for the same clip.
is_last_clip (bool): a bool specifying whether there are more clips to be
sampled from the video.
"""
clip_start_sec: Union[float, Fraction]
clip_end_sec: Union[float, Fraction]
clip_index: int
aug_index: int
is_last_clip: bool
class ClipInfoList(NamedTuple):
"""
Named-tuple for clip information with:
clip_start_sec (float): clip start time.
clip_end_sec (float): clip end time.
clip_index (int): clip index in the video.
aug_index (int): augmentation index for the clip. Different augmentation methods
might generate multiple views for the same clip.
is_last_clip (bool): a bool specifying whether there are more clips to be
sampled from the video.
"""
clip_start_sec: List[float]
clip_end_sec: List[float]
clip_index: List[float]
aug_index: List[float]
is_last_clip: List[float]
class ClipSampler(ABC):
"""
Interface for clip samplers that take a video time, previous sampled clip time,
and returns a named-tuple ``ClipInfo``.
"""
def __init__(self, clip_duration: Union[float, Fraction]) -> None:
self._clip_duration = Fraction(clip_duration)
self._current_clip_index = 0
self._current_aug_index = 0
@abstractmethod
def __call__(
self,
last_clip_time: Union[float, Fraction],
video_duration: Union[float, Fraction],
annotation: Dict[str, Any],
) -> ClipInfo:
pass
def reset(self) -> None:
"""Resets any video-specific attributes in preperation for next video"""
pass
def make_clip_sampler(sampling_type: str, *args) -> ClipSampler:
"""
Constructs the clip samplers found in ``pytorchvideo.data.clip_sampling`` from the
given arguments.
Args:
sampling_type (str): choose clip sampler to return. It has three options:
* uniform: constructs and return ``UniformClipSampler``
* random: construct and return ``RandomClipSampler``
* constant_clips_per_video: construct and return ``ConstantClipsPerVideoSampler``
*args: the args to pass to the chosen clip sampler constructor.
"""
if sampling_type == "uniform":
return UniformClipSampler(*args)
elif sampling_type == "random":
return RandomClipSampler(*args)
elif sampling_type == "constant_clips_per_video":
return ConstantClipsPerVideoSampler(*args)
elif sampling_type == "random_multi":
return RandomMultiClipSampler(*args)
else:
raise NotImplementedError(f"{sampling_type} not supported")
class RandomClipSampler(ClipSampler):
"""
Randomly samples clip of size clip_duration from the videos.
"""
def __call__(
self, last_clip_time: float, video_duration: float, annotation: Dict[str, Any]
) -> ClipInfo:
"""
Args:
last_clip_time (float): Not used for RandomClipSampler.
video_duration: (float): the duration (in seconds) for the video that's
being sampled
annotation (Dict): Not used by this sampler.
Returns:
clip_info (ClipInfo): includes the clip information of (clip_start_time,
clip_end_time, clip_index, aug_index, is_last_clip). The times are in seconds.
clip_index, aux_index and is_last_clip are always 0, 0 and True, respectively.
"""
max_possible_clip_start = max(video_duration - self._clip_duration, 0)
clip_start_sec = Fraction(random.uniform(0, max_possible_clip_start))
return ClipInfo(
clip_start_sec, clip_start_sec + self._clip_duration, 0, 0, True
)
class RandomMultiClipSampler(RandomClipSampler):
"""
TODO
"""
def __init__(self, clip_duration: float, num_clips: int) -> None:
super().__init__(clip_duration)
self._num_clips = num_clips
def __call__(
self, last_clip_time: float, video_duration: float, annotation: Dict[str, Any]
) -> ClipInfoList:
(
clip_start_list,
clip_end_list,
clip_index_list,
aug_index_list,
is_last_clip_list,
) = (
self._num_clips * [None],
self._num_clips * [None],
self._num_clips * [None],
self._num_clips * [None],
self._num_clips * [None],
)
for i in range(self._num_clips):
(
clip_start_list[i],
clip_end_list[i],
clip_index_list[i],
aug_index_list[i],
is_last_clip_list[i],
) = super().__call__(last_clip_time, video_duration, annotation)
return ClipInfoList(
clip_start_list,
clip_end_list,
clip_index_list,
aug_index_list,
is_last_clip_list,
)
class ConstantClipsPerVideoSampler(ClipSampler):
"""
Evenly splits the video into clips_per_video increments and samples clips of size
clip_duration at these increments.
"""
def __init__(
self, clip_duration: float, clips_per_video: int, augs_per_clip: int = 1
) -> None:
super().__init__(clip_duration)
self._clips_per_video = clips_per_video
self._augs_per_clip = augs_per_clip
def __call__(
self, last_clip_time: float, video_duration: float, annotation: Dict[str, Any]
) -> ClipInfo:
"""
Args:
last_clip_time (float): Not used for ConstantClipsPerVideoSampler.
video_duration: (float): the duration (in seconds) for the video that's
being sampled.
annotation (Dict): Not used by this sampler.
Returns:
a named-tuple `ClipInfo`: includes the clip information of (clip_start_time,
clip_end_time, clip_index, aug_index, is_last_clip). The times are in seconds.
is_last_clip is True after clips_per_video clips have been sampled or the end
of the video is reached.
"""
max_possible_clip_start = Fraction(max(video_duration - self._clip_duration, 0))
uniform_clip = Fraction(max_possible_clip_start, self._clips_per_video)
clip_start_sec = uniform_clip * self._current_clip_index
clip_index = self._current_clip_index
aug_index = self._current_aug_index
self._current_aug_index += 1
if self._current_aug_index >= self._augs_per_clip:
self._current_clip_index += 1
self._current_aug_index = 0
# Last clip is True if sampled self._clips_per_video or if end of video is reached.
is_last_clip = False
if (
self._current_clip_index >= self._clips_per_video
or uniform_clip * self._current_clip_index > max_possible_clip_start
):
self._current_clip_index = 0
is_last_clip = True
if is_last_clip:
self.reset()
return ClipInfo(
clip_start_sec,
clip_start_sec + self._clip_duration,
clip_index,
aug_index,
is_last_clip,
)
def reset(self):
self._current_clip_index = 0
self._current_aug_index = 0