from typing import Any, Dict, Iterable, Optional
import torch
from omegaconf import DictConfig
from torch import Tensor, nn
from eztorch.losses.simclr_loss import (compute_simclr_loss,
                                        compute_simclr_masks)
from eztorch.models.siamese.base import SiameseBaseModel
from eztorch.modules.gather import concat_all_gather_with_backprop
[docs]
class SimCLRModel(SiameseBaseModel):
    """SimCLR model with version 1, 2 that can be configured.
    References:
        - SimCLR: https://arxiv.org/abs/2002.05709
        - SimCLRv2: https://arxiv.org/abs/2006.10029
    Args:
        trunk: Config tu build a trunk.
        optimizer: Config tu build optimizers and schedulers.
        projector: Config to build a project.
        train_transform: Config to perform transformation on train input.
        val_transform: Config to perform transformation on val input.
        test_transform: Config to perform transformation on test input.
        normalize_outputs: If ``True``, normalize outputs.
        num_global_crops: Number of global crops which are the first elements of each batch.
        num_local_crops: Number of local crops which are the last elements of each batch.
        num_splits: Number of splits to apply to each crops.
        num_splits_per_combination: Number of splits used for combinations of features of each split.
        mutual_pass: If ``True``, perform one pass per crop resolution.
        temp: Temperature parameter to scale the online similarities.
    """
    def __init__(
        self,
        trunk: DictConfig,
        optimizer: DictConfig,
        projector: Optional[DictConfig] = None,
        train_transform: Optional[DictConfig] = None,
        val_transform: Optional[DictConfig] = None,
        test_transform: Optional[DictConfig] = None,
        normalize_outputs: bool = True,
        num_global_crops: int = 2,
        num_local_crops: int = 0,
        num_splits: int = 0,
        num_splits_per_combination: int = 2,
        mutual_pass: bool = False,
        temp: float = 0.1,
    ) -> None:
        super().__init__(
            trunk=trunk,
            optimizer=optimizer,
            projector=projector,
            train_transform=train_transform,
            val_transform=val_transform,
            test_transform=test_transform,
            num_global_crops=num_global_crops,
            num_local_crops=num_local_crops,
            num_splits=num_splits,
            num_splits_per_combination=num_splits_per_combination,
            mutual_pass=mutual_pass,
            predictor=None,
            normalize_outputs=normalize_outputs,
        )
        assert not self.use_split, "Splits not supported for SimCLR"
        self.temp = temp
        self.save_hyperparameters()
    def _precompute_mask(self) -> None:
        batch_size = self.trainer.datamodule.train_local_batch_size
        self.pos_mask, self.neg_mask = compute_simclr_masks(
            batch_size=batch_size,
            num_crops=self.num_crops,
            rank=self.global_rank,
            world_size=self.trainer.world_size,
            device=self.device,
        )
    def on_fit_start(self) -> None:
        super().on_fit_start()
        self._precompute_mask()
    def compute_loss(self, z: Tensor, z_global: Tensor) -> Tensor:
        """Compute the SimCLR loss.
        z_global is provided and not computed in the loss to prevent multiple gathering of z that require synchronisation among processes.
        Args:
            z: The representations of all crops.
            z_global: The global representations of all crops. Aggregated on all devices.
        Returns:
            The loss.
        """
        return compute_simclr_loss(z, z_global, self.pos_mask, self.neg_mask, self.temp)
    def training_step(self, batch: Iterable[Any], batch_idx: int) -> Dict[str, Tensor]:
        X = batch["input"]
        assert len(X) == self.num_crops
        if self.train_transform is not None:
            with torch.no_grad():
                with torch.cuda.amp.autocast(enabled=False):
                    X = self.train_transform(X)
        outs_online = self.multi_crop_shared_step(X)
        z = torch.cat([out_online["z"] for out_online in outs_online])
        z_global = concat_all_gather_with_backprop(z)
        loss = self.compute_loss(z, z_global)
        outputs = {"loss": loss}
        # Only compute stats for first crop to avoid unnecessary computations
        outputs.update(outs_online[0])
        for name_output, output in outputs.items():
            if name_output != "loss":
                outputs[name_output] = output.detach()
        self.log(
            "pretrain/loss", outputs["loss"], prog_bar=True, on_step=True, on_epoch=True
        )
        return outputs