Main export API's
For details on the exported artifact (directory vs .tar vs .tgz) and how
compression_level and the output path suffix influence it, see
Artifacts and Compression.
See Also
- NeMo ASR export tutorial: Export and run NeMo ASR
- Transformers/LLM export tutorial: LLM export guide
- Shapes remodeler tutorial: Provider-agnostic remodeler
Choosing the Target Runtime
torch_to_nnef exports to an inference-target abstraction. The most common choice is TractNNEF.
- Use
TractNNEF.latest()for the most recent supported tract. - Pin a specific tract version:
TractNNEF(SemanticVersion.from_str("0.23.0")). - Pass dynamic-axes constraints and feature toggles (e.g., SDPA reification) through the inference target when needed.
Example
from torch_to_nnef.inference_target import TractNNEF
from torch_to_nnef.utils import SemanticVersion
target = TractNNEF.latest() # or TractNNEF(SemanticVersion.from_str("0.23.0"))
export_model_to_nnef(
model=my_model.eval(),
args=(x,),
file_path_export="/tmp/model.nnef.tgz", # suffix expresses archive intent
inference_target=target,
input_names=["inp"],
output_names=["out"],
compression_level=1, # 1..9 => .tgz, 0 => .tar, None => .nnef dir
)
torch_to_nnef.export
export_model_to_nnef
export_model_to_nnef(model: torch.nn.Module, args, file_path_export: T.Union[Path, str], inference_target: InferenceTarget, input_names: T.Optional[T.List[str]] = None, output_names: T.Optional[T.List[str]] = None, compression_level: T.Optional[int] = 0, log_level: int = log.INFO, nnef_variable_naming_scheme: VariableNamingScheme = DEFAULT_VARNAME_SCHEME, check_io_names_qte_match: bool = True, debug_bundle_path: T.Optional[Path] = None, custom_extensions: T.Optional[T.List[str]] = None, allow_same_io_names: bool = False, auto_harden_jit: bool = True) -> Path
Main entrypoint of this library.
Export any torch.nn.Module to NNEF file format archive
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model |
Module
|
a nn.Module that have a |
required |
args |
a flat ordered list of tensors for each forward inputs of |
required | |
file_path_export |
Union[Path, str]
|
target path for the exported model.
Accepted forms are:
- ".../model.nnef" → base path; creates:
• directory when |
required |
inference_target |
InferenceTarget
|
can be |
required |
input_names |
Optional[List[str]]
|
Optional list of names for args, it replaces variable inputs names traced from graph (if set it must have the same size as number of args) |
None
|
output_names |
Optional[List[str]]
|
Optional list of names for outputs of |
None
|
compression_level |
Optional[int]
|
Optional[int] = 0
If None, writes an uncompressed |
0
|
log_level |
int
|
int,
logger level for |
INFO
|
nnef_variable_naming_scheme |
VariableNamingScheme
|
Possible choices NNEF variables naming schemes are: - "raw": Taking variable names from traced graph debugName directly - "natural_verbose": that try to provide nn.Module exported variable naming consistency - "natural_verbose_camel": that try to provide nn.Module exported variable naming consistency but with more consice camelCase variable pattern - "numeric": that try to be as concise as possible |
DEFAULT_VARNAME_SCHEME
|
check_io_names_qte_match |
bool
|
(default: True)
During the tracing process of the torch graph
One or more input provided can be removed if not contributing to
generate outputs while check_io_names_qte_match is True we ensure
that this input and output quantity remain constant with numbers in
|
True
|
debug_bundle_path |
Optional[Path]
|
Optional[Path] if specified it should create an archive bundle with all needed information to allow easier debug. |
None
|
custom_extensions |
Optional[List[str]]
|
Optional[List[str]] allow to add a set of extensions as defined in (https://registry.khronos.org/NNEF/specs/1.0/nnef-1.0.5.html) Useful to set specific extensions like for example: 'extension tract_assert S >= 0' those assertion allows to add limitation on dynamic shapes that are not expressed in traced graph (like for example maximum number of tokens for an LLM) |
None
|
allow_same_io_names |
bool
|
bool by default input and output names must be different to avoid simplification of the graph that would merge those tensors silently. If you really want to have same names for inputs and outputs set this flag to True. Some libs like 'nvidia/nemo' use this pattern. (note that it only make sense if it's a no operation) |
False
|
auto_harden_jit |
bool
|
bool (default: True)
When |
True
|
Returns:
| Name | Type | Description |
|---|---|---|
Path |
Path
|
the path to the exported artifact.
- If |
Examples:
For example this function can be used to export as simple perceptron model:
>>> import os
>>> import tarfile
>>> import tempfile
>>> from torch import nn
>>> mod = nn.Sequential(nn.Linear(1, 5), nn.ReLU())
>>> export_path = tempfile.mktemp(suffix=".nnef.tgz")
>>> inference_target = TractNNEF.latest()
>>> _ = export_model_to_nnef(
... mod,
... torch.rand(3, 1),
... export_path,
... inference_target,
... compression_level=0,
... input_names=["inp"],
... output_names=["out"]
... )
>>> os.chdir(export_path.rsplit("/", maxsplit=1)[0])
>>> tarfile.open(export_path).extract("graph.nnef")
>>> "graph network(inp) -> (out)" in open("graph.nnef").read()
True
export_tensors_from_disk_to_nnef
export_tensors_from_disk_to_nnef(store_filepath: T.Union[Path, str], output_dir: T.Union[Path, str], filter_key: T.Optional[T.Callable[[str], bool]] = None, fn_check_found_tensors: T.Optional[T.Callable[[T.Dict[str, _Tensor]], bool]] = None, map_location: T.Union[str, torch.device] = 'cpu') -> T.Dict[str, _Tensor]
Export any statedict or safetensors file torch.Tensors to NNEF .dat file.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
store_filepath |
Union[Path, str]
|
the filepath that hold the .safetensors , .pt or .bin containing the state dict |
required |
output_dir |
Union[Path, str]
|
directory to dump the NNEF tensor .dat files |
required |
filter_key |
Optional[Callable[[str], bool]]
|
An optional function to filter specific keys to be exported |
None
|
fn_check_found_tensors |
Optional[Callable[[Dict[str, _Tensor]], bool]]
|
post checking function to ensure all requested tensors have effectively been dumped |
None
|
map_location |
Union[str, device]
|
device mapping used by torch.load for .pt/.pth/.bin files (default: "cpu"). |
'cpu'
|
Returns:
| Type | Description |
|---|---|
Dict[str, _Tensor]
|
a dict of tensor name as key and torch.Tensor values,
identical to |
Examples:
Simple filtered example
>>> import tempfile
>>> from torch import nn
>>> class Mod(nn.Module):
... def __init__(self):
... super().__init__()
... self.a = nn.Linear(1, 5)
... self.b = nn.Linear(5, 1)
...
... def forward(self, x):
... return self.b(self.a(x))
>>> mod = Mod()
>>> pt_path = tempfile.mktemp(suffix=".pt")
>>> nnef_dir = tempfile.mkdtemp(suffix="_nnef")
>>> torch.save(mod.state_dict(), pt_path)
>>> def check(ts):
... assert all(_.startswith("a.") for _ in ts)
>>> exported_tensors = export_tensors_from_disk_to_nnef(
... pt_path,
... nnef_dir,
... lambda x: x.startswith("a."),
... check
... )
>>> list(exported_tensors.keys())
['a.weight', 'a.bias']
export_tensors_to_nnef
export_tensors_to_nnef(name_to_torch_tensors: T.Dict[str, _Tensor], output_dir: Path) -> T.Dict[str, _Tensor]
Export any torch.Tensors list to NNEF .dat file.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name_to_torch_tensors |
Dict[str, _Tensor]
|
dict A map of name (that will be used to define .dat filename) and tensor values (that can also be special torch_to_nnef tensors) |
required |
output_dir |
Path
|
directory to dump the NNEF tensor .dat files |
required |
Returns:
| Type | Description |
|---|---|
Dict[str, _Tensor]
|
a dict of tensor name as key and torch.Tensor values,
identical to |
Examples:
Simple example
>>> import tempfile
>>> from torch import nn
>>> class Mod(nn.Module):
... def __init__(self):
... super().__init__()
... self.a = nn.Linear(1, 5)
... self.b = nn.Linear(5, 1)
...
... def forward(self, x):
... return self.b(self.a(x))
>>> mod = Mod()
>>> nnef_dir = tempfile.mkdtemp(suffix="_nnef")
>>> exported_tensors = export_tensors_to_nnef(
... {k: v for k, v in mod.named_parameters() if k.startswith("b.")},
... nnef_dir,
... )
>>> list(exported_tensors.keys())
['b.weight', 'b.bias']
iter_torch_tensors_from_disk
iter_torch_tensors_from_disk(store_filepath: Path, filter_key: T.Optional[T.Callable[[str], bool]] = None, map_location: T.Union[str, torch.device] = 'cpu') -> T.Iterator[T.Tuple[str, _Tensor]]
Iter on torch tensors from disk .safetensors, .pt, pth, .bin.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
store_filepath |
Path
|
path to the container file holding PyTorch tensors (.pt, .pth, .bin and .safetensors) |
required |
filter_key |
Optional[Callable[[str], bool]]
|
if set, this function filter over tensor by name stored in those format |
None
|
map_location |
Union[str, device]
|
device mapping used by torch.load for .pt/.pth/.bin files (default: "cpu"). |
'cpu'
|
Yields:
| Type | Description |
|---|---|
str
|
provide each tensor that are validated by filter within store filepath |
_Tensor
|
one at a time as tuple with name first then the torch.Tensor itself |
JIT-only model hardening
For JIT-only artifacts whose Python source isn't on the import path
(e.g. silero_vad.jit), the standard recursive parser cannot resolve
the inner classes. harden_jit_for_export runs a chain of opt-in
graph passes that specialize the JIT graph for your example inputs,
producing a graph the standard exporter can consume. See the
JIT-only models tutorial for the full
chain and rationale per pass.
torch_to_nnef.torch_graph.harden
High-level helper that runs the JIT-only export hardening chain.
The individual passes in jit_inline and jit_passes are exposed for
fine-grained use (custom orderings, partial chains for debugging). Most
callers want the full chain; harden_jit_for_export wraps it with a
sensible default order and freeze step.
Each pass is a no-op on graphs that don't carry the relevant pattern, so the chain is safe to apply unconditionally.
Supported on torch >= 1.10 (the only API requirement is
torch._C._jit_interpret_graph, used by fold_data_dependent_ifs and
probed lazily). CI gates the chain on torch 2.11.0; older versions work
but are not regression-tested.
harden_jit_for_export
harden_jit_for_export(model: 'torch.jit.ScriptModule', args: T.Union[T.Sequence[T.Any], 'torch.Tensor'], *, freeze: bool = True, diagnostics: T.Optional[T.Dict[str, T.Any]] = None) -> 'torch.jit.ScriptModule'
Specialize a JIT ScriptModule's graph for the given example inputs.
Returns the (possibly frozen) module with the chain applied in place. The chain has two stages.
Resolve CallMethod / CallFunction / GetAttr (one of):
torch.jit.freeze(freeze=True, default). Requires the module to be in eval mode; pass amodel.eval()-ed instance or setfreeze=False. OnRuntimeError, the helper logs a warning and falls back to the manual inline.inline_unresolvable_submodules. Used when freeze was disabled or failed; covers CallMethods whose target class isn't on the import path. Skipped when freeze succeeded.
Specialize for the example inputs (always, in this order):
replace_size_calls_with_constants: forward-reach analysis foldsaten::dim/size/len/numelwhose values flow only into control flow. Tensor-shape consumers are left dynamic.fold_constant_scalar_arithmetic: cmps,__not__,__contains__, scalar casts.fold_constant_ifs: drop Ifs with a constant boolean condition.fold_tuple_index_through_tuple_constructandfold_tuple_unpack_through_tuple_construct: collapse theTupleConstruct -> TupleIndexandTupleConstruct -> TupleUnpackround-trips.strip_prim_data: dropprim::data(t)(autograd-detach no-op).strip_assertion_ifs: drop Ifs whose one branch is a pureRaiseException.fold_data_dependent_ifs: evaluate any remaining If condition under the example inputs and inline the chosen branch.
args is the model's forward arguments (no implicit self); the
helper prepends the (possibly frozen) self receiver internally
when invoking passes that re-execute the graph through
torch._C._jit_interpret_graph. A single torch.Tensor is accepted
as a shorthand for (tensor,), matching the permissive convention
of export_model_to_nnef(model, args=...).
When diagnostics is a dict, it is populated in place with the
count of nodes folded / stripped per pass, keyed by pass name. The
froze key records whether freeze succeeded;
inline_unresolvable_submodules is only present when the manual
inline ran (i.e. freeze was disabled or failed). Useful for
debugging unfamiliar JIT artifacts.
Remodeler
For boundary‑only transforms (collapse, bind, alias, outputs_keep), see the dedicated tutorial: Provider‑agnostic remodeler.