Source code for eztorch.transforms.video.spot.spotting_mixup

from typing import Any, Dict

import torch
from torch import Tensor
from torch.nn import Module


def mix_spotting(
    x: Tensor,
    mix_value: Tensor,
    permutation: Tensor,
    labels: Tensor,
    has_label: Tensor,
    ignore_class: Tensor,
):
    """Make mixup of the batch for action spotting.

    Args:
        x: The batch values to mix.
        mix_value: Value coefficients for mixing.
        permutation: Permutation to perform mix.
        labels: Labels of the timestamps in the batch.
        has_label: Whether timestamps have label.
        ignore_class: Whether class in the batch should be ignored.

    Returns:
        Tuple containing:
        - The mixed input.
        - The mixed class labels.
        - The `ignore_class` of the mixed elements.
        - The concatenated `mix_value` of the mixed elements.
    """
    x_permuted = x[permutation]
    labels_permuted = labels[permutation]
    has_label_permuted = has_label[permutation]
    ignore_class_permuted = ignore_class[permutation]

    labels_cat = torch.cat((labels, labels_permuted))
    has_label_cat = torch.cat((has_label, has_label_permuted))
    ignore_class_cat = torch.cat((ignore_class, ignore_class_permuted))

    mix_value_x = mix_value.view([-1, *([1] * (x.ndim - 1))])
    one_minus_mix_value_x = 1 - mix_value_x
    mix_value_label = mix_value.view([-1, *([1] * (labels.ndim - 1))])
    one_minus_mix_value_label = 1 - mix_value_label

    x_mixed = mix_value_x * x + one_minus_mix_value_x * x_permuted
    mixed_weights = torch.cat((mix_value_label, one_minus_mix_value_label))

    return (
        x_mixed,
        labels_cat,
        has_label_cat,
        ignore_class_cat,
        mixed_weights,
    )


[docs] class SpottingMixup(Module): """Make mixup for spotting for labels. Args: alpha: Alpha value for the beta distribution of mixup. """ def __init__( self, alpha: float = 0.5, ) -> None: super().__init__() self.alpha = alpha self.mix_sampler = torch.distributions.Beta(alpha, alpha) def forward(self, batch: Dict[str, Any]): (x, labels, has_label, ignore_class,) = ( batch["input"], batch["labels"], batch["has_label"], batch["ignore_class"], ) device, dtype = x.device, x.dtype batch_size = x.shape[0] with torch.inference_mode(): mix_value = self.mix_sampler.sample((batch_size,)).to( device=device, dtype=dtype ) mix_value = mix_value.clone() permutation = torch.randperm(batch_size, device=device) ( x_mixed, labels_after_mix, has_label_after_mix, ignore_class_after_mix, mixed_weights, ) = mix_spotting( x=x, mix_value=mix_value, permutation=permutation, labels=labels, has_label=has_label, ignore_class=ignore_class, ) new_batch = { "input": x_mixed, "labels": labels_after_mix, "ignore_class": ignore_class_after_mix, "has_label": has_label_after_mix, "mixup_weights": mixed_weights, } return new_batch def __repr__(self): return f"{__class__.__name__}(alpha={self.alpha})"