Skip to content

torch_to_nnef

torch_to_nnef

Top-level package for torch_to_nnef.

KhronosNNEF

KhronosNNEF(version: T.Union[SemanticVersion, str], check_io: bool = True)

Bases: InferenceTarget

Khronos Specification compliant NNEF asset build.

in case of check_io=True we perform evaluation against nnef_tool nnef to pytorch converter. And access original and reloaded pytorch model provide same outputs

post_export
post_export(model: nn.Module, nnef_graph: NGraph, args: T.List[T.Any], exported_filepath: Path, debug_bundle_path: T.Optional[Path] = None)

Check io via the Torch interpreter of NNEF-Tools.

TractNNEF

TractNNEF(version: T.Union[str, SemanticVersion], feature_flags: T.Optional[T.Set[TractFeatureFlag]] = None, check_io: bool = True, dynamic_axes: T.Optional[T.Dict[str, T.Dict[int, str]]] = None, specific_tract_binary_path: T.Optional[Path] = None, check_io_tolerance: TractCheckTolerance = TractCheckTolerance.APPROXIMATE, specific_properties: T.Optional[T.Dict[str, str]] = None, dump_identity_properties: bool = True, force_attention_inner_in_f32: bool = False, force_linear_accumulation_in_f32: bool = False, force_norm_in_f32: bool = False, reify_sdpa_operator: bool = False, upsample_with_debox: bool = False)

Bases: InferenceTarget

Tract NNEF inference target.

Init.

Parameters:

Name Type Description Default
version Union[str, SemanticVersion]

tract version targeted for export

required
feature_flags Optional[Set[TractFeatureFlag]]

set of possibly added feature flags from tract (for example complex numbers)

None
check_io bool

check between tract cli and Pytorch original model that given provided input, output is similar

True
dynamic_axes Optional[Dict[str, Dict[int, str]]]

Optional specification of dynamic dimension By default the exported model will have the shapes of all input and output tensors set to exactly match those given in args. To specify axes of tensors as dynamic (i.e. known only at runtime) set dynamic_axes to a dict with schema: KEY (str): an input or output name. Each name must also be provided in input_names or output_names. VALUE (dict or list): If a dict, keys are axis indices and values are axis names. If a list, each element is an axis index.

None
specific_tract_binary_path Optional[Path]

filepath of tract cli in case of custom non released version of tract (for testing purpose)

None
check_io_tolerance TractCheckTolerance

TractCheckTolerance level of difference tolerance between original output values and those generated by tract (those are defined tract levels)

APPROXIMATE
specific_properties Optional[Dict[str, str]]

custom tract_properties you wish to add inside NNEF asset (will be parsed by tract as metadata)

None
dump_identity_properties bool

add tract_properties relative to user identity (host, username, OS...), helpfull for debug

True
force_attention_inner_in_f32 bool
control if attention should be forced as f32 inside
(even if inputs are all f16), usefull for unstable networks
like qwen2.5
False
force_linear_accumulation_in_f32 bool

usefull for f16 models to ensure that output of f16. f16 matmul become f32 accumulators.

False
force_norm_in_f32 bool

ensure that all normalization layers are in f32 whatever the original PyTorch modeling.

False
reify_sdpa_operator bool

enable the conversion of scaled_dot_product_attention as a tract operator (intead of a NNEF fragment). Experimental feature.

False
upsample_with_debox bool

use debox upsample operator instead of deconvolution. This should be faster. (if tract version support it). Experimental feature.

False
post_export
post_export(model: nn.Module, nnef_graph: NGraph, args: T.List[T.Any], exported_filepath: Path, debug_bundle_path: T.Optional[Path] = None)

Perform check io and build debug bundle if fail.

post_trace
post_trace(nnef_graph, active_custom_extensions)

Add dynamic axes in the NNEF graph.

pre_trace
pre_trace(model: nn.Module, input_names: T.Optional[T.List[str]], output_names: T.Optional[T.List[str]])

Check dynamic_axes are correctly formated.

