Finetuning¶
- class eztorch.models.finetuning.FinetuningModel(trunk, classifier, optimizer, pretrained_trunk_path, trunk_pattern='^(trunk\\\\.)', two_groups=False, freeze_trunk=False, train_transform=None, val_transform=None, test_transform=None, val_time_augmentation=None, test_time_augmentation=None, update_bn_momentum=False, freeze_bn_layers=False)[source]¶
Fine-tuning training.
- Parameters:
trunk (
DictConfig
) – Config to build a trunk.classifier (
DictConfig
) – Config to build a classifier.optimizer (
DictConfig
) – Config to build an optimizer for trunk.pretrained_trunk_path (
str
) – Path to the pretrained trunk file.trunk_pattern (
str
, optional) – Pattern to retrieve the trunk model in checkpoint state_dict and delete the key.Default:'^(trunk\.)'
two_groups (
bool
, optional) – IfTrue
, use two groups of parameters for optimizer, the trunk and the head.Default:False
freeze_trunk (
bool
, optional) – IfTrue
, freeze the trunk.Default:False
train_transform (
Optional
[DictConfig
], optional) – Config to perform transformation on train input.Default:None
val_transform (
Optional
[DictConfig
], optional) – Config to perform transformation on val input.Default:None
test_transform (
Optional
[DictConfig
], optional) – Config to perform transformation on test input.Default:None
val_time_augmentation (
Optional
[DictConfig
], optional) – Ensembling method for test time augmentation used at validation.Default:None
test_time_augmentation (
Optional
[DictConfig
], optional) – Ensembling method for test time augmentation used at test.Default:None
update_bn_momentum (
bool
, optional) – IfTrue
update batch norm statistics according to \(max(1 - 10/steps\_per\_epoch, 0.9)\).Default:False
freeze_bn_layers (
bool
, optional) – IfTrue
, freeze the batch norm layers.Default:False
Example:
trunk = {...} # config to build a trunk classifier = {...} # config to build a classifier optimizer = {...} # config to build an optimizer pretrained_trunk_path = ... # path where the trunk has been saved model = FinetuningModel(trunk, classifier, optimizer, pretrained_trunk_path)
- class eztorch.models.soccernet_spotting.SoccerNetSpottingModel(trunk, head_class, optimizer, pretrained_path=None, pretrained_trunk_path=None, loss_fn_args={}, prediction_args={}, eval_step_timestamp=1.0, trunk_pattern='^(trunk\\\\.)', freeze_trunk=False, train_transform=None, val_transform=None, test_transform=None, save_val_preds_path='val_preds/', save_test_preds_path='test_preds/', NMS_args={'window': 10, 'threshold': 0.001}, evaluation_args={}, do_compile=False)[source]¶
Model to perform action spotting.
- Parameters:
trunk (
DictConfig
) – Config to build a trunk.head_class (
DictConfig
) – Config to build a head for classification.optimizer (
DictConfig
) – Config to build an optimizer for trunk.pretrained_trunk_path (
str
|None
, optional) – Path to the pretrained trunk file.Default:None
pretrained_path (
str
|None
, optional) – Path to the pretrained model.Default:None
prediction_args (
DictConfig
, optional) – Arguments to configure predictions.Default:{}
loss_fn_args (
DictConfig
, optional) – Arguments for the loss function.Default:{}
eval_step_timestamp (
float
, optional) – Step between each timestamp.Default:1.0
trunk_pattern (
str
, optional) – Pattern to retrieve the trunk model in checkpoint state_dict and delete the key.Default:'^(trunk\.)'
freeze_trunk (
bool
, optional) – Whether to freeze the trunk.Default:False
train_transform (
DictConfig
|None
, optional) – Config to perform transformation on train input.Default:None
val_transform (
DictConfig
|None
, optional) – Config to perform transformation on val input.Default:None
test_transform (
DictConfig
|None
, optional) – Config to perform transformation on test input.Default:None
save_val_preds_path (
str
|Path
, optional) – Path to store the validation predictions.Default:'val_preds/'
save_test_preds_path (
str
|Path
, optional) – Path to store the test predictions.Default:'test_preds/'
NMS_args (
DictConfig
, optional) – Arguments to configure the NMS.Default:{'window': 10, 'threshold': 0.001}
evaluation_args (
DictConfig
, optional) – Arguments to configure the evaluation.Default:{}
Example:
trunk = {...} # config to build a trunk head_class = {...} # config to build a head for classification optimizer = {...} # config to build an optimizer pretrained_trunk_path = ... # path where the trunk has been saved model = SoccerNetSpottingModel(trunk, head_class, optimizer, pretrained_trunk_path)
- class eztorch.models.spotting.SpottingModel(trunk, head_class, optimizer, pretrained_path=None, pretrained_trunk_path=None, loss_fn_args={}, prediction_args={}, trunk_pattern='^(trunk\\\\.)', freeze_trunk=False, train_transform=None, val_transform=None, test_transform=None, save_val_preds_path='val_preds/', save_test_preds_path='test_preds/', NMS_args={'window': 10, 'threshold': 0.001}, evaluation_args={}, do_compile=False)[source]¶
Model to perform spotting.
- Parameters:
trunk (
DictConfig
) – Config to build a trunk.head_class (
DictConfig
) – Config to build a head for classification.optimizer (
DictConfig
) – Config to build an optimizer for trunk.pretrained_trunk_path (
str
|None
, optional) – Path to the pretrained trunk file.Default:None
pretrained_path (
str
|None
, optional) – Path to the pretrained model.Default:None
prediction_args (
DictConfig
, optional) – Arguments to configure predictions.Default:{}
loss_fn_args (
DictConfig
, optional) – Arguments for the loss function.Default:{}
trunk_pattern (
str
, optional) – Pattern to retrieve the trunk model in checkpoint state_dict and delete the key.Default:'^(trunk\.)'
freeze_trunk (
bool
, optional) – Whether to freeze the trunk.Default:False
train_transform (
DictConfig
|None
, optional) – Config to perform transformation on train input.Default:None
val_transform (
DictConfig
|None
, optional) – Config to perform transformation on val input.Default:None
test_transform (
DictConfig
|None
, optional) – Config to perform transformation on test input.Default:None
save_val_preds_path (
str
|Path
, optional) – Path to store the validation predictions.Default:'val_preds/'
save_test_preds_path (
str
|Path
, optional) – Path to store the test predictions.Default:'test_preds/'
NMS_args (
DictConfig
, optional) – Arguments to configure the NMS.Default:{'window': 10, 'threshold': 0.001}
evaluation_args (
DictConfig
, optional) – Arguments to configure the evaluation.Default:{}
Example:
trunk = {...} # config to build a trunk head_class = {...} # config to build a head for classification optimizer = {...} # config to build an optimizer pretrained_trunk_path = ... # path where the trunk has been saved model = SpottingModel(trunk, head_class, optimizer, pretrained_trunk_path)