Skip to content

🚀 Main export API's

For details on the exported artifact (directory vs .tar vs .tgz) and how compression_level and the output path suffix influence it, see Artifacts and Compression.

See Also

Choosing the Target Runtime

torch_to_nnef exports to an inference-target abstraction. The most common choice is TractNNEF.

  • Use TractNNEF.latest() for the most recent supported tract.
  • Pin a specific tract version: TractNNEF(SemanticVersion.from_str("0.23.0")).
  • Pass dynamic-axes constraints and feature toggles (e.g., SDPA reification) through the inference target when needed.

Example

from torch_to_nnef.inference_target import TractNNEF
from torch_to_nnef.utils import SemanticVersion

target = TractNNEF.latest()  # or TractNNEF(SemanticVersion.from_str("0.23.0"))
export_model_to_nnef(
    model=my_model.eval(),
    args=(x,),
    file_path_export="/tmp/model.nnef.tgz",  # suffix expresses archive intent
    inference_target=target,
    input_names=["inp"],
    output_names=["out"],
    compression_level=1,  # 1..9 => .tgz, 0 => .tar, None => .nnef dir
)

torch_to_nnef.export

export_model_to_nnef

export_model_to_nnef(model: torch.nn.Module, args, file_path_export: T.Union[Path, str], inference_target: InferenceTarget, input_names: T.Optional[T.List[str]] = None, output_names: T.Optional[T.List[str]] = None, compression_level: T.Optional[int] = 0, log_level: int = log.INFO, nnef_variable_naming_scheme: VariableNamingScheme = DEFAULT_VARNAME_SCHEME, check_io_names_qte_match: bool = True, debug_bundle_path: T.Optional[Path] = None, custom_extensions: T.Optional[T.List[str]] = None, allow_same_io_names: bool = False, auto_harden_jit: bool = True) -> Path

Main entrypoint of this library.

Export any torch.nn.Module to NNEF file format archive

Parameters:

Name Type Description Default
model Module

a nn.Module that have a .forward function with only tensor arguments and outputs (no tuple, list, dict or objects) Only this function will be serialized

required
args

a flat ordered list of tensors for each forward inputs of model this list can not be of dynamic size (at serialization it will be fixed to quantity of tensor provided) WARNING! tensor size in args will increase export time so take that in consideration for dynamic axes

required
file_path_export Union[Path, str]

target path for the exported model. Accepted forms are: - ".../model.nnef" → base path; creates: • directory when compression_level is None • archive "model.nnef.tar" when compression_level == 0 • archive "model.nnef.tgz" when compression_level in 1..9 - ".../model.nnef.tgz" → treated as a request to use base name "model.nnef"; the actual artifact still follows the rule above (directory, .tar, or .tgz) depending on compression_level. Any other suffix pattern is rejected.

required
inference_target InferenceTarget

can be torch_to_nnef.TractNNEF or torch_to_nnef.KhronosNNEF for each you can specify version targeted: - KhronosNNEF is the least maintained so far, and is checked against nnef-tools PyTorch interpreter - TractNNEF is our main focus at SONOS, it is checked against tract inference engine among key paramters there is feature_flags: Optional[Set[str]], that may contains tract specifics dynamic_axes: Optional By default the exported model will have the shapes of all input and output tensors set to exactly match those given in args. To specify axes of tensors as dynamic (i.e. known only at runtime) set dynamic_axes to a dict with schema: KEY (str): an input or output name. Each name must also be provided in input_names or output_names. VALUE (dict or list): If a dict, keys are axis indices and values are axis names. If a list, each element is an axis index.

required
input_names Optional[List[str]]

Optional list of names for args, it replaces variable inputs names traced from graph (if set it must have the same size as number of args)

None
output_names Optional[List[str]]

Optional list of names for outputs of model.forward, it replaces variable output names traced from graph (if set it must have the same size as number of outputs)

None
compression_level Optional[int]

Optional[int] = 0 If None, writes an uncompressed .nnef directory. If 0, writes an uncompressed tar archive .nnef.tar. If 1..9, writes a gzip-compressed tar archive .nnef.tgz with the given compression level.

0
log_level int

int, logger level for torch_to_nnef following Python standard logging level can be set to: INFO, WARN, DEBUG ...

INFO
nnef_variable_naming_scheme VariableNamingScheme

Possible choices NNEF variables naming schemes are: - "raw": Taking variable names from traced graph debugName directly - "natural_verbose": that try to provide nn.Module exported variable naming consistency - "natural_verbose_camel": that try to provide nn.Module exported variable naming consistency but with more consice camelCase variable pattern - "numeric": that try to be as concise as possible

DEFAULT_VARNAME_SCHEME
check_io_names_qte_match bool

(default: True) During the tracing process of the torch graph One or more input provided can be removed if not contributing to generate outputs while check_io_names_qte_match is True we ensure that this input and output quantity remain constant with numbers in input_names and output_names.

True
debug_bundle_path Optional[Path]

Optional[Path] if specified it should create an archive bundle with all needed information to allow easier debug.

