Source code for eztorch.models.siamese.simclr

from typing import Any, Dict, Iterable, Optional

import torch
from omegaconf import DictConfig
from torch import Tensor, nn

from eztorch.losses.simclr_loss import (compute_simclr_loss,
                                        compute_simclr_masks)
from eztorch.models.siamese.base import SiameseBaseModel
from eztorch.modules.gather import concat_all_gather_with_backprop


[docs] class SimCLRModel(SiameseBaseModel): """SimCLR model with version 1, 2 that can be configured. References: - SimCLR: https://arxiv.org/abs/2002.05709 - SimCLRv2: https://arxiv.org/abs/2006.10029 Args: trunk: Config tu build a trunk. optimizer: Config tu build optimizers and schedulers. projector: Config to build a project. train_transform: Config to perform transformation on train input. val_transform: Config to perform transformation on val input. test_transform: Config to perform transformation on test input. normalize_outputs: If ``True``, normalize outputs. num_global_crops: Number of global crops which are the first elements of each batch. num_local_crops: Number of local crops which are the last elements of each batch. num_splits: Number of splits to apply to each crops. num_splits_per_combination: Number of splits used for combinations of features of each split. mutual_pass: If ``True``, perform one pass per crop resolution. temp: Temperature parameter to scale the online similarities. """ def __init__( self, trunk: DictConfig, optimizer: DictConfig, projector: Optional[DictConfig] = None, train_transform: Optional[DictConfig] = None, val_transform: Optional[DictConfig] = None, test_transform: Optional[DictConfig] = None, normalize_outputs: bool = True, num_global_crops: int = 2, num_local_crops: int = 0, num_splits: int = 0, num_splits_per_combination: int = 2, mutual_pass: bool = False, temp: float = 0.1, ) -> None: super().__init__( trunk=trunk, optimizer=optimizer, projector=projector, train_transform=train_transform, val_transform=val_transform, test_transform=test_transform, num_global_crops=num_global_crops, num_local_crops=num_local_crops, num_splits=num_splits, num_splits_per_combination=num_splits_per_combination, mutual_pass=mutual_pass, predictor=None, normalize_outputs=normalize_outputs, ) assert not self.use_split, "Splits not supported for SimCLR" self.temp = temp self.save_hyperparameters() def _precompute_mask(self) -> None: batch_size = self.trainer.datamodule.train_local_batch_size self.pos_mask, self.neg_mask = compute_simclr_masks( batch_size=batch_size, num_crops=self.num_crops, rank=self.global_rank, world_size=self.trainer.world_size, device=self.device, ) def on_fit_start(self) -> None: super().on_fit_start() self._precompute_mask() def compute_loss(self, z: Tensor, z_global: Tensor) -> Tensor: """Compute the SimCLR loss. z_global is provided and not computed in the loss to prevent multiple gathering of z that require synchronisation among processes. Args: z: The representations of all crops. z_global: The global representations of all crops. Aggregated on all devices. Returns: The loss. """ return compute_simclr_loss(z, z_global, self.pos_mask, self.neg_mask, self.temp) def training_step(self, batch: Iterable[Any], batch_idx: int) -> Dict[str, Tensor]: X = batch["input"] assert len(X) == self.num_crops if self.train_transform is not None: with torch.no_grad(): with torch.cuda.amp.autocast(enabled=False): X = self.train_transform(X) outs_online = self.multi_crop_shared_step(X) z = torch.cat([out_online["z"] for out_online in outs_online]) z_global = concat_all_gather_with_backprop(z) loss = self.compute_loss(z, z_global) outputs = {"loss": loss} # Only compute stats for first crop to avoid unnecessary computations outputs.update(outs_online[0]) for name_output, output in outputs.items(): if name_output != "loss": outputs[name_output] = output.detach() self.log( "pretrain/loss", outputs["loss"], prog_bar=True, on_step=True, on_epoch=True ) return outputs