ONNX inference
- class torchok.tasks.onnx.ONNXTask(hparams: DictConfig, path_to_onnx: str, providers, **kwargs)
Bases:
BaseTaskA class for onnx task.
- str_type2numpy_type = {'tensor(double)': 'float64', 'tensor(float)': 'float32', 'tensor(float16)': 'float16', 'tensor(int16)': 'int16', 'tensor(int32)': 'int32', 'tensor(int64)': 'int64', 'tensor(int8)': 'int8', 'tensor(uint8)': 'uint8'}
- __init__(hparams: DictConfig, path_to_onnx: str, providers, **kwargs)
Init ONNXTask.
- Parameters
hparams – Hyperparameters that set in yaml file.
path_to_onnx – path to ONNX model file.
providers – Optional sequence of providers in order of decreasing precedence. Values can either be provider names or tuples of (provider name, options dict). If not provided, then all available providers are used with the default precedence.
- forward(x: Tensor) Tensor
- forward_with_gt(batch: Dict[str, Any]) Dict[str, Tensor]
- as_module() Sequential
- foward_infer(inputs: Dict[str, Tensor]) Dict[str, Tensor]
Forward onnx model.
- forward_infer_with_gt(batch: Dict[str, Any]) Dict[str, Tensor]
Forward method for test stage.
- test_step(batch: Dict[str, Union[Tensor, int]], batch_idx: int) None
Complete test loop.
- predict_step(batch: Dict[str, Any], batch_idx: int) Tensor
Complete predict loop.
- training: bool
- precision: Union[int, str]
- prepare_data_per_node: bool
- allow_zero_length_dataloader_with_multiple_devices: bool