Source code for eztorch.datasets.clip_samplers.soccernet.soccernet_clip_sampler
from abc import ABC
from torch.utils.data import Sampler
from eztorch.datasets.soccernet import SoccerNet
from eztorch.utils.utils import get_default_seed
[docs]
class SoccerNetClipSampler(Sampler, ABC):
"""Base class for SoccerNet clip samplers.
Args:
data_source: SoccerNet dataset.
shuffle: Whether to shuffle indices.
"""
def __init__(
self,
data_source: SoccerNet,
shuffle: bool = False,
) -> None:
super().__init__(data_source)
self.data_source = data_source
self.epoch = 0
self.seed = get_default_seed()
self._shuffle = shuffle
@property
def shuffle(self) -> bool:
return self._shuffle
def set_shuffle(self, shuffle: bool) -> None:
"""Set shuffle value.
Args:
shuffle: Value for shuffle.
"""
self._shuffle = shuffle
def set_epoch(self, epoch: int) -> None:
"""Sets the epoch for this sampler.
This ensures that at each epoch the windows are not the same for relevant subclass samplers.
Args:
epoch: Epoch number.
"""
self.epoch = epoch