from typing import Optional
from omegaconf import DictConfig
from torch import Tensor
from eztorch.losses.ressl_loss import compute_ressl_loss, compute_ressl_mask
from eztorch.models.siamese.shuffle_momentum_queue_base import \
ShuffleMomentumQueueBaseModel
LARGE_NUM = 1e9
[docs]
class ReSSLModel(ShuffleMomentumQueueBaseModel):
"""ReSSL model.
References:
- ReSSL: https://proceedings.neurips.cc/paper/2021/file/14c4f36143b4b09cbc320d7c95a50ee7-Paper.pdf
Args:
trunk: Config tu build a trunk.
optimizer: Config tu build optimizers and schedulers.
projector: Config to build a project.
predictor: Config to build a predictor.
train_transform: Config to perform transformation on train input.
val_transform: Config to perform transformation on val input.
test_transform: Config to perform transformation on test input.
normalize_outputs: If ``True``, normalize outputs.
num_global_crops: Number of global crops which are the first elements of each batch.
num_local_crops: Number of local crops which are the last elements of each batch.
num_splits: Number of splits to apply to each crops.
num_splits_per_combination: Number of splits used for combinations of features of each split.
mutual_pass: If ``True``, perform one pass per branch per crop resolution.
initial_momentum: Initial value for the momentum update.
scheduler_momentum: Rule to update the momentum value.
shuffle_bn: If ``True``, apply shuffle normalization trick from MoCo.
num_devices: Number of devices used to train the model in each node.
simulate_n_devices: Number of devices to simulate to apply shuffle trick. Requires ``shuffle_bn`` to be ``True`` and ``num_devices`` to be :math:`1`.
queue: Config to build a queue.
sym: If ``True``, symmetrised the loss.
use_keys: If ``True``, add keys to negatives.
temp: Temperature parameter to scale the online similarities.
temp_m: Temperature parameter to scale the target similarities. Initial value if warmup applied.
"""
def __init__(
self,
trunk: DictConfig,
optimizer: DictConfig,
projector: Optional[DictConfig] = None,
predictor: Optional[DictConfig] = None,
train_transform: Optional[DictConfig] = None,
val_transform: Optional[DictConfig] = None,
test_transform: Optional[DictConfig] = None,
normalize_outputs: bool = True,
num_global_crops: int = 2,
num_local_crops: int = 0,
num_splits: int = 0,
num_splits_per_combination: int = 2,
mutual_pass: bool = False,
initial_momentum: int = 0.999,
scheduler_momentum: str = "constant",
shuffle_bn: bool = True,
num_devices: int = 1,
simulate_n_devices: int = 8,
queue: Optional[DictConfig] = None,
sym: bool = False,
use_keys: bool = False,
temp: float = 0.1,
temp_m: float = 0.04,
initial_temp_m: float = 0.04,
) -> None:
super().__init__(
trunk=trunk,
optimizer=optimizer,
projector=projector,
predictor=predictor,
train_transform=train_transform,
val_transform=val_transform,
test_transform=test_transform,
normalize_outputs=normalize_outputs,
num_global_crops=num_global_crops,
num_local_crops=num_local_crops,
num_splits=num_splits,
num_splits_per_combination=num_splits_per_combination,
mutual_pass=mutual_pass,
initial_momentum=initial_momentum,
scheduler_momentum=scheduler_momentum,
shuffle_bn=shuffle_bn,
num_devices=num_devices,
simulate_n_devices=simulate_n_devices,
queue=queue,
sym=sym,
use_keys=use_keys,
)
self.save_hyperparameters()
self.temp = temp
self.temp_m = temp_m
self.initial_temp_m = initial_temp_m
self.final_temp_m = temp_m
def _precompute_mask(self) -> None:
batch_size = self.trainer.datamodule.train_local_batch_size
self.mask = compute_ressl_mask(
batch_size=batch_size,
num_negatives=self.queue.shape[1] if self.queue is not None else 0,
use_keys=self.use_keys,
rank=self.global_rank,
world_size=self.trainer.world_size,
device=self.device,
)
def on_fit_start(self) -> None:
super().on_fit_start()
self._precompute_mask()
def compute_loss(
self, q: Tensor, k: Tensor, k_global: Tensor, queue: Tensor | None
) -> Tensor:
"""Compute the ReSSL loss.
Args:
q: The representations of the queries.
k: The representations of the keys.
k_global: The global representations of the keys.
queue: The queue of representations if not None.
Returns:
The loss.
"""
k_loss = k_global if self.use_keys else k
loss = compute_ressl_loss(
q=q,
k=k,
k_global=k_loss,
use_keys=self.use_keys,
queue=queue,
mask=self.mask,
temp=self.temp,
temp_m=self.temp_m,
LARGE_NUM=LARGE_NUM,
)
return loss