Skip to content

base

torch_to_nnef.inference_target.base

InferenceTarget

InferenceTarget(version: T.Union[SemanticVersion, str], check_io: bool = False)

Base abstract class to implement a new inference engine target.

Init InferenceTarget.

Each inference engine is supposed to have at least a version and a way to check output given an input.

has_dynamic_axes property
has_dynamic_axes: bool

Define if user request dynamic axes to be in the NNEF graph.

Some inference engines may not support it hence False by default.

post_export
post_export(model: nn.Module, nnef_graph: NGraph, args: T.List[T.Any], exported_filepath: Path, debug_bundle_path: T.Optional[Path] = None)

Get called after NNEF model asset is generated.

This is typically where check_io is effectively applied.

post_trace
post_trace(nnef_graph: NGraph, active_custom_extensions: T.List[str])

Get called just after PyTorch graph is parsed.

pre_trace
pre_trace(model: nn.Module, input_names: T.Optional[T.List[str]], output_names: T.Optional[T.List[str]])

Get called just before PyTorch graph is traced.

(after auto wrapper)

specific_fragments
specific_fragments(model: nn.Module) -> T.Dict[str, str]

Optional custom fragments to pass.