Create your own Task
- class torchok.tasks.base.BaseTask(hparams: DictConfig, inputs=None, **kwargs)
Bases:
LightningModule,ABCAn abstract class that represent main methods of tasks.
- __init__(hparams: DictConfig, inputs=None, **kwargs)
Init BaseTask.
- Parameters
hparams – Hyperparameters that set in yaml file.
inputs – information about input model shapes and dtypes.
- abstract forward(*args, **kwargs) Tensor
Abstract forward method for validation and test.
- abstract forward_with_gt(batch: Dict[str, Any]) Dict[str, Tensor]
Abstract forward method for training(with ground truth labels).
- configure_optimizers() List[Dict[str, Union[Optimizer, Dict[str, Any]]]]
Configure optimizers.
- train_dataloader() Optional[List[DataLoader]]
Implement one or more PyTorch DataLoaders for training.
- val_dataloader() Optional[List[DataLoader]]
Implement one or multiple PyTorch DataLoaders for prediction.
- test_dataloader() Optional[List[DataLoader]]
Implement one or multiple PyTorch DataLoaders for testing.
- predict_dataloader() Optional[List[DataLoader]]
Implement one or multiple PyTorch DataLoaders for prediction.
- on_fit_start() None
- on_test_start() None
- on_predict_start() None
- training_step(batch: Dict[str, Union[Tensor, int]], batch_idx: int) Dict[str, Tensor]
Complete training loop.
- validation_step(batch: Dict[str, Union[Tensor, int]], batch_idx: int, dataloader_idx: int = 0) Dict[str, Tensor]
Complete validation loop.
- test_step(batch: Dict[str, Tensor], batch_idx: int, dataloader_idx: int = 0) None
Complete test loop.
- predict_step(batch: Dict[str, Tensor], batch_idx: int, dataloader_idx: int = 0) Dict[str, Tensor]
Complete predict loop.
- on_train_batch_end(outputs: Dict[str, Tensor], batch: Dict[str, Tensor], batch_idx: int, dataloader_idx: int = 0) Dict[str, Tensor]
- on_validation_batch_end(outputs: Dict[str, Tensor], batch: Dict[str, Tensor], batch_idx: int, dataloader_idx: int = 0) Dict[str, Tensor]
- on_train_epoch_end() None
It’s calling at the end of the training epoch with the outputs of all training steps.
- on_validation_epoch_end() None
It’s calling at the end of the validation epoch with the outputs of all validation steps.
- on_test_epoch_end() None
It’s calling at the end of a test epoch with the output of all test steps.
- abstract as_module() Sequential
Abstract method for model representation as sequential of modules(need for checkpointing).