None
custom_extensions Optional[List[str]]

Optional[List[str]] allow to add a set of extensions as defined in (https://registry.khronos.org/NNEF/specs/1.0/nnef-1.0.5.html) Useful to set specific extensions like for example: 'extension tract_assert S >= 0' those assertion allows to add limitation on dynamic shapes that are not expressed in traced graph (like for example maximum number of tokens for an LLM)

None
allow_same_io_names bool

bool by default input and output names must be different to avoid simplification of the graph that would merge those tensors silently. If you really want to have same names for inputs and outputs set this flag to True. Some libs like 'nvidia/nemo' use this pattern. (note that it only make sense if it's a no operation)

False
auto_harden_jit bool

bool (default: True) When model is a torch.jit.ScriptModule, automatically run harden_jit_for_export to specialize its graph for the given example inputs (freeze + size folds + constant folds + tuple round-trip + data-dependent If fold). Each pass is a no-op on graphs that don't carry the relevant pattern, so the wrapper is safe to apply unconditionally; turn it off to drive the chain manually for fine-grained control.

True

Returns:

Name Type Description
Path Path

the path to the exported artifact. - If compression_level is None: returns the .nnef directory path. - If compression_level == 0: returns the .nnef.tar archive path. - If compression_level in 1..9: returns the .nnef.tgz archive path.

Examples:

For example this function can be used to export as simple perceptron model:

>>> import os
>>> import tarfile
>>> import tempfile
>>> from torch import nn
>>> mod = nn.Sequential(nn.Linear(1, 5), nn.ReLU())
>>> export_path = tempfile.mktemp(suffix=".nnef.tgz")
>>> inference_target = TractNNEF.latest()
>>> _ = export_model_to_nnef(
...   mod,
...   torch.rand(3, 1),
...   export_path,
...   inference_target,
...   compression_level=0,
...   input_names=["inp"],
...   output_names=["out"]
... )
>>> os.chdir(export_path.rsplit("/", maxsplit=1)[0])
>>> tarfile.open(export_path).extract("graph.nnef")
>>> "graph network(inp) -> (out)" in open("graph.nnef").read()
True

export_tensors_from_disk_to_nnef

export_tensors_from_disk_to_nnef(store_filepath: T.Union[Path, str], output_dir: T.Union[Path, str], filter_key: T.Optional[T.Callable[[str], bool]] = None, fn_check_found_tensors: T.Optional[T.Callable[[T.Dict[str, _Tensor]], bool]] = None, map_location: T.Union[str, torch.device] = 'cpu') -> T.Dict[str, _Tensor]

Export any statedict or safetensors file torch.Tensors to NNEF .dat file.

Parameters:

Name Type Description Default
store_filepath Union[Path, str]

the filepath that hold the .safetensors , .pt or .bin containing the state dict

required
output_dir Union[Path, str]

directory to dump the NNEF tensor .dat files

required
filter_key Optional[Callable[[str], bool]]

An optional function to filter specific keys to be exported

None
fn_check_found_tensors Optional[Callable[[Dict[str, _Tensor]], bool]]

post checking function to ensure all requested tensors have effectively been dumped

None
map_location Union[str, device]

device mapping used by torch.load for .pt/.pth/.bin files (default: "cpu").

'cpu'

Returns:

Type Description
Dict[str, _Tensor]

a dict of tensor name as key and torch.Tensor values, identical to torch_to_nnef.export.export_tensors_to_nnef

Examples:

Simple filtered example

>>> import tempfile
>>> from torch import nn
>>> class Mod(nn.Module):
...     def __init__(self):
...         super().__init__()
...         self.a = nn.Linear(1, 5)
...         self.b = nn.Linear(5, 1)
...
...     def forward(self, x):
...         return self.b(self.a(x))
>>> mod = Mod()
>>> pt_path = tempfile.mktemp(suffix=".pt")
>>> nnef_dir = tempfile.mkdtemp(suffix="_nnef")
>>> torch.save(mod.state_dict(), pt_path)
>>> def check(ts):
...     assert all(_.startswith("a.") for _ in ts)
>>> exported_tensors = export_tensors_from_disk_to_nnef(
...     pt_path,
...     nnef_dir,
...     lambda x: x.startswith("a."),
...     check
... )
>>> list(exported_tensors.keys())
['a.weight', 'a.bias']

export_tensors_to_nnef

export_tensors_to_nnef(name_to_torch_tensors: T.Dict[str, _Tensor], output_dir: Path) -> T.Dict[str, _Tensor]

Export any torch.Tensors list to NNEF .dat file.

Parameters:

Name Type Description Default
name_to_torch_tensors Dict[str, _Tensor]

dict A map of name (that will be used to define .dat filename) and tensor values (that can also be special torch_to_nnef tensors)

required
output_dir Path

directory to dump the NNEF tensor .dat files

required

Returns:

Type Description
Dict[str, _Tensor]

a dict of tensor name as key and torch.Tensor values, identical to torch_to_nnef.export.export_tensors_to_nnef

Examples:

Simple example

>>> import tempfile
>>> from torch import nn
>>> class Mod(nn.Module):
...     def __init__(self):
...         super().__init__()
...         self.a = nn.Linear(1, 5)
...         self.b = nn.Linear(5, 1)
...
...     def forward(self, x):
...         return self.b(self.a(x))
>>> mod = Mod()
>>> nnef_dir = tempfile.mkdtemp(suffix="_nnef")
>>> exported_tensors = export_tensors_to_nnef(
...     {k: v for k, v in mod.named_parameters() if k.startswith("b.")},
...     nnef_dir,
... )
>>> list(exported_tensors.keys())
['b.weight', 'b.bias']

iter_torch_tensors_from_disk

iter_torch_tensors_from_disk(store_filepath: Path, filter_key: T.Optional[T.Callable[[str], bool]] = None, map_location: T.Union[str, torch.device] = 'cpu') -> T.Iterator[T.Tuple[str, _Tensor]]

Iter on torch tensors from disk .safetensors, .pt, pth, .bin.

Parameters:

Name Type Description Default
store_filepath Path

path to the container file holding PyTorch tensors (.pt, .pth, .bin and .safetensors)

required
filter_key Optional[Callable[[str], bool]]

if set, this function filter over tensor by name stored in those format

None
map_location Union[str, device]

device mapping used by torch.load for .pt/.pth/.bin files (default: "cpu").

'cpu'

Yields:

Type Description
str

provide each tensor that are validated by filter within store filepath

_Tensor

one at a time as tuple with name first then the torch.Tensor itself

JIT-only model hardening

For JIT-only artifacts whose Python source isn't on the import path (e.g. silero_vad.jit), the standard recursive parser cannot resolve the inner classes. harden_jit_for_export runs a chain of opt-in graph passes that specialize the JIT graph for your example inputs, producing a graph the standard exporter can consume. See the JIT-only models tutorial for the full chain and rationale per pass.

torch_to_nnef.torch_graph.harden

High-level helper that runs the JIT-only export hardening chain.

The individual passes in jit_inline and jit_passes are exposed for fine-grained use (custom orderings, partial chains for debugging). Most callers want the full chain; harden_jit_for_export wraps it with a sensible default order and freeze step.

Each pass is a no-op on graphs that don't carry the relevant pattern, so the chain is safe to apply unconditionally.

Supported on torch >= 1.10 (the only API requirement is torch._C._jit_interpret_graph, used by fold_data_dependent_ifs and probed lazily). CI gates the chain on torch 2.11.0; older versions work but are not regression-tested.

harden_jit_for_export

harden_jit_for_export(model: 'torch.jit.ScriptModule', args: T.Union[T.Sequence[T.Any], 'torch.Tensor'], *, freeze: bool = True, diagnostics: T.Optional[T.Dict[str, T.Any]] = None) -> 'torch.jit.ScriptModule'

Specialize a JIT ScriptModule's graph for the given example inputs.

Returns the (possibly frozen) module with the chain applied in place. The chain has two stages.

Resolve CallMethod / CallFunction / GetAttr (one of):

  • torch.jit.freeze (freeze=True, default). Requires the module to be in eval mode; pass a model.eval()-ed instance or set freeze=False. On RuntimeError, the helper logs a warning and falls back to the manual inline.
  • inline_unresolvable_submodules. Used when freeze was disabled or failed; covers CallMethods whose target class isn't on the import path. Skipped when freeze succeeded.

Specialize for the example inputs (always, in this order):

  1. replace_size_calls_with_constants: forward-reach analysis folds aten::dim/size/len/numel whose values flow only into control flow. Tensor-shape consumers are left dynamic.
  2. fold_constant_scalar_arithmetic: cmps, __not__, __contains__, scalar casts.
  3. fold_constant_ifs: drop Ifs with a constant boolean condition.
  4. fold_tuple_index_through_tuple_construct and fold_tuple_unpack_through_tuple_construct: collapse the TupleConstruct -> TupleIndex and TupleConstruct -> TupleUnpack round-trips.
  5. strip_prim_data: drop prim::data(t) (autograd-detach no-op).
  6. strip_assertion_ifs: drop Ifs whose one branch is a pure RaiseException.
  7. fold_data_dependent_ifs: evaluate any remaining If condition under the example inputs and inline the chosen branch.

args is the model's forward arguments (no implicit self); the helper prepends the (possibly frozen) self receiver internally when invoking passes that re-execute the graph through torch._C._jit_interpret_graph. A single torch.Tensor is accepted as a shorthand for (tensor,), matching the permissive convention of export_model_to_nnef(model, args=...).

When diagnostics is a dict, it is populated in place with the count of nodes folded / stripped per pass, keyed by pass name. The froze key records whether freeze succeeded; inline_unresolvable_submodules is only present when the manual inline ran (i.e. freeze was disabled or failed). Useful for debugging unfamiliar JIT artifacts.

Remodeler

For boundary‑only transforms (collapse, bind, alias, outputs_keep), see the dedicated tutorial: Provider‑agnostic remodeler.