Losses

TorchOk supports all loss functions from PyTorch as well as some customized loss functions.

Customized loss functions

Classification

class torchok.losses.classification.binary_cross_entropy.BCEWithLogitsLoss(weight: Optional[Tensor] = None, reduction: str = 'mean', pos_weight: Optional[Union[str, list]] = None, ignore_index: int = - 1)

Bases: BCEWithLogitsLoss

BCEWithLogitsLoss with ability to load pos_weights from json file (dict) or config (list).

__init__(weight: Optional[Tensor] = None, reduction: str = 'mean', pos_weight: Optional[Union[str, list]] = None, ignore_index: int = - 1)

BCEWithLogitsLossX init.

Parameters
  • weight – A manual rescaling weight given to the loss of each batch element. If given, has to be a Tensor of size nbatch.

  • reduction – Specifies the reduction to apply to the output: ‘none’ | ‘mean’ | ‘sum’. ‘none’: no reduction will be applied, ‘mean’: the sum of the output will be divided by the number of elements in the output, ‘sum’: the output will be summed.

  • pos_weight – A weight of positive examples. Must be a vector with length equal to the number of classes. Can be string - json file with keys - class index and value - weight, or can be list - readable from yaml config with the weights.

  • ignore_index – index that will be ignored during training

forward(input, target)
reduction: str

Segmentation

TorchOK DiceLoss module. Adapted from https://github.com/BloodAxe/pytorch-toolbelt/blob/develop/pytorch_toolbelt/losses/dice.py Copyright (c) Eugene Khvedchenya Licensed under The MIT License [see LICENSE for details]

class torchok.losses.segmentation.dice.DiceLoss(mode: str, classes: Optional[List[int]] = None, log_loss: bool = False, from_logits: bool = True, smooth: float = 0, eps: float = 1e-07)

Bases: Module

Implementation of Dice loss for image segmentation task. It supports binary, multiclass and multilabel cases

__init__(mode: str, classes: Optional[List[int]] = None, log_loss: bool = False, from_logits: bool = True, smooth: float = 0, eps: float = 1e-07)

Init DiceLoss.

Parameters
  • mode – Metric mode {‘binary’, ‘multiclass’, ‘multilabel’}.

  • classes – Optional list of classes that contribute in loss computation. By default, all channels are included.

  • log_loss – If True, loss computed as -log(jaccard); otherwise 1 - jaccard.

  • from_logits – If True assumes input is raw logits.

  • smooth – Smooth value.

  • eps – Small epsilon for numerical stability.

Raises
  • ValueError – If mode not in set={‘binary’, ‘multiclass’, ‘multilabel’}.

  • ValueError – If classes parameter is not None in binary mode.

forward(input: Tensor, target: Tensor) Tensor

Forward method for Dice loss.

Parameters
  • input – NxCxHxW if mode is multiclass or multilabel and NxHxW if mode is binary

  • target – NxHxW

Returns

Dice loss value - scalar.

Return type

loss

Raises

ValueError – If shape input tensor not match with target.

training: bool

Representation

class torchok.losses.representation.pairwise.ContrastiveLoss(margin: float, reg: Optional[str] = None, reduction: Optional[str] = 'mean', eps: Optional[float] = 0.001)

Bases: GeneralPairWeightingLoss

Contrastive loss See base class documentation for more details

calc_loss(emb1: Tensor, emb2: Tensor, R: Tensor) Tensor

See documentation of base forward method

training: bool

Detection

At the moment, we support all mmdetection loss functions <https://github.com/open-mmlab/mmdetection/tree/master/mmdet/models/losses>. They might be accessed via the prefix MM_*, for example, MM_FocalLoss.