ONNX inference

class torchok.tasks.onnx.ONNXTask(hparams: DictConfig, path_to_onnx: str, providers, **kwargs)

Bases: BaseTask

A class for onnx task.

str_type2numpy_type = {'tensor(double)': 'float64', 'tensor(float)': 'float32', 'tensor(float16)': 'float16', 'tensor(int16)': 'int16', 'tensor(int32)': 'int32', 'tensor(int64)': 'int64', 'tensor(int8)': 'int8', 'tensor(uint8)': 'uint8'}
__init__(hparams: DictConfig, path_to_onnx: str, providers, **kwargs)

Init ONNXTask.

Parameters
  • hparams – Hyperparameters that set in yaml file.

  • path_to_onnx – path to ONNX model file.

  • providers – Optional sequence of providers in order of decreasing precedence. Values can either be provider names or tuples of (provider name, options dict). If not provided, then all available providers are used with the default precedence.

forward(x: Tensor) Tensor
forward_with_gt(batch: Dict[str, Any]) Dict[str, Tensor]
as_module() Sequential
foward_infer(inputs: Dict[str, Tensor]) Dict[str, Tensor]

Forward onnx model.

forward_infer_with_gt(batch: Dict[str, Any]) Dict[str, Tensor]

Forward method for test stage.

test_step(batch: Dict[str, Union[Tensor, int]], batch_idx: int) None

Complete test loop.

predict_step(batch: Dict[str, Any], batch_idx: int) Tensor

Complete predict loop.

training: bool
precision: Union[int, str]
prepare_data_per_node: bool
allow_zero_length_dataloader_with_multiple_devices: bool