specific_fragments
specific_fragments(model: nn.Module) -> T.Dict[str, str]

Optional custom fragments to pass.

export_model_to_nnef

export_model_to_nnef(model: torch.nn.Module, args, file_path_export: T.Union[Path, str], inference_target: InferenceTarget, input_names: T.Optional[T.List[str]] = None, output_names: T.Optional[T.List[str]] = None, compression_level: int = 0, log_level: int = log.INFO, nnef_variable_naming_scheme: VariableNamingScheme = DEFAULT_VARNAME_SCHEME, check_io_names_qte_match: bool = True, debug_bundle_path: T.Optional[Path] = None, custom_extensions: T.Optional[T.List[str]] = None)

Main entrypoint of this library.

Export any torch.nn.Module to NNEF file format archive

Parameters:

Name Type Description Default
model Module

a nn.Module that have a .forward function with only tensor arguments and outputs (no tuple, list, dict or objects) Only this function will be serialized

required
args

a flat ordered list of tensors for each forward inputs of model this list can not be of dynamic size (at serialization it will be fixed to quantity of tensor provided) WARNING! tensor size in args will increase export time so take that in consideration for dynamic axes

required
file_path_export Union[Path, str]

a Path to the exported NNEF serialized model archive. It must by convention end with .nnef.tgz suffixes

required
inference_target InferenceTarget

can be torch_to_nnef.TractNNEF or torch_to_nnef.KhronosNNEF for each you can specify version targeted: - KhronosNNEF is the least maintained so far, and is checked against nnef-tools PyTorch interpreter - TractNNEF is our main focus at SONOS, it is checked against tract inference engine among key paramters there is feature_flags: Optional[Set[str]], that may contains tract specifics dynamic_axes: Optional By default the exported model will have the shapes of all input and output tensors set to exactly match those given in args. To specify axes of tensors as dynamic (i.e. known only at runtime) set dynamic_axes to a dict with schema: KEY (str): an input or output name. Each name must also be provided in input_names or output_names. VALUE (dict or list): If a dict, keys are axis indices and values are axis names. If a list, each element is an axis index.

required
specific_tract_binary_path

Optional[Path] ideal to check io against new tract versions

required
input_names Optional[List[str]]

Optional list of names for args, it replaces variable inputs names traced from graph (if set it must have the same size as number of args)

None
output_names Optional[List[str]]

Optional list of names for outputs of model.forward, it replaces variable output names traced from graph (if set it must have the same size as number of outputs)

None
compression_level int

int (>= 0) compression level of tar.gz (higher is more compressed)

0
log_level int

int, logger level for torch_to_nnef following Python standard logging level can be set to: INFO, WARN, DEBUG ...

INFO
nnef_variable_naming_scheme VariableNamingScheme

Possible choices NNEF variables naming schemes are: - "raw": Taking variable names from traced graph debugName directly - "natural_verbose": that try to provide nn.Module exported variable naming consistency - "natural_verbose_camel": that try to provide nn.Module exported variable naming consistency but with more consice camelCase variable pattern - "numeric": that try to be as concise as possible

DEFAULT_VARNAME_SCHEME
check_io_names_qte_match bool

(default: True) During the tracing process of the torch graph One or more input provided can be removed if not contributing to generate outputs while check_io_names_qte_match is True we ensure that this input and output quantity remain constant with numbers in input_names and output_names.

True
debug_bundle_path Optional[Path]

Optional[Path] if specified it should create an archive bundle with all needed information to allow easier debug.

None
custom_extensions Optional[List[str]]

