Source code for eztorch.transforms.only_input_transform

from typing import Any, Callable, Dict, List

from torchvision.transforms import Compose

from eztorch.transforms.apply_key import (ApplySameTransformInputKeyOnList,
                                          ApplyTransformInputKey,
                                          ApplyTransformInputKeyOnList,
                                          ApplyTransformOnDict)
from eztorch.transforms.dict_keep_keys import DictKeepInputLabelIdx


[docs] class OnlyInputListTransform(Compose): """Apply Transform to only the key ``'input'`` in a list of sample dictionary. Args: transform: The transform to apply. """ def __init__(self, transform: Callable) -> None: transforms = [ApplyTransformInputKeyOnList(transform), DictKeepInputLabelIdx()] super().__init__(transforms=transforms)
[docs] class OnlyInputTransform(Compose): """Apply Transform to only the key ``'input'`` in a sample dictionary. Args: transform: The transform to apply. """ def __init__(self, transform: Callable) -> None: transforms = [ApplyTransformInputKey(transform), DictKeepInputLabelIdx()] super().__init__(transforms=transforms)
[docs] class OnlyInputListSameTransform(Compose): """Apply the same transform to only the key ``'input'`` in a list of sample dictionary. Args: transform: The transform to apply. """ def __init__(self, transform: Callable) -> None: transforms = [ ApplySameTransformInputKeyOnList(transform), DictKeepInputLabelIdx(), ] super().__init__(transforms=transforms)
[docs] class OnlyInputTransformWithDictTransform(Compose): """Apply Transform to only the key ``'input'`` in a sample dictionary with a transformation on the dictionary afterwards. Args: transform: The transform to apply to the input. dict_transform: The transform to apply to the dictionary. first_dict: If ``True``, first apply the transformation on the dict, else, first apply the transformation on the input. """ def __init__( self, transform: Callable, dict_transform: Callable, first_dict: bool = False ) -> None: if first_dict: transforms = [ ApplyTransformInputKey(transform), ApplyTransformOnDict(dict_transform), DictKeepInputLabelIdx(), ] else: transforms = [ ApplyTransformInputKey(transform), ApplyTransformOnDict(dict_transform), DictKeepInputLabelIdx(), ] super().__init__(transforms=transforms)
[docs] class OnlyInputListTransformWithDictTransform: """Apply Transform to only the key ``'input'`` in a list of sample dictionary with a transformation on the dictionary afterwards. Args: transform: The transform to apply to the input. dict_transform: The transform to apply to the dictionary. first_dict: If ``True``, first apply the transformation on the dict, else, first apply the transformation on the input. """ def __init__( self, transform: Callable, dict_transform: Callable, first_dict: bool = False ) -> None: self.transform = OnlyInputTransformWithDictTransform( transform, dict_transform, first_dict ) def __call__(self, x: List[Dict[str, Any]]) -> List[Dict[str, Any]]: return [self.transform(sample) for sample in x] def __repr__(self) -> str: format_string = self.__class__.__name__ + "(" format_string += "\n" format_string += f" {self.transform}" format_string += "\n)" return format_string