Create your own Task

class torchok.tasks.base.BaseTask(hparams: DictConfig, inputs=None, **kwargs)

Bases: LightningModule, ABC

An abstract class that represent main methods of tasks.

__init__(hparams: DictConfig, inputs=None, **kwargs)

Init BaseTask.

Parameters
  • hparams – Hyperparameters that set in yaml file.

  • inputs – information about input model shapes and dtypes.

abstract forward(*args, **kwargs) Tensor

Abstract forward method for validation and test.

abstract forward_with_gt(batch: Dict[str, Any]) Dict[str, Tensor]

Abstract forward method for training(with ground truth labels).

configure_optimizers() List[Dict[str, Union[Optimizer, Dict[str, Any]]]]

Configure optimizers.

train_dataloader() Optional[List[DataLoader]]

Implement one or more PyTorch DataLoaders for training.

val_dataloader() Optional[List[DataLoader]]

Implement one or multiple PyTorch DataLoaders for prediction.

test_dataloader() Optional[List[DataLoader]]

Implement one or multiple PyTorch DataLoaders for testing.

predict_dataloader() Optional[List[DataLoader]]

Implement one or multiple PyTorch DataLoaders for prediction.

on_fit_start() None
on_test_start() None
on_predict_start() None
training_step(batch: Dict[str, Union[Tensor, int]], batch_idx: int) Dict[str, Tensor]

Complete training loop.

validation_step(batch: Dict[str, Union[Tensor, int]], batch_idx: int, dataloader_idx: int = 0) Dict[str, Tensor]

Complete validation loop.

test_step(batch: Dict[str, Tensor], batch_idx: int, dataloader_idx: int = 0) None

Complete test loop.

predict_step(batch: Dict[str, Tensor], batch_idx: int, dataloader_idx: int = 0) Dict[str, Tensor]

Complete predict loop.

on_train_batch_end(outputs: Dict[str, Tensor], batch: Dict[str, Tensor], batch_idx: int, dataloader_idx: int = 0) Dict[str, Tensor]
on_validation_batch_end(outputs: Dict[str, Tensor], batch: Dict[str, Tensor], batch_idx: int, dataloader_idx: int = 0) Dict[str, Tensor]
on_train_epoch_end() None

It’s calling at the end of the training epoch with the outputs of all training steps.

on_validation_epoch_end() None

It’s calling at the end of the validation epoch with the outputs of all validation steps.

on_test_epoch_end() None

It’s calling at the end of a test epoch with the output of all test steps.

abstract as_module() Sequential

Abstract method for model representation as sequential of modules(need for checkpointing).