Source code for eztorch.datasets.decoders.frame_video

from __future__ import annotations

import logging
import math
import os
import pathlib
import random
import re
from abc import ABC
from concurrent.futures import ThreadPoolExecutor
from fractions import Fraction
from typing import Callable

import numpy as np
import torch
import torch.utils.data
from iopath.common.file_io import g_pathmgr
from numpy.typing import NDArray
from pytorchvideo.data.video import Video
from torchvision.io import decode_image, read_file, read_image

from eztorch.datasets.utils_fn import get_video_to_frame_path_fn
from eztorch.transforms.video.temporal_difference_transform import \
    temporal_difference

logger = logging.getLogger(__name__)


[docs] class GeneralFrameVideo(Video, ABC): """GeneralFrameVideo is an abstract class for accessing clips stored as frames. Args: num_threads_io: Controls whether parallelizable io operations are performed across multiple threads. num_threads_decode: Controls whether parallelizable decode operations are performed across multiple threads. transform: The transform to apply to the frames. """ def __init__( self, num_threads_io: int = 0, num_threads_decode: int = 0, transform: Callable | None = None, ) -> None: self._num_threads_io = num_threads_io self._num_threads_decode = num_threads_decode self._transform = transform def _load_images_with_retries( self, image_paths: list[str], ) -> torch.Tensor: """Loads the given image paths decodes them and returns them as a stacked tensors. Args: image_paths: A list of paths to images. Returns: A tensor of the clip's RGB frames with shape: (time, height, width, channel). The frames are of type ``torch.uint8`` and in the range :math:`[0,255]`. Raises: Exception: If unable to load images. """ def fetch_img(image_path: str) -> torch.Tensor: img_byte = read_file(image_path) return img_byte def decode_img(fetch_idx: int, img_bytes: torch.Tensor) -> None: img = decode_image(img_bytes) if self._transform is not None: img = self._transform(img) imgs[fetch_idx] = img return if self._num_threads_io > 0: imgs = [None for i in range(len(image_paths))] work_queue_size = [] with ThreadPoolExecutor(max_workers=self._num_threads_io) as read_pool: imgs_bytes = read_pool.map(fetch_img, image_paths) with ThreadPoolExecutor( max_workers=max(self._num_threads_decode, 1) ) as work_pool: for _ in work_pool.map( decode_img, range(len(image_paths)), imgs_bytes ): work_queue_size.append(work_pool._work_queue.qsize()) else: imgs = [read_image(image_path) for image_path in image_paths] imgs = torch.stack(imgs) return imgs
[docs] class FrameVideo(GeneralFrameVideo): """FrameVideo is an abstractions for accessing clips based on their start and end time for a video where each frame is stored as an image. Args: video_path: The path of the video. duration: The duration of the video in seconds. fps: The target fps for the video. This is needed to link the frames to a second timestamp in the video. frame_filter: Function to subsample frames in a clip before loading. If ``None``, no subsampling is performed video_frame_to_path_fn: A function that maps from the video path and a frame index integer to the file path where the frame is located. video_frame_paths: List of frame paths for each index of a video. num_threads_io: Controls whether parallelizable io operations are performed across multiple threads. num_threads_decode: Controls whether parallelizable decode operations are performed across multiple threads. transform: The transform to apply to the frames. decode_float: Whether to decode the clip as float. """ def __init__( self, video_path: str, duration: float, fps: int, frame_filter: Callable[[list[int]], tuple[NDArray, NDArray]] | None = None, time_difference_prob: float = 0.0, video_frame_to_path_fn: None | (Callable[[str, int], int]) = get_video_to_frame_path_fn(), video_frame_paths: list[str] | None = None, num_threads_io: int = 0, num_threads_decode: int = 0, transform: Callable | None = None, decode_float: bool = False, ) -> None: super().__init__( num_threads_io=num_threads_io, num_threads_decode=num_threads_decode, transform=transform, ) self._duration = duration self._fps = fps self._time_difference_prob = time_difference_prob self._decode_float = decode_float self._frame_filter = frame_filter assert (video_frame_to_path_fn is None) != ( video_frame_paths is None ), "Only one of video_frame_to_path_fn or video_frame_paths can be provided" self._video_frame_to_path_fn = video_frame_to_path_fn self._video_frame_paths = video_frame_paths # Set the pathname to the parent directory of the first frame. self._video_path = video_path self._name = pathlib.Path(video_path).name @classmethod def from_directory( cls, video_path: str, fps: int, num_frames: int | None = None, path_order_cache: dict[str, list[str]] | None = None, frame_path_fn: Callable[[str, int], str] | None = None, **kwargs, ): """ Args: video_path: Path to frame video directory. fps: The target fps for the video. This is needed to link the frames to a second timestamp in the video. num_frames: If not ``None``, number of frames in the video. path_order_cache: An optional mapping from directory-path to list of frames in the directory in numerical order. Used for speedup by caching the frame paths. frame_path_fn: Function to retrieve frame path from the video directory path and the frame index. """ if path_order_cache is not None and video_path in path_order_cache: return cls.from_frame_paths( video_path, fps, path_order_cache[video_path], num_frames, **kwargs ) assert g_pathmgr.isdir(video_path), f"{video_path} is not a directory" if num_frames is None: rel_frame_paths = g_pathmgr.ls(video_path) def natural_keys(text): return [int(c) if c.isdigit() else c for c in re.split(r"(\d+)", text)] rel_frame_paths.sort(key=natural_keys) frame_paths = [os.path.join(video_path, f) for f in rel_frame_paths] elif frame_path_fn is not None: frame_paths = [frame_path_fn(video_path, i) for i in range(num_frames)] else: frame_paths = None if path_order_cache is not None and frame_paths is not None: path_order_cache[video_path] = frame_paths return cls.from_frame_paths(video_path, fps, frame_paths, num_frames, **kwargs) @classmethod def from_frame_paths( cls, video_path: str, fps: int, video_frame_paths: list[str] | None = None, num_frames: int | None = None, **kwargs, ): """ Args: video_path: Path to the video directory. fps: The target fps for the video. This is needed to link the frames to a second timestamp in the video. video_frame_paths: A list of paths to each frames in the video. num_frames: If not ``None``, number of frames in the video. """ assert ( video_frame_paths is not None or num_frames is not None ), "video_frame_paths is empty or num_frames should be specified" duration = Fraction(num_frames, fps) or Fraction(len(video_frame_paths), fps) return cls( video_path, duration, fps, video_frame_paths=video_frame_paths, **kwargs ) @property def name(self) -> str: """The name of the video.""" return self._name @property def duration(self) -> float: """The video's duration/end-time in seconds.""" return self._duration def _get_frame_index_for_time(self, time_sec: float) -> int: return math.ceil(self._fps * time_sec) def get_clip( self, start_sec: float, end_sec: float, ) -> dict[str, torch.Tensor | None]: """Retrieves frames from the stored video at the specified start and end times in seconds (the video always starts at 0 seconds). Returned frames will be in [start_sec, end_sec). Given that PathManager may be fetching the frames from network storage, to handle transient errors, frame reading is retried N times. Note that as end_sec is exclusive, so you may need to use `get_clip(start_sec, duration + EPS)` to get the last frame. Args: start_sec: The clip start time in seconds end_sec: The clip end time in seconds Returns: A dictionary constraining the clip data and information. """ if start_sec < 0 or start_sec > self._duration: logger.warning( f"No frames found within {start_sec} and {end_sec} seconds. Video starts" f"at time 0 and ends at {self._duration}." ) return None end_sec = min(end_sec, self._duration) start_frame_index = self._get_frame_index_for_time(start_sec) end_frame_index = self._get_frame_index_for_time(end_sec) frame_indices = list(range(start_frame_index, end_frame_index)) if ( self._time_difference_prob > 0.0 and self._time_difference_prob > random.random() ): do_time_difference = True else: do_time_difference = False # Frame filter function to allow for subsampling before loading if self._frame_filter: frame_indices, keep_frames = self._frame_filter( frame_indices, time_difference=do_time_difference ) else: frame_indices = np.array(frame_indices) keep_frames = np.array([True for i in range(len(frame_indices))]) unique, unique_indices = np.unique(frame_indices, return_inverse=True) clip_paths = [self._video_frame_to_path(i) for i in unique] clip_frames = self._load_images_with_retries(clip_paths) clip_frames = clip_frames[unique_indices] clip_frames = clip_frames.permute(1, 0, 2, 3) if do_time_difference: clip_frames = temporal_difference( clip_frames, use_grayscale=True, absolute=False )[:, keep_frames, :, :] frame_indices = frame_indices[keep_frames] if self._decode_float: clip_frames = clip_frames.to(torch.float32) return { "video": clip_frames, "frame_indices": frame_indices, "audio": None, "time_difference": do_time_difference, } def _video_frame_to_path(self, frame_index: int) -> str: if self._video_frame_to_path_fn: return self._video_frame_to_path_fn(self._video_path, frame_index) elif self._video_frame_paths: return self._video_frame_paths[frame_index] else: raise Exception( "One of _video_frame_to_path_fn or _video_frame_paths must be set" )