Losses
TorchOk supports all loss functions from PyTorch as well as some customized loss functions.
Customized loss functions
Classification
- class torchok.losses.classification.binary_cross_entropy.BCEWithLogitsLoss(weight: Optional[Tensor] = None, reduction: str = 'mean', pos_weight: Optional[Union[str, list]] = None, ignore_index: int = - 1)
Bases:
BCEWithLogitsLossBCEWithLogitsLoss with ability to load pos_weights from json file (dict) or config (list).
- __init__(weight: Optional[Tensor] = None, reduction: str = 'mean', pos_weight: Optional[Union[str, list]] = None, ignore_index: int = - 1)
BCEWithLogitsLossX init.
- Parameters
weight – A manual rescaling weight given to the loss of each batch element. If given, has to be a Tensor of size nbatch.
reduction – Specifies the reduction to apply to the output: ‘none’ | ‘mean’ | ‘sum’. ‘none’: no reduction will be applied, ‘mean’: the sum of the output will be divided by the number of elements in the output, ‘sum’: the output will be summed.
pos_weight – A weight of positive examples. Must be a vector with length equal to the number of classes. Can be string - json file with keys - class index and value - weight, or can be list - readable from yaml config with the weights.
ignore_index – index that will be ignored during training
- forward(input, target)
- reduction: str
Segmentation
TorchOK DiceLoss module. Adapted from https://github.com/BloodAxe/pytorch-toolbelt/blob/develop/pytorch_toolbelt/losses/dice.py Copyright (c) Eugene Khvedchenya Licensed under The MIT License [see LICENSE for details]
- class torchok.losses.segmentation.dice.DiceLoss(mode: str, classes: Optional[List[int]] = None, log_loss: bool = False, from_logits: bool = True, smooth: float = 0, eps: float = 1e-07)
Bases:
ModuleImplementation of Dice loss for image segmentation task. It supports binary, multiclass and multilabel cases
- __init__(mode: str, classes: Optional[List[int]] = None, log_loss: bool = False, from_logits: bool = True, smooth: float = 0, eps: float = 1e-07)
Init DiceLoss.
- Parameters
mode – Metric mode {‘binary’, ‘multiclass’, ‘multilabel’}.
classes – Optional list of classes that contribute in loss computation. By default, all channels are included.
log_loss – If True, loss computed as -log(jaccard); otherwise 1 - jaccard.
from_logits – If True assumes input is raw logits.
smooth – Smooth value.
eps – Small epsilon for numerical stability.
- Raises
ValueError – If mode not in set={‘binary’, ‘multiclass’, ‘multilabel’}.
ValueError – If classes parameter is not None in binary mode.
- forward(input: Tensor, target: Tensor) Tensor
Forward method for Dice loss.
- Parameters
input – NxCxHxW if mode is multiclass or multilabel and NxHxW if mode is binary
target – NxHxW
- Returns
Dice loss value - scalar.
- Return type
loss
- Raises
ValueError – If shape input tensor not match with target.
- training: bool
Representation
- class torchok.losses.representation.pairwise.ContrastiveLoss(margin: float, reg: Optional[str] = None, reduction: Optional[str] = 'mean', eps: Optional[float] = 0.001)
Bases:
GeneralPairWeightingLossContrastive loss See base class documentation for more details
- calc_loss(emb1: Tensor, emb2: Tensor, R: Tensor) Tensor
See documentation of base forward method
- training: bool
Detection
At the moment, we support all mmdetection loss functions <https://github.com/open-mmlab/mmdetection/tree/master/mmdet/models/losses>. They might be accessed via the prefix MM_*, for example, MM_FocalLoss.