Optional[List[str]] allow to add a set of extensions as defined in (https://registry.khronos.org/NNEF/specs/1.0/nnef-1.0.5.html) Useful to set specific extensions like for example: 'extension tract_assert S >= 0' those assertion allows to add limitation on dynamic shapes that are not expressed in traced graph (like for example maximum number of tokens for an LLM)

None

Examples:

For example this function can be used to export as simple perceptron model:

>>> import os
>>> import tarfile
>>> import tempfile
>>> from torch import nn
>>> mod = nn.Sequential(nn.Linear(1, 5), nn.ReLU())
>>> export_path = tempfile.mktemp(suffix=".nnef.tgz")
>>> inference_target = TractNNEF.latest()
>>> export_model_to_nnef(
...   mod,
...   torch.rand(3, 1),
...   export_path,
...   inference_target,
...   input_names=["inp"],
...   output_names=["out"]
... )
>>> os.chdir(export_path.rsplit("/", maxsplit=1)[0])
>>> tarfile.open(export_path).extract("graph.nnef")
>>> "graph network(inp) -> (out)" in open("graph.nnef").read()
True

export_tensors_from_disk_to_nnef

export_tensors_from_disk_to_nnef(store_filepath: T.Union[Path, str], output_dir: T.Union[Path, str], filter_key: T.Optional[T.Callable[[str], bool]] = None, fn_check_found_tensors: T.Optional[T.Callable[[T.Dict[str, _Tensor]], bool]] = None) -> T.Dict[str, _Tensor]

Export any statedict or safetensors file torch.Tensors to NNEF .dat file.

Parameters:

Name Type Description Default
store_filepath Union[Path, str]

the filepath that hold the .safetensors , .pt or .bin containing the state dict

required
output_dir Union[Path, str]

directory to dump the NNEF tensor .dat files

required
filter_key Optional[Callable[[str], bool]]

An optional function to filter specific keys to be exported

None
fn_check_found_tensors Optional[Callable[[Dict[str, _Tensor]], bool]]

post checking function to ensure all requested tensors have effectively been dumped

None

Returns:

Type Description
Dict[str, _Tensor]

a dict of tensor name as key and torch.Tensor values, identical to torch_to_nnef.export.export_tensors_to_nnef

Examples:

Simple filtered example

>>> import tempfile
>>> from torch import nn
>>> class Mod(nn.Module):
...     def __init__(self):
...         super().__init__()
...         self.a = nn.Linear(1, 5)
...         self.b = nn.Linear(5, 1)
...
...     def forward(self, x):
...         return self.b(self.a(x))
>>> mod = Mod()
>>> pt_path = tempfile.mktemp(suffix=".pt")
>>> nnef_dir = tempfile.mkdtemp(suffix="_nnef")
>>> torch.save(mod.state_dict(), pt_path)
>>> def check(ts):
...     assert all(_.startswith("a.") for _ in ts)
>>> exported_tensors = export_tensors_from_disk_to_nnef(
...     pt_path,
...     nnef_dir,
...     lambda x: x.startswith("a."),
...     check
... )
>>> list(exported_tensors.keys())
['a.weight', 'a.bias']

export_tensors_to_nnef

export_tensors_to_nnef(name_to_torch_tensors: T.Dict[str, _Tensor], output_dir: Path) -> T.Dict[str, _Tensor]

Export any torch.Tensors list to NNEF .dat file.

Parameters:

Name Type Description Default
name_to_torch_tensors Dict[str, _Tensor]

dict A map of name (that will be used to define .dat filename) and tensor values (that can also be special torch_to_nnef tensors)

required
output_dir Path

directory to dump the NNEF tensor .dat files

required

Returns:

Type Description
Dict[str, _Tensor]

a dict of tensor name as key and torch.Tensor values, identical to torch_to_nnef.export.export_tensors_to_nnef

Examples:

Simple example

>>> import tempfile
>>> from torch import nn
>>> class Mod(nn.Module):
...     def __init__(self):
...         super().__init__()
...         self.a = nn.Linear(1, 5)
...         self.b = nn.Linear(5, 1)
...
...     def forward(self, x):
...         return self.b(self.a(x))
>>> mod = Mod()
>>> nnef_dir = tempfile.mkdtemp(suffix="_nnef")
>>> exported_tensors = export_tensors_to_nnef(
...     {k: v for k, v in mod.named_parameters() if k.startswith("b.")},
...     nnef_dir,
... )
>>> list(exported_tensors.keys())
['b.weight', 'b.bias']