Source code for eztorch.transforms.multi_crop_transform

from typing import Any, Iterable, List

from PIL.Image import Image
from torch import Tensor


[docs] class MultiCropTransform: """Define multi crop transform that apply several sets of transform to the inputs. Args: set_transforms: List of Dictionary of sets of transforms specifying transforms and number of views per set. Example:: set_transforms = [ {'transform': [...], 'num_views': ...}, {'transform': [...], 'num_views': ...}, ... ] transform = MultiCropTransform( set_transforms ) """ def __init__(self, set_transforms: List[Any]) -> None: super().__init__() self.set_transforms = set_transforms transforms = [] for set_transform in self.set_transforms: if "num_views" not in set_transform: set_transform["num_views"] = 1 transforms.extend([set_transform["transform"]] * set_transform["num_views"]) self.transforms = transforms def __call__(self, img: Image | Tensor | Iterable[Image | Tensor]) -> Tensor: if type(img) not in [Image, Tensor]: transformed_images = [ transform(image) for transform, image in zip(self.transforms, img, strict=True) ] else: transformed_images = [transform(img) for transform in self.transforms] return transformed_images def __repr__(self) -> str: format_string = self.__class__.__name__ for set_transform in self.set_transforms: format_string += "(\n" format_string += " num views={}\n".format(set_transform["num_views"]) format_string += " transforms={}".format(set_transform["transform"]) format_string += "\n)" return format_string