Source code for eztorch.models.siamese.sce

from typing import Any, Optional

from omegaconf import DictConfig
from torch import Tensor

from eztorch.losses.sce_loss import compute_sce_loss, compute_sce_mask
from eztorch.models.siamese.shuffle_momentum_queue_base import \
    ShuffleMomentumQueueBaseModel
from eztorch.utils.utils import scheduler_value

LARGE_NUM = 1e9


[docs] class SCEModel(ShuffleMomentumQueueBaseModel): """SCE model. References: - SCE: https://arxiv.org/pdf/2111.14585.pdf Args: trunk: Config tu build a trunk. optimizer: Config tu build optimizers and schedulers. projector: Config to build a project. predictor: Config to build a predictor. train_transform: Config to perform transformation on train input. val_transform: Config to perform transformation on val input. test_transform: Config to perform transformation on test input. normalize_outputs: If ``True``, normalize outputs. num_global_crops: Number of global crops which are the first elements of each batch. num_local_crops: Number of local crops which are the last elements of each batch. num_splits: Number of splits to apply to each crops. num_splits_per_combination: Number of splits used for combinations of features of each split. mutual_pass: If ``True``, perform one pass per branch per crop resolution. initial_momentum: initial value for the momentum update. scheduler_momentum: rule to update the momentum value. shuffle_bn: If ``True``, apply shuffle normalization trick from MoCo. num_devices: Number of devices used to train the model in each node. simulate_n_devices: Number of devices to simulate to apply shuffle trick. Requires ``shuffle_bn`` to be ``True`` and ``num_devices`` to be :math:`1`. queue: Config to build a queue. sym: If ``True``, symmetrised the loss. use_keys: If ``True``, add keys to negatives. temp: Temperature parameter to scale the online similarities. temp_m: Temperature parameter to scale the target similarities. Initial value if warmup applied. start_warmup_temp_m: Initial temperature parameter to scale the target similarities in case of warmup. warmup_epoch_temp_m: Number of warmup epochs for the target temperature. warmup_scheduler_temp_m: Type of scheduler for warming up the target temperature. Options are: ``'linear'``, ``'cosine'``. coeff: Coeff parameter between InfoNCE and relational aspects. warmup_scheduler_coeff: Type of scheduler for warming up the coefficient. Options are: ``'linear'``, ``'cosine'``. warmup_epoch_coeff: Number of warmup epochs for coefficient. start_warmup_coeff: Starting value of coefficient for warmup. scheduler_coeff: Type of scheduler for coefficient after warmup. Options are: ``'linear'``, ``'cosine'``. final_scheduler_coeff: Final value of scheduler coefficient. """ def __init__( self, trunk: DictConfig, optimizer: DictConfig, projector: Optional[DictConfig] = None, predictor: Optional[DictConfig] = None, train_transform: Optional[DictConfig] = None, val_transform: Optional[DictConfig] = None, test_transform: Optional[DictConfig] = None, normalize_outputs: bool = True, num_global_crops: int = 2, num_local_crops: int = 0, num_splits: int = 0, num_splits_per_combination: int = 2, mutual_pass: bool = False, initial_momentum: int = 0.999, scheduler_momentum: str = "constant", shuffle_bn: bool = True, num_devices: int = 1, simulate_n_devices: int = 8, queue: Optional[DictConfig] = None, sym: bool = False, use_keys: bool = False, temp: float = 0.1, temp_m: float = 0.05, start_warmup_temp_m: float = 0.05, warmup_epoch_temp_m: int = 0, warmup_scheduler_temp_m: Optional[int] = "cosine", coeff: float = 0.5, warmup_scheduler_coeff: Optional[int] = "linear", warmup_epoch_coeff: int = 0, start_warmup_coeff: float = 1.0, scheduler_coeff: Optional[str] = None, final_scheduler_coeff: float = 0.0, ) -> None: super().__init__( trunk=trunk, optimizer=optimizer, projector=projector, predictor=predictor, train_transform=train_transform, val_transform=val_transform, test_transform=test_transform, normalize_outputs=normalize_outputs, num_global_crops=num_global_crops, num_local_crops=num_local_crops, num_splits=num_splits, num_splits_per_combination=num_splits_per_combination, mutual_pass=mutual_pass, initial_momentum=initial_momentum, scheduler_momentum=scheduler_momentum, shuffle_bn=shuffle_bn, num_devices=num_devices, simulate_n_devices=simulate_n_devices, queue=queue, sym=sym, use_keys=use_keys, ) self.save_hyperparameters() self.temp = temp self.temp_m = temp_m self.start_warmup_temp_m = start_warmup_temp_m self.final_temp_m = temp_m self.warmup_scheduler_temp_m = warmup_scheduler_temp_m self.warmup_epoch_temp_m = warmup_epoch_temp_m self.coeff = coeff self.initial_coeff = coeff self.warmup_scheduler_coeff = warmup_scheduler_coeff self.warmup_epoch_coeff = warmup_epoch_coeff self.start_warmup_coeff = start_warmup_coeff self.scheduler_coeff = scheduler_coeff self.final_scheduler_coeff = final_scheduler_coeff def _precompute_mask(self) -> None: batch_size = self.trainer.datamodule.train_local_batch_size self.mask = compute_sce_mask( batch_size=batch_size, num_negatives=self.queue.shape[1] if self.queue is not None else 0, use_keys=self.use_keys, rank=self.global_rank, world_size=self.trainer.world_size, device=self.device, ) def on_fit_start(self) -> None: super().on_fit_start() self._precompute_mask() def compute_loss( self, q: Tensor, k: Tensor, k_global: Tensor, queue: Tensor | None ) -> Tensor: """Compute the SCE loss for several tokens as output. Args: q: The representations of the queries. k: The representations of the keys. k_global: The global representations of the keys. queue: The queue of representations if not None. Returns: The loss. """ k_loss = k_global if self.use_keys else k loss = compute_sce_loss( q=q, k=k, k_global=k_loss, use_keys=self.use_keys, queue=queue, mask=self.mask, coeff=self.coeff, temp=self.temp, temp_m=self.temp_m, LARGE_NUM=LARGE_NUM, ) return loss def on_train_batch_start(self, batch: Any, batch_idx: int) -> None: if self.warmup_epoch_temp_m > 0: if self.current_epoch >= self.warmup_epoch_temp_m: self.temp_m = self.final_temp_m else: self.temp_m = scheduler_value( self.warmup_scheduler_temp_m, self.start_warmup_temp_m, self.final_temp_m, self.global_step, self.warmup_epoch_temp_m * self.training_steps_per_epoch - 1, ) if self.warmup_epoch_coeff > 0: if self.current_epoch >= self.warmup_epoch_coeff: self.coeff = self.initial_coeff else: self.coeff = scheduler_value( self.warmup_scheduler_coeff, self.start_warmup_coeff, self.initial_coeff, self.global_step, self.warmup_epoch_coeff * self.training_steps_per_epoch - 1, ) if self.scheduler_coeff is not None: if self.warmup_epoch_coeff > 0: if self.current_epoch >= self.warmup_epoch_coeff: self.coeff = scheduler_value( self.scheduler_coeff, self.initial_coeff, self.final_scheduler_coeff, self.global_step - self.warmup_epoch_coeff * self.training_steps_per_epoch, (self.trainer.max_epochs - self.warmup_epoch_coeff) * self.training_steps_per_epoch - 1, ) else: self.coeff = scheduler_value( self.scheduler_coeff, self.initial_coeff, self.final_scheduler_coeff, self.global_step, self.trainer.max_epochs * self.training_steps_per_epoch - 1, ) self.log("pretrain/temp", self.temp, on_step=True, on_epoch=True) self.log("pretrain/temp_m", self.temp_m, on_step=True, on_epoch=True) self.log("pretrain/coeff", self.coeff, on_step=True, on_epoch=True) return