Classification
- class torchok.tasks.classification.ClassificationTask(hparams: DictConfig, backbone_name: str, neck_name: Optional[str] = None, pooling_name: Optional[str] = None, head_name: Optional[str] = None, backbone_params: Optional[dict] = None, neck_params: Optional[dict] = None, pooling_params: Optional[dict] = None, head_params: Optional[dict] = None, inputs: Optional[dict] = None)
Bases:
BaseTaskA class for image classification task.
- __init__(hparams: DictConfig, backbone_name: str, neck_name: Optional[str] = None, pooling_name: Optional[str] = None, head_name: Optional[str] = None, backbone_params: Optional[dict] = None, neck_params: Optional[dict] = None, pooling_params: Optional[dict] = None, head_params: Optional[dict] = None, inputs: Optional[dict] = None)
Init ClassificationTask.
- Parameters
hparams – Hyperparameters that set in yaml file.
backbone_name – name of the backbone architecture in the BACKBONES registry.
pooling_name – name of the backbone architecture in the POOLINGS registry.
head_name – name of the neck architecture in the HEADS registry.
neck_name – if present, name of the head architecture in the NECKS registry. Otherwise, model will be created without neck.
backbone_params – parameters for backbone constructor.
neck_params – parameters for neck constructor. in_channels will be set automatically based on backbone.
pooling_params – parameters for neck constructor. in_channels will be set automatically based on neck or backbone if neck is absent.
head_params – parameters for head constructor. in_channels will be set automatically based on neck.
inputs – information about input model shapes and dtypes.
- forward(x: Tensor) Tensor
Forward method.
- Parameters
x – torch.Tensor of shape (B, C, H, W). Batch of input images.
- Returns
torch.Tensor of shape (B, num_classes), representing logits per each image.
- forward_with_gt(batch: Dict[str, Tensor]) Dict[str, Tensor]
Forward with ground truth labels.
- Parameters
batch –
Dictionary with the following keys and values:
- image (torch.Tensor):
tensor of shape (B, C, H, W), representing input images.
- target (torch.Tensor):
tensor of shape (B), target class or labels per each image.
- Returns
Dictionary with the following keys and values
’embeddings’: torch.Tensor of shape (B, num_features), representing embeddings per each image.
’prediction’: torch.Tensor of shape (B, num_classes), representing logits per each image.
’target’: torch.Tensor of shape (B), target class or labels per each image. May absent.
- as_module() Sequential
Method for model representation as sequential of modules(need for onnx checkpointing).