Source code for eztorch.models.soccernet_spotting

import shutil
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

import hydra
import torch
import torch.distributed as dist
from lightning.pytorch.utilities import rank_zero_info
from omegaconf import DictConfig
from SoccerNet.Evaluation.ActionSpotting import evaluate
from torch import Tensor, nn
from torch.nn import Parameter

from eztorch.datasets.soccernet_utils.parse_utils import (
    REVERSE_ACTION_SPOTTING_LABELS, REVERSE_BALL_SPOTTING_LABELS,
    SoccerNetTask, get_soccernet_weights)
from eztorch.datasets.soccernet_utils.predictions import (
    add_clips_predictions, aggregate_and_filter_clips, initialize_predictions,
    save_raw_spotting_predictions, save_spotting_predictions)
from eztorch.datasets.utils_fn import get_shard_indices
from eztorch.losses.spot_loss import compute_spot_loss_fn
from eztorch.models.eztorch_base_module import EztorchBaseModule
from eztorch.utils.checkpoints import (get_sub_state_dict_from_pl_ckpt,
                                       remove_pattern_in_keys_from_dict)


[docs] class SoccerNetSpottingModel(EztorchBaseModule): r"""Model to perform action spotting. Args: trunk: Config to build a trunk. head_class: Config to build a head for classification. optimizer: Config to build an optimizer for trunk. pretrained_trunk_path: Path to the pretrained trunk file. pretrained_path: Path to the pretrained model. prediction_args: Arguments to configure predictions. loss_fn_args: Arguments for the loss function. eval_step_timestamp: Step between each timestamp. trunk_pattern: Pattern to retrieve the trunk model in checkpoint state_dict and delete the key. freeze_trunk: Whether to freeze the trunk. train_transform: Config to perform transformation on train input. val_transform: Config to perform transformation on val input. test_transform: Config to perform transformation on test input. save_val_preds_path: Path to store the validation predictions. save_test_preds_path: Path to store the test predictions. NMS_args: Arguments to configure the NMS. evaluation_args: Arguments to configure the evaluation. Example:: trunk = {...} # config to build a trunk head_class = {...} # config to build a head for classification optimizer = {...} # config to build an optimizer pretrained_trunk_path = ... # path where the trunk has been saved model = SoccerNetSpottingModel(trunk, head_class, optimizer, pretrained_trunk_path) """ def __init__( self, trunk: DictConfig, head_class: DictConfig, optimizer: DictConfig, pretrained_path: str | None = None, pretrained_trunk_path: str | None = None, loss_fn_args: DictConfig = DictConfig({}), prediction_args: DictConfig = DictConfig({}), eval_step_timestamp: float = 1.0, trunk_pattern: str = r"^(trunk\.)", freeze_trunk: bool = False, train_transform: DictConfig | None = None, val_transform: DictConfig | None = None, test_transform: DictConfig | None = None, save_val_preds_path: str | Path = "val_preds/", save_test_preds_path: str | Path = "test_preds/", NMS_args: DictConfig = DictConfig({"window": 10, "threshold": 0.001}), evaluation_args: DictConfig = DictConfig({}), do_compile: bool = False, ) -> None: super().__init__() self.save_hyperparameters() self.optimizer_cfg = optimizer self.freeze_trunk = freeze_trunk self.trunk: nn.Module = hydra.utils.instantiate(trunk) if pretrained_trunk_path is not None: trunk_state_dict = get_sub_state_dict_from_pl_ckpt( checkpoint_path=pretrained_trunk_path, pattern=trunk_pattern ) trunk_state_dict = remove_pattern_in_keys_from_dict( d=trunk_state_dict, pattern=trunk_pattern ) missing_keys, unexpected_keys = self.trunk.load_state_dict( trunk_state_dict, strict=False ) rank_zero_info( f"Loaded {__class__.__name__} from pretrained trunk weights model.\n" f"missing_keys:{missing_keys}\n" f"unexpected_keys:{unexpected_keys}" ) self.head_class: nn.Module = hydra.utils.instantiate(head_class) self.loss_fn = compute_spot_loss_fn if do_compile: self.trunk = torch.compile(self.trunk) self.loss_fn = torch.compile(self.loss_fn) self.train_transform = ( hydra.utils.instantiate(train_transform) if train_transform is not None else None ) self.val_transform = ( hydra.utils.instantiate(val_transform) if val_transform is not None else None ) self.test_transform = ( hydra.utils.instantiate(test_transform) if test_transform is not None else None ) self.loss_fn_args = dict(loss_fn_args) self.remove_inference_prediction_seconds = prediction_args.get( "remove_inference_prediction_seconds", -1.0 ) self.merge_predictions_type = prediction_args.get( "merge_predictions_type", "max" ) self.eval_step_timestamp = eval_step_timestamp self.save_val_preds_path = Path(save_val_preds_path) self.save_test_preds_path = Path(save_test_preds_path) self.NMS_args = NMS_args evaluation_args.pop("metric") self.evaluation_args = evaluation_args if self.evaluation_args["split"] == "val": self.evaluation_args["split"] = "valid" if pretrained_path is not None: state_dict = torch.load(pretrained_path, map_location="cpu") if "state_dict" in state_dict: state_dict = ["state_dict"] missing_keys, unexpected_keys = self.load_state_dict( state_dict, strict=False ) rank_zero_info( f"Loaded {__class__.__name__} from pretrained torch model.\n" f"missing_keys:{missing_keys}\n" f"unexpected_keys:{unexpected_keys}" ) rank_zero_info( f"{__class__.__name__} loaded trunk weights from {pretrained_trunk_path}." ) if self.freeze_trunk: for param in self.trunk.parameters(): param.requires_grad = False @property def learnable_params(self) -> List[Tuple[str, Parameter]]: """Learnable parameters of the model.""" params = [] if not self.freeze_trunk: params.extend(self.trunk.parameters()) params.extend(self.head_class.parameters()) return params @property def training_steps_per_epoch(self) -> Optional[int]: """Total training steps inferred from datamodule and devices.""" if ( self.trainer.datamodule is not None and self.trainer.datamodule.train_num_samples > 0 ): return ( self.trainer.datamodule.train_num_samples // self.trainer.datamodule.train_global_batch_size ) else: return None def on_fit_start(self) -> None: class_weights = self.loss_fn_args.get("class_weights", {"type": None}) self.loss_fn_args["class_weights"] = get_soccernet_weights( self.trainer.datamodule.train_dataset, device=self.device, **class_weights, ) def on_validation_epoch_start(self) -> None: shard_indices = get_shard_indices( self.trainer.datamodule.val_dataset.num_videos, shuffle_shards=False ) if len(shard_indices) == 0: self.min_video_index = self.max_video_index = -1 self.predictions = None return self.min_video_index = min(shard_indices) self.max_video_index = max(shard_indices) self.predictions = initialize_predictions( self.trainer.datamodule.val_dataset, self.eval_step_timestamp, self.max_video_index, self.min_video_index, self.device, ) return def on_test_start(self) -> None: shard_indices = get_shard_indices( self.trainer.datamodule.test_dataset.num_videos, shuffle_shards=False ) if len(shard_indices) == 0: self.min_video_index = self.max_video_index = -1 self.predictions = None return self.min_video_index = min(shard_indices) self.max_video_index = max(shard_indices) self.predictions = initialize_predictions( self.trainer.datamodule.test_dataset, self.eval_step_timestamp, self.max_video_index, self.min_video_index, self.device, ) return def on_train_epoch_start(self) -> None: if self.freeze_trunk: self.trunk.eval() def forward(self, x: Tensor) -> Dict[str, Any]: h = self.trunk(x) class_preds = self.head_class(h) return { "class_preds": class_preds, "h": h, } @property def num_layers(self) -> int: """Number of layers of the model.""" return self.num_layers_class @property def num_layers_class(self) -> int: return self.trunk.num_layers + self.head_class.num_layers def get_param_layer_id(self, name: str) -> int: """Get the layer id of the named parameter. Args: name: The name of the parameter. """ if name.startswith("trunk."): return self.trunk.get_param_layer_id(name[len("trunk.") :]) elif name.startswith("head_class."): return self.trunk.num_layers + self.head_class.get_param_layer_id( name[len("head_class.") :] ) else: raise NotImplementedError(f"{name} should not have been used.") def configure_optimizers(self) -> Dict[Any, Any]: optimizer, scheduler = hydra.utils.instantiate( self.optimizer_cfg, num_steps_per_epoch=self.optimizer_cfg.get( "num_steps_per_epoch", self.training_steps_per_epoch ), model=self, ) if scheduler is None: return optimizer return {"optimizer": optimizer, "lr_scheduler": scheduler} def shared_step(self, x: Tensor, inversed_temporal_masked_indices: Tensor | None): if self.freeze_trunk: with torch.no_grad(): if inversed_temporal_masked_indices is not None: h: Tensor = self.trunk( x, inversed_temporal_masked_indices=inversed_temporal_masked_indices, ) else: h = self.trunk(x) else: if inversed_temporal_masked_indices is not None: h = self.trunk( x, inversed_temporal_masked_indices=inversed_temporal_masked_indices ) else: h = self.trunk(x) if h.ndim == 2: h = h.reshape(h.shape[0], 1, h.shape[1]) class_preds: Tensor = self.head_class(h) return class_preds def training_step(self, batch: Dict[Any, Any], batch_idx: int) -> Tensor: if self.train_transform is not None: with torch.no_grad(): with torch.cuda.amp.autocast(enabled=False): batch = self.train_transform(batch) (x, labels, has_label, ignore_class,) = ( batch["input"], batch["labels"], batch["has_label"], batch["ignore_class"], ) inversed_temporal_masked_indices: Tensor | None = batch.get( "inversed_temporal_masked_indices", None ) class_preds = self.shared_step(x, inversed_temporal_masked_indices) loss = self.loss_fn( class_preds=class_preds, class_target=labels, ignore_class=ignore_class, has_label=has_label, mixup_weights=batch.get("mixup_weights", None), **self.loss_fn_args, ) self.log( "train/loss", loss, prog_bar=True, on_step=True, on_epoch=True, sync_dist=False, ) return loss def validation_step(self, batch: Tensor, batch_idx: int) -> Tensor: if self.val_transform is not None: with torch.no_grad(): with torch.cuda.amp.autocast(enabled=False): batch = self.val_transform(batch) ( x, timestamps, labels, has_label, ignore_class, video_indexes, halves_indexes, halves_duration, ) = ( batch["input"], batch["timestamps"], batch["labels"], batch["has_label"], batch["ignore_class"], batch["video_index"], batch["half_index"], batch["half_duration"], ) num_timestamps = ( torch.round(halves_duration / self.eval_step_timestamp) - 1 ).to(dtype=torch.long) class_preds = self.shared_step(x, inversed_temporal_masked_indices=None) loss = self.loss_fn( class_preds=class_preds, class_target=labels, ignore_class=ignore_class, has_label=has_label, mixup_weights=batch.get("mixup_weights", None), **self.loss_fn_args, ) class_preds = class_preds.sigmoid() kept_tensors = aggregate_and_filter_clips( class_preds, timestamps, num_timestamps, video_indexes, halves_indexes, self.max_video_index, self.min_video_index, ) if kept_tensors is not None: ( class_preds, timestamps, num_timestamps, video_indexes, halves_indexes, ) = kept_tensors add_clips_predictions( self.predictions, class_preds, timestamps, num_timestamps, video_indexes, halves_indexes, self.eval_step_timestamp, self.remove_inference_prediction_seconds, self.merge_predictions_type, ) self.log( "val/loss", loss, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True, ) return loss def test_step(self, batch: Tensor, batch_idx: int) -> Tensor: if self.test_transform is not None: with torch.no_grad(): with torch.cuda.amp.autocast(enabled=False): batch = self.test_transform(batch) x, timestamps, halves_duration, video_indexes, halves_indexes = ( batch["input"], batch["timestamps"], batch["half_duration"], batch["video_index"], batch["half_index"], ) num_timestamps = ( torch.round(halves_duration / self.eval_step_timestamp) - 1 ).to(dtype=torch.long) class_preds = self.shared_step(x, inversed_temporal_masked_indices=None) class_preds = class_preds.sigmoid() kept_tensors = aggregate_and_filter_clips( class_preds, timestamps, num_timestamps, video_indexes, halves_indexes, self.max_video_index, self.min_video_index, ) if kept_tensors is not None: ( class_preds, timestamps, num_timestamps, video_indexes, halves_indexes, ) = kept_tensors add_clips_predictions( self.predictions, class_preds, timestamps, num_timestamps, video_indexes, halves_indexes, self.eval_step_timestamp, self.remove_inference_prediction_seconds, self.merge_predictions_type, ) return 0 def _make_evaluation( self, predictions_path: str | Path, task: SoccerNetTask = SoccerNetTask.ACTION, logger: bool = False, ) -> None: if self.evaluation_args.get("split", None) not in ["valid", "test"]: return if task == SoccerNetTask.ACTION: dataset = "SoccerNet" eval_task = "spotting" num_classes = 17 label_files = "Labels-v2.json" framerate = 2 REVERSE_LABELS = REVERSE_ACTION_SPOTTING_LABELS default_metrics = ["tight", "at1", "at2", "at3", "at4", "at5", "loose"] elif task == SoccerNetTask.BALL: dataset = "Ball" eval_task = "spotting" num_classes = 2 label_files = "Labels-ball.json" framerate = 25 REVERSE_LABELS = REVERSE_BALL_SPOTTING_LABELS default_metrics = ["tight", "at1", "at2", "at3", "at4", "at5"] metrics = self.evaluation_args.pop("metrics", default_metrics) for metric in metrics: log_result = self.global_rank == 0 if log_result: results = evaluate( Predictions_path=str(predictions_path), dataset=dataset, task=eval_task, label_files=label_files, num_classes=num_classes, framerate=framerate, metric=metric, **self.evaluation_args, ) if logger: self.log( f"{metric}_Average_mAP/mAP", results["a_mAP"] if log_result else 0, prog_bar=True, on_step=False, on_epoch=True, sync_dist=False, ) rank_zero_info( f"{metric}_Average_mAP/mAP: {(results['a_mAP'] if log_result else 0):.04f}", ) for c in range(num_classes): map_class = results["a_mAP_per_class"][c] if log_result else 0 if logger: self.log( f"{metric}_Average_mAP/mAP_{REVERSE_LABELS[c]}", map_class, on_step=False, on_epoch=True, sync_dist=False, ) rank_zero_info( f"{metric}_Average_mAP/mAP_{REVERSE_LABELS[c]}: {map_class:.04f}", ) if logger: self.log( f"{metric}_Average_mAP_visible/mAP", results["a_mAP_visible"] if log_result else 0, on_step=False, on_epoch=True, sync_dist=False, ) rank_zero_info( f"{metric}_Average_mAP_visible/mAP: {(results['a_mAP_visible'] if log_result else 0):.04f}", ) for c in range(num_classes): map_class = results["a_mAP_per_class_visible"][c] if log_result else 0 if logger: self.log( f"{metric}_Average_mAP_visible/mAP_{REVERSE_LABELS[c]}", map_class, on_step=False, on_epoch=True, sync_dist=False, ) rank_zero_info( f"{metric}_Average_mAP_visible/mAP_{REVERSE_LABELS[c]}: {map_class:.04f}", ) if logger: self.log( f"{metric}_Average_mAP_unshown/mAP", results["a_mAP_unshown"] if log_result else 0, on_step=False, on_epoch=True, sync_dist=False, ) rank_zero_info( f"{metric}_Average_mAP_unshown/mAP: {(results['a_mAP_unshown'] if log_result else 0):.04f}", ) for c in range(num_classes): map_class = results["a_mAP_per_class_unshown"][c] if log_result else 0 if logger: self.log( f"{metric}_Average_mAP_unshown/mAP_{REVERSE_LABELS[c]}", map_class, on_step=False, on_epoch=True, sync_dist=False, ) rank_zero_info( f"{metric}_Average_mAP_unshown/mAP_{REVERSE_LABELS[c]}: {map_class:.04f}", ) self.evaluation_args["metrics"] = metrics def on_validation_epoch_end(self) -> None: preds_path = self.save_val_preds_path / str(self.trainer.current_epoch) raw_preds_path = self.save_val_preds_path / f"{self.trainer.current_epoch}_raw" if self.predictions is not None: save_raw_spotting_predictions( self.predictions, raw_preds_path, make_zip=False ) save_spotting_predictions( self.predictions, preds_path, self.trainer.datamodule.val_dataset, self.eval_step_timestamp, self.NMS_args, make_zip=False, ) if dist.is_available() and dist.is_initialized(): dist.barrier() if self.global_rank == 0: shutil.make_archive(str(preds_path), "zip", preds_path) shutil.make_archive(raw_preds_path, "zip", raw_preds_path) shutil.rmtree(preds_path) shutil.rmtree(raw_preds_path) if dist.is_available() and dist.is_initialized(): dist.barrier() if not self.trainer.sanity_checking: self._make_evaluation( self.save_val_preds_path / f"{self.trainer.current_epoch}.zip", task=self.trainer.datamodule.val_dataset.task, logger=True, ) return super().on_validation_epoch_end() def on_test_end(self) -> None: raw_preds_path = ( self.save_test_preds_path.parent / f"{self.save_test_preds_path.name}_raw" ) save_raw_spotting_predictions(self.predictions, raw_preds_path, make_zip=False) save_spotting_predictions( self.predictions, self.save_test_preds_path, self.trainer.datamodule.test_dataset, self.eval_step_timestamp, self.NMS_args, make_zip=False, ) if dist.is_available() and dist.is_initialized(): dist.barrier() if self.global_rank == 0: shutil.make_archive( str(self.save_test_preds_path), "zip", self.save_test_preds_path ) shutil.make_archive(str(raw_preds_path), "zip", raw_preds_path) shutil.rmtree(self.save_test_preds_path) shutil.rmtree(raw_preds_path) if dist.is_available() and dist.is_initialized(): dist.barrier() self._make_evaluation( self.save_test_preds_path.parent / f"{self.save_test_preds_path.name}.zip", task=self.trainer.datamodule.test_dataset.task, logger=False, ) return super().on_test_end()