Segmentation

class torchok.tasks.segmentation.SegmentationTask(hparams: DictConfig, backbone_name: str, head_name: str, neck_name: str, backbone_params: Optional[dict] = None, neck_params: Optional[dict] = None, head_params: Optional[dict] = None, **kwargs)

Bases: BaseTask

A class for image segmentation task.

__init__(hparams: DictConfig, backbone_name: str, head_name: str, neck_name: str, backbone_params: Optional[dict] = None, neck_params: Optional[dict] = None, head_params: Optional[dict] = None, **kwargs)

Init SegmentationTask.

Parameters
  • hparams – Hyperparameters that set in yaml file.

  • backbone_name – name of the backbone architecture in the BACKBONES registry.

  • neck_name – name of the head architecture in the DETECTION_NECKS registry.

  • head_name – name of the neck architecture in the HEADS registry.

  • backbone_params – parameters for backbone constructor.

  • neck_params – parameters for neck constructor. in_channels will be set automatically based on backbone.

  • head_params – parameters for head constructor. in_channels will be set automatically based on neck.

  • inputs – information about input model shapes and dtypes.

forward(x: Tensor) Tensor

Forward method.

Parameters

x – torch.Tensor of shape (B, C, H, W). Batch of input images

Returns

torch.Tensor of shape (B, num_classes, H, W), representing logits masks per each image.

forward_with_gt(batch: Dict[str, Tensor]) Dict[str, Tensor]

Forward with ground truth labels.

Parameters

batch

Dictionary with the following keys and values:

  • image (torch.Tensor):

    tensor of shape (B, C, H, W), representing input images.

  • target (torch.Tensor):

    tensor of shape (B, H, W), target class or labels masks per each image.

Returns

Dictionary with the following keys and values

  • ’prediction’: torch.Tensor of shape (B, num_classes), representing logits masks per each image.

  • ’target’: torch.Tensor of shape (B, H, W), target class or labels masks per each image. May absent.

as_module() Sequential

Method for model representation as sequential of modules (required for checkpointing).