Finetuning

class eztorch.models.finetuning.FinetuningModel(trunk, classifier, optimizer, pretrained_trunk_path, trunk_pattern='^(trunk\\\\.)', two_groups=False, freeze_trunk=False, train_transform=None, val_transform=None, test_transform=None, val_time_augmentation=None, test_time_augmentation=None, update_bn_momentum=False, freeze_bn_layers=False)[source]

Fine-tuning training.

Parameters:
  • trunk (DictConfig) – Config to build a trunk.

  • classifier (DictConfig) – Config to build a classifier.

  • optimizer (DictConfig) – Config to build an optimizer for trunk.

  • pretrained_trunk_path (str) – Path to the pretrained trunk file.

  • trunk_pattern (str, optional) – Pattern to retrieve the trunk model in checkpoint state_dict and delete the key.

    Default: '^(trunk\.)'

  • two_groups (bool, optional) – If True, use two groups of parameters for optimizer, the trunk and the head.

    Default: False

  • freeze_trunk (bool, optional) – If True, freeze the trunk.

    Default: False

  • train_transform (Optional[DictConfig], optional) – Config to perform transformation on train input.

    Default: None

  • val_transform (Optional[DictConfig], optional) – Config to perform transformation on val input.

    Default: None

  • test_transform (Optional[DictConfig], optional) – Config to perform transformation on test input.

    Default: None

  • val_time_augmentation (Optional[DictConfig], optional) – Ensembling method for test time augmentation used at validation.

    Default: None

  • test_time_augmentation (Optional[DictConfig], optional) – Ensembling method for test time augmentation used at test.

    Default: None

  • update_bn_momentum (bool, optional) – If True update batch norm statistics according to \(max(1 - 10/steps\_per\_epoch, 0.9)\).

    Default: False

  • freeze_bn_layers (bool, optional) – If True, freeze the batch norm layers.

    Default: False

Example:

trunk = {...} # config to build a trunk
classifier = {...} # config to build a classifier
optimizer = {...} # config to build an optimizer
pretrained_trunk_path = ... # path where the trunk has been saved

model = FinetuningModel(trunk, classifier, optimizer, pretrained_trunk_path)
class eztorch.models.soccernet_spotting.SoccerNetSpottingModel(trunk, head_class, optimizer, pretrained_path=None, pretrained_trunk_path=None, loss_fn_args={}, prediction_args={}, eval_step_timestamp=1.0, trunk_pattern='^(trunk\\\\.)', freeze_trunk=False, train_transform=None, val_transform=None, test_transform=None, save_val_preds_path='val_preds/', save_test_preds_path='test_preds/', NMS_args={'window': 10, 'threshold': 0.001}, evaluation_args={}, do_compile=False)[source]

Model to perform action spotting.

Parameters:
  • trunk (DictConfig) – Config to build a trunk.

  • head_class (DictConfig) – Config to build a head for classification.

  • optimizer (DictConfig) – Config to build an optimizer for trunk.

  • pretrained_trunk_path (str | None, optional) – Path to the pretrained trunk file.

    Default: None

  • pretrained_path (str | None, optional) – Path to the pretrained model.

    Default: None

  • prediction_args (DictConfig, optional) – Arguments to configure predictions.

    Default: {}

  • loss_fn_args (DictConfig, optional) – Arguments for the loss function.

    Default: {}

  • eval_step_timestamp (float, optional) – Step between each timestamp.

    Default: 1.0

  • trunk_pattern (str, optional) – Pattern to retrieve the trunk model in checkpoint state_dict and delete the key.

    Default: '^(trunk\.)'

  • freeze_trunk (bool, optional) – Whether to freeze the trunk.

    Default: False

  • train_transform (DictConfig | None, optional) – Config to perform transformation on train input.

    Default: None

  • val_transform (DictConfig | None, optional) – Config to perform transformation on val input.

    Default: None

  • test_transform (DictConfig | None, optional) – Config to perform transformation on test input.

    Default: None

  • save_val_preds_path (str | Path, optional) – Path to store the validation predictions.

    Default: 'val_preds/'

  • save_test_preds_path (str | Path, optional) – Path to store the test predictions.

    Default: 'test_preds/'

  • NMS_args (DictConfig, optional) – Arguments to configure the NMS.

    Default: {'window': 10, 'threshold': 0.001}

  • evaluation_args (DictConfig, optional) – Arguments to configure the evaluation.

    Default: {}

Example:

trunk = {...} # config to build a trunk
head_class = {...} # config to build a head for classification
optimizer = {...} # config to build an optimizer
pretrained_trunk_path = ... # path where the trunk has been saved

model = SoccerNetSpottingModel(trunk, head_class, optimizer, pretrained_trunk_path)
class eztorch.models.spotting.SpottingModel(trunk, head_class, optimizer, pretrained_path=None, pretrained_trunk_path=None, loss_fn_args={}, prediction_args={}, trunk_pattern='^(trunk\\\\.)', freeze_trunk=False, train_transform=None, val_transform=None, test_transform=None, save_val_preds_path='val_preds/', save_test_preds_path='test_preds/', NMS_args={'window': 10, 'threshold': 0.001}, evaluation_args={}, do_compile=False)[source]

Model to perform spotting.

Parameters:
  • trunk (DictConfig) – Config to build a trunk.

  • head_class (DictConfig) – Config to build a head for classification.

  • optimizer (DictConfig) – Config to build an optimizer for trunk.

  • pretrained_trunk_path (str | None, optional) – Path to the pretrained trunk file.

    Default: None

  • pretrained_path (str | None, optional) – Path to the pretrained model.

    Default: None

  • prediction_args (DictConfig, optional) – Arguments to configure predictions.

    Default: {}

  • loss_fn_args (DictConfig, optional) – Arguments for the loss function.

    Default: {}

  • trunk_pattern (str, optional) – Pattern to retrieve the trunk model in checkpoint state_dict and delete the key.

    Default: '^(trunk\.)'

  • freeze_trunk (bool, optional) – Whether to freeze the trunk.

    Default: False

  • train_transform (DictConfig | None, optional) – Config to perform transformation on train input.

    Default: None

  • val_transform (DictConfig | None, optional) – Config to perform transformation on val input.

    Default: None

  • test_transform (DictConfig | None, optional) – Config to perform transformation on test input.

    Default: None

  • save_val_preds_path (str | Path, optional) – Path to store the validation predictions.

    Default: 'val_preds/'

  • save_test_preds_path (str | Path, optional) – Path to store the test predictions.

    Default: 'test_preds/'

  • NMS_args (DictConfig, optional) – Arguments to configure the NMS.

    Default: {'window': 10, 'threshold': 0.001}

  • evaluation_args (DictConfig, optional) – Arguments to configure the evaluation.

    Default: {}

Example:

trunk = {...} # config to build a trunk
head_class = {...} # config to build a head for classification
optimizer = {...} # config to build an optimizer
pretrained_trunk_path = ... # path where the trunk has been saved

model = SpottingModel(trunk, head_class, optimizer, pretrained_trunk_path)