Source code for eztorch.transforms.apply_key
from typing import Dict, List
import torch
from torch import Tensor
from torch.nn import Module
[docs]
class ApplyTransformToKey(Module):
"""Applies transform to key of dictionary input.
Args:
key: The dictionary key the transform is applied to.
transform: The transform that is applied.
"""
def __init__(self, key: str, transform: Module):
super().__init__()
self._key = key
self._transform = transform
def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
x[self._key] = self._transform(x[self._key])
return x
[docs]
class ApplyTransformToKeyOnList(Module):
"""
Applies transform to key of dictionary input where input is a list
Args:
key: the dictionary key the transform is applied to.
transform: the transform that is applied.
Example::
>>> transforms.ApplyTransformToKeyOnList(
>>> key='input',
>>> transform=UniformTemporalSubsample(num_video_samples),
>>> )
"""
def __init__(self, key: str, transform: Module) -> None: # pyre-ignore[24]
super().__init__()
self._key = key
self._transform = transform
def forward(self, x: Dict[str, List[Tensor]]) -> Dict[str, List[Tensor]]:
x[self._key] = [self._transform(a) for a in x[self._key]]
return x
def __repr__(self):
return (
f"{self.__class__.__name__}(key={self._key}, transform={self._transform})"
)
[docs]
class ApplySameTransformToKeyOnList(Module):
"""Applies the same transform to key of dictionary input where input is a list.
Args:
key: the dictionary key the transform is applied to.
transform: the transform that is applied.
dim: The dimension to retrieve the various elements of the list.
"""
def __init__(
self, key: str, transform: Module, dim: int = 1
) -> None: # pyre-ignore[24]
super().__init__()
self._key = key
self._transform = transform
self._dim = dim
def forward(self, x: Dict[str, List[Tensor]]) -> Dict[str, List[Tensor]]:
data = x[self._key]
len_data = len(data)
data = torch.cat(data, dim=self._dim)
data = self._transform(data)
data = list(data.split(data.shape[self._dim] // len_data, dim=self._dim))
x[self._key] = data
return x
def __repr__(self):
return f"{self.__class__.__name__}(key={self._key}, transform={self._transform}, dim={self._dim})"
[docs]
class ApplyTransformInputKeyOnList(ApplyTransformToKeyOnList):
"""Apply Transform to the input key.
Args:
transform: The transform to apply.
"""
def __init__(self, transform: Module):
super().__init__("input", transform=transform)
def __repr__(self):
return f"{self.__class__.__name__}(transform={self._transform})"
[docs]
class ApplySameTransformInputKeyOnList(ApplySameTransformToKeyOnList):
"""Apply same transform to the input list key.
Args:
transform: The transform to apply.
dim: The dimension to retrieve the various elements of the list.
"""
def __init__(self, transform: Module, dim: int = 1):
super().__init__("input", transform=transform, dim=dim)
def __repr__(self):
return (
f"{self.__class__.__name__}(transform={self._transform}, dim={self._dim})"
)
[docs]
class ApplyTransformAudioKeyOnList(ApplyTransformToKeyOnList):
"""Apply Transform to the audio key.
Args:
transform: The transform to apply.
"""
def __init__(self, transform: Module):
super().__init__("audio", transform=transform)
def __repr__(self):
return f"{self.__class__.__name__}(transform={self._transform})"
[docs]
class ApplyTransformInputKey(ApplyTransformToKey):
"""Apply Transform to the input key.
Args:
transform: The transform to apply.
"""
def __init__(self, transform: Module):
super().__init__("input", transform=transform)
def __repr__(self):
return f"{self.__class__.__name__}(transform={self._transform})"
[docs]
class ApplyTransformAudioKey(ApplyTransformToKey):
"""Apply Transform to the audio key.
Args:
transform: The transform to apply.
"""
def __init__(self, transform: Module):
super().__init__("audio", transform=transform)
[docs]
class ApplyTransformOnDict(Module):
"""Apply Transform to the audio key.
Args:
transform: The transform to apply.
"""
def __init__(self, transform: Module):
super().__init__()
self._transform = transform
def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
x = self._transform(x)
return x
def __repr__(self):
return f"{self.__class__.__name__}(transform={self._transform})"