Finetuning#
- class eztorch.models.finetuning.FinetuningModel(trunk, classifier, optimizer, pretrained_trunk_path, trunk_pattern='^(trunk\\\\.)', two_groups=False, freeze_trunk=False, train_transform=None, val_transform=None, test_transform=None, val_time_augmentation=None, test_time_augmentation=None, update_bn_momentum=False, freeze_bn_layers=False)[source]#
Fine-tuning training.
- Parameters:
trunk (
Default:) – Config to build a trunk.classifier (
Default:) – Config to build a classifier.optimizer (
Default:) – Config to build an optimizer for trunk.pretrained_trunk_path (
Default:) – Path to the pretrained trunk file.trunk_pattern (
Default:, optional) – Pattern to retrieve the trunk model in checkpoint state_dict and delete the key.Default:'^(trunk\.)'
two_groups (
Default:, optional) – IfTrue
, use two groups of parameters for optimizer, the trunk and the head.Default:False
freeze_trunk (
Default:, optional) – IfTrue
, freeze the trunk.Default:False
train_transform (
Default:, optional) – Config to perform transformation on train input.Default:None
val_transform (
Default:, optional) – Config to perform transformation on val input.Default:None
test_transform (
Default:, optional) – Config to perform transformation on test input.Default:None
val_time_augmentation (
Default:, optional) – Ensembling method for test time augmentation used at validation.Default:None
test_time_augmentation (
Default:, optional) – Ensembling method for test time augmentation used at test.Default:None
update_bn_momentum (
Default:, optional) – IfTrue
update batch norm statistics according to \(max(1 - 10/steps\_per\_epoch, 0.9)\).Default:False
freeze_bn_layers (
Default:, optional) – IfTrue
, freeze the batch norm layers.Default:False
Example:
trunk = {...} # config to build a trunk classifier = {...} # config to build a classifier optimizer = {...} # config to build an optimizer pretrained_trunk_path = ... # path where the trunk has been saved model = FinetuningModel(trunk, classifier, optimizer, pretrained_trunk_path)
- class eztorch.models.soccernet_spotting.SoccerNetSpottingModel(trunk, head_class, optimizer, pretrained_path=None, pretrained_trunk_path=None, loss_fn_args={}, prediction_args={}, eval_step_timestamp=1.0, trunk_pattern='^(trunk\\\\.)', freeze_trunk=False, train_transform=None, val_transform=None, test_transform=None, save_val_preds_path='val_preds/', save_test_preds_path='test_preds/', NMS_args={'window': 10, 'threshold': 0.001}, evaluation_args={}, do_compile=False)[source]#
Model to perform action spotting.
- Parameters:
trunk (
Default:) – Config to build a trunk.head_class (
Default:) – Config to build a head for classification.optimizer (
Default:) – Config to build an optimizer for trunk.pretrained_trunk_path (
Default:, optional) – Path to the pretrained trunk file.Default:None
pretrained_path (
Default:, optional) – Path to the pretrained model.Default:None
prediction_args (
Default:, optional) – Arguments to configure predictions.Default:{}
loss_fn_args (
Default:, optional) – Arguments for the loss function.Default:{}
eval_step_timestamp (
Default:, optional) – Step between each timestamp.Default:1.0
trunk_pattern (
Default:, optional) – Pattern to retrieve the trunk model in checkpoint state_dict and delete the key.Default:'^(trunk\.)'
freeze_trunk (
Default:, optional) – Whether to freeze the trunk.Default:False
train_transform (
Default:, optional) – Config to perform transformation on train input.Default:None
val_transform (
Default:, optional) – Config to perform transformation on val input.Default:None
test_transform (
Default:, optional) – Config to perform transformation on test input.Default:None
save_val_preds_path (
Default:, optional) – Path to store the validation predictions.Default:'val_preds/'
save_test_preds_path (
Default:, optional) – Path to store the test predictions.Default:'test_preds/'
NMS_args (
Default:, optional) – Arguments to configure the NMS.Default:{'window': 10, 'threshold': 0.001}
evaluation_args (
Default:, optional) – Arguments to configure the evaluation.Default:{}
Example:
trunk = {...} # config to build a trunk head_class = {...} # config to build a head for classification optimizer = {...} # config to build an optimizer pretrained_trunk_path = ... # path where the trunk has been saved model = SoccerNetSpottingModel(trunk, head_class, optimizer, pretrained_trunk_path)
- class eztorch.models.spotting.SpottingModel(trunk, head_class, optimizer, pretrained_path=None, pretrained_trunk_path=None, loss_fn_args={}, prediction_args={}, trunk_pattern='^(trunk\\\\.)', freeze_trunk=False, train_transform=None, val_transform=None, test_transform=None, save_val_preds_path='val_preds/', save_test_preds_path='test_preds/', NMS_args={'window': 10, 'threshold': 0.001}, evaluation_args={}, do_compile=False)[source]#
Model to perform spotting.
- Parameters:
trunk (
Default:) – Config to build a trunk.head_class (
Default:) – Config to build a head for classification.optimizer (
Default:) – Config to build an optimizer for trunk.pretrained_trunk_path (
Default:, optional) – Path to the pretrained trunk file.Default:None
pretrained_path (
Default:, optional) – Path to the pretrained model.Default:None
prediction_args (
Default:, optional) – Arguments to configure predictions.Default:{}
loss_fn_args (
Default:, optional) – Arguments for the loss function.Default:{}
trunk_pattern (
Default:, optional) – Pattern to retrieve the trunk model in checkpoint state_dict and delete the key.Default:'^(trunk\.)'
freeze_trunk (
Default:, optional) – Whether to freeze the trunk.Default:False
train_transform (
Default:, optional) – Config to perform transformation on train input.Default:None
val_transform (
Default:, optional) – Config to perform transformation on val input.Default:None
test_transform (
Default:, optional) – Config to perform transformation on test input.Default:None
save_val_preds_path (
Default:, optional) – Path to store the validation predictions.Default:'val_preds/'
save_test_preds_path (
Default:, optional) – Path to store the test predictions.Default:'test_preds/'
NMS_args (
Default:, optional) – Arguments to configure the NMS.Default:{'window': 10, 'threshold': 0.001}
evaluation_args (
Default:, optional) – Arguments to configure the evaluation.Default:{}
Example:
trunk = {...} # config to build a trunk head_class = {...} # config to build a head for classification optimizer = {...} # config to build an optimizer pretrained_trunk_path = ... # path where the trunk has been saved model = SpottingModel(trunk, head_class, optimizer, pretrained_trunk_path)