Callbacks

class torchok.callbacks.checkpoint_onnx.CheckpointONNX(*args, onnx_params=None, remove_head=False, **kwargs)

Bases: ModelCheckpoint

A class checkpointing onnx format.

ONNX_EXTENSION = '.onnx'
__init__(*args, onnx_params=None, remove_head=False, **kwargs)

Init CheckpointONNX.

class torchok.callbacks.finalize_logger.FinalizeLogger

Bases: Callback

Callback to finalize logger if an error occurs

on_exception(trainer, pl_module, outputs)
class torchok.callbacks.freeze_unfreeze.FreezeUnfreeze(freeze_modules: List[Dict[str, Any]], top_down_freeze_order: bool = True)

Bases: BaseFinetuning

Callback to freeze modules and incremental unfreeze these modules during training.

__init__(freeze_modules: List[Dict[str, Any]], top_down_freeze_order: bool = True)

Init FreezeUnfreeze.

Parameters
  • freeze_modules

    List with dictionaries of models to be frozen-unfrozen with possible keys of dictionaries:

    • module_name (str):

      module name relative to the task on which freeze will be applied. For example backbone.layer1. Empty string in the module_name stands for the whole model.

    • epoch (int, optional):

      number of epochs when module to be frozen. If not specified then module will be frozen forever.

    • stages (int, optional):

      if specified with module_name that has get_stage(int) attribute, apply freeze only on modules returned from get_stage(int). Usually used with a backbone: stage 0 refers to stem layer, stage i refers to i-1 model layer block and all preceding blocks. If not specified, all blocks will be frozen.

    • module_class (str, optional):

      if specified apply freeze only on the modules of the same type or successors of specified type.

    • bn_requires_grad (bool, optional):

      if specified batch norms’ gradients will be set up separately from other blocks. If not specified processed as the other modules.

    • bn_track_running_stats (bool, optional):

      if specified batch norms train mode will be set up separately from other blocks. If not specified processed as the other modules.

  • top_down_freeze_order – If true freeze policy will be applied firstly on top modules, e.g. aa > aa.bb. In this case freeze policy aa.bb will overwrite freeze policy in aa related to aa.bb. Otherwise, freeze policy for bottom layers will be applied first.

static make_trainable(modules: Union[Module, Iterable[Union[Module, Iterable]]]) None

Unfreezes the parameters of the provided modules.

Parameters

modules – A given module or an iterable of modules

static freeze(modules: Union[Module, Iterable[Union[Module, Iterable]]], module_dict: Dict[str, Any]) None

Freezes the parameters of the provided modules.

Parameters
  • modules – A given module or an iterable of modules

  • module_dict – If True, leave the BatchNorm layers in training mode

static unfreeze_and_add_param_group(modules: Union[Module, Iterable[Union[Module, Iterable]]], optimizer: Optimizer, lr: Optional[float] = None, initial_denom_lr: float = 10.0, train_bn: bool = True) None

Unfreezes a module and adds its parameters to an optimizer.

Parameters
  • modules – A module or iterable of modules to unfreeze. Their parameters will be added to an optimizer as a new param group.

  • optimizer – The provided optimizer will receive new parameters and will add them to add_param_group

  • lr – Learning rate for the new param group.

  • initial_denom_lr – If no lr is provided, the learning from the first param group will be used and divided by initial_denom_lr.

  • train_bn – Whether to train the BatchNormalization layers.

freeze_before_training(pl_module: Module)

Freeze modules before training.

Freeze all the modules named in self.epoch2module_names.

Parameters

pl_module – Module which contain unfreeze modules.

finetune_function(pl_module: Module, current_epoch: int, optimizer: Optimizer, optimizer_idx: int)

Unfreeze modules from self.epoch2module_names dictionary.

Parameters
  • pl_module – Module which contain unfreeze modules.

  • current_epoch – Current epoch.

  • optimizer – Optimizer.

  • optimizer_idx – Optimizer index.