Finetuning¶
- class eztorch.models.finetuning.FinetuningModel(trunk, classifier, optimizer, pretrained_trunk_path, trunk_pattern='^(trunk\\\\.)', two_groups=False, freeze_trunk=False, train_transform=None, val_transform=None, test_transform=None, val_time_augmentation=None, test_time_augmentation=None, update_bn_momentum=False, freeze_bn_layers=False)[source]¶
- Fine-tuning training. - Parameters:
- trunk ( - DictConfig) – Config to build a trunk.
- classifier ( - DictConfig) – Config to build a classifier.
- optimizer ( - DictConfig) – Config to build an optimizer for trunk.
- pretrained_trunk_path ( - str) – Path to the pretrained trunk file.
- trunk_pattern ( - str, optional) – Pattern to retrieve the trunk model in checkpoint state_dict and delete the key.Default:- '^(trunk\.)'- two_groups ( - bool, optional) – If- True, use two groups of parameters for optimizer, the trunk and the head.Default:- False- freeze_trunk ( - bool, optional) – If- True, freeze the trunk.Default:- False- train_transform ( - Optional[- DictConfig], optional) – Config to perform transformation on train input.Default:- None- val_transform ( - Optional[- DictConfig], optional) – Config to perform transformation on val input.Default:- None- test_transform ( - Optional[- DictConfig], optional) – Config to perform transformation on test input.Default:- None- val_time_augmentation ( - Optional[- DictConfig], optional) – Ensembling method for test time augmentation used at validation.Default:- None- test_time_augmentation ( - Optional[- DictConfig], optional) – Ensembling method for test time augmentation used at test.Default:- None- update_bn_momentum ( - bool, optional) – If- Trueupdate batch norm statistics according to \(max(1 - 10/steps\_per\_epoch, 0.9)\).Default:- False- freeze_bn_layers ( - bool, optional) – If- True, freeze the batch norm layers.Default:- False- Example: - trunk = {...} # config to build a trunk classifier = {...} # config to build a classifier optimizer = {...} # config to build an optimizer pretrained_trunk_path = ... # path where the trunk has been saved model = FinetuningModel(trunk, classifier, optimizer, pretrained_trunk_path) - class eztorch.models.soccernet_spotting.SoccerNetSpottingModel(trunk, head_class, optimizer, pretrained_path=None, pretrained_trunk_path=None, loss_fn_args={}, prediction_args={}, eval_step_timestamp=1.0, trunk_pattern='^(trunk\\\\.)', freeze_trunk=False, train_transform=None, val_transform=None, test_transform=None, save_val_preds_path='val_preds/', save_test_preds_path='test_preds/', NMS_args={'window': 10, 'threshold': 0.001}, evaluation_args={}, do_compile=False)[source]¶
- Model to perform action spotting. - Parameters:
- trunk ( - DictConfig) – Config to build a trunk.
- head_class ( - DictConfig) – Config to build a head for classification.
- optimizer ( - DictConfig) – Config to build an optimizer for trunk.
- pretrained_trunk_path ( - str|- None, optional) – Path to the pretrained trunk file.Default:- None- pretrained_path ( - str|- None, optional) – Path to the pretrained model.Default:- None- prediction_args ( - DictConfig, optional) – Arguments to configure predictions.Default:- {}- loss_fn_args ( - DictConfig, optional) – Arguments for the loss function.Default:- {}- eval_step_timestamp ( - float, optional) – Step between each timestamp.Default:- 1.0- trunk_pattern ( - str, optional) – Pattern to retrieve the trunk model in checkpoint state_dict and delete the key.Default:- '^(trunk\.)'- freeze_trunk ( - bool, optional) – Whether to freeze the trunk.Default:- False- train_transform ( - DictConfig|- None, optional) – Config to perform transformation on train input.Default:- None- val_transform ( - DictConfig|- None, optional) – Config to perform transformation on val input.Default:- None- test_transform ( - DictConfig|- None, optional) – Config to perform transformation on test input.Default:- None- save_val_preds_path ( - str|- Path, optional) – Path to store the validation predictions.Default:- 'val_preds/'- save_test_preds_path ( - str|- Path, optional) – Path to store the test predictions.Default:- 'test_preds/'- NMS_args ( - DictConfig, optional) – Arguments to configure the NMS.Default:- {'window': 10, 'threshold': 0.001}- evaluation_args ( - DictConfig, optional) – Arguments to configure the evaluation.Default:- {}- Example: - trunk = {...} # config to build a trunk head_class = {...} # config to build a head for classification optimizer = {...} # config to build an optimizer pretrained_trunk_path = ... # path where the trunk has been saved model = SoccerNetSpottingModel(trunk, head_class, optimizer, pretrained_trunk_path) - class eztorch.models.spotting.SpottingModel(trunk, head_class, optimizer, pretrained_path=None, pretrained_trunk_path=None, loss_fn_args={}, prediction_args={}, trunk_pattern='^(trunk\\\\.)', freeze_trunk=False, train_transform=None, val_transform=None, test_transform=None, save_val_preds_path='val_preds/', save_test_preds_path='test_preds/', NMS_args={'window': 10, 'threshold': 0.001}, evaluation_args={}, do_compile=False)[source]¶
- Model to perform spotting. - Parameters:
- trunk ( - DictConfig) – Config to build a trunk.
- head_class ( - DictConfig) – Config to build a head for classification.
- optimizer ( - DictConfig) – Config to build an optimizer for trunk.
- pretrained_trunk_path ( - str|- None, optional) – Path to the pretrained trunk file.Default:- None- pretrained_path ( - str|- None, optional) – Path to the pretrained model.Default:- None- prediction_args ( - DictConfig, optional) – Arguments to configure predictions.Default:- {}- loss_fn_args ( - DictConfig, optional) – Arguments for the loss function.Default:- {}- trunk_pattern ( - str, optional) – Pattern to retrieve the trunk model in checkpoint state_dict and delete the key.Default:- '^(trunk\.)'- freeze_trunk ( - bool, optional) – Whether to freeze the trunk.Default:- False- train_transform ( - DictConfig|- None, optional) – Config to perform transformation on train input.Default:- None- val_transform ( - DictConfig|- None, optional) – Config to perform transformation on val input.Default:- None- test_transform ( - DictConfig|- None, optional) – Config to perform transformation on test input.Default:- None- save_val_preds_path ( - str|- Path, optional) – Path to store the validation predictions.Default:- 'val_preds/'- save_test_preds_path ( - str|- Path, optional) – Path to store the test predictions.Default:- 'test_preds/'- NMS_args ( - DictConfig, optional) – Arguments to configure the NMS.Default:- {'window': 10, 'threshold': 0.001}- evaluation_args ( - DictConfig, optional) – Arguments to configure the evaluation.Default:- {}- Example: - trunk = {...} # config to build a trunk head_class = {...} # config to build a head for classification optimizer = {...} # config to build an optimizer pretrained_trunk_path = ... # path where the trunk has been saved model = SpottingModel(trunk, head_class, optimizer, pretrained_trunk_path)