Skip to content

torch_to_nnef.torch_graph

torch_to_nnef.torch_graph

torch_graph is intended to extract full representation of PyTorch Graph.

From PyTorch into a stable intermediate representation suitable to then apply translation operation to NNEF. This means that not all PyTorch orginal graph is translated. For example, we ignore part linked to device location informations, memory specific operation or parameters linked to gradients.

This choice which is different compared to torch.onnx module due to the absence of control (on our side) over evolution of PyTorch internals. If some of the PyTorch internals are modified only this module should idealy be impacted.

Here there is NO notion of dynamic axes all shapes are supposedly defined based on provided input example. At latter stage in other modules the dynamic shapes need to be introduced if requested by user.

Data dataclass

Data(name: str, data: T.Any)

Bases: NamedItem

Base abstract T2N IR data holder.

FixedTensorList dataclass

FixedTensorList(name: str, data: T.Sequence[T.Union[TensorVariable, PythonConstant]])

Bases: Data

FixedTensorList is a list that contains tensor constant or not.

TensorVariable dataclass

TensorVariable(name: str, data: T.Any, shape: T.Optional[T.List[int]], dtype: T.Optional[torch.dtype], quant: T.Optional[T.Dict[str, T.Any]] = None, _traced_data: T.Optional[torch.Tensor] = None)

Bases: Data

tracing_data property
tracing_data

Generate data if is not fixed based on tensor information.

we use it to produce computation trace

TorchModuleIRGraph

TorchModuleIRGraph(torch_module_tracer: TorchModuleTracer, omit_useless_nodes: bool = True, is_root_module: bool = False)

Torch Graph intermediate representation from: jit.trace with recursion.

This is not direct torch._C.Graph but simpler abstraction, with:

A list of data nodes in self.data_nodes A list of operations nodes in self.op_nodes

self.inputs is a list of reference of some self.data_nodes self.outputs is a list of reference of some self.data_nodes

This abstraction of the vanilla Torch Graph allow to manipulate graph in order to check/complete missing data informations and ignore useless operations for our transcription needs.

It's also allows to be less reliant on base graph in case of modification of PyTorch Internals (think Adapter Pattern).

Warning ! Only NOT nested data container (TupleTensors, FixedTensorList, ...) are supported for now

parse
parse(nnef_variable_naming_scheme: VariableNamingScheme = DEFAULT_VARNAME_SCHEME, provided_inputs=None, provided_outputs=None, forced_inputs_names=None, forced_outputs_names=None)

Core parsing transforming nn.Module into torch_to_nnef IR.

printall
printall()

Display Helper Graph infos in stdout of your tty.

remap_node
remap_node(from_node, to_node)

Remap a data_node to another.

TorchModuleTracer

TorchModuleTracer(module: nn.Module, traced_module: T.Optional[torch.jit.TracedModule] = None, fn_name: str = 'forward', args: T.Optional[T.Tuple[T.Any, ...]] = None)

Evaluate Optimized traced Function code so that signature always match.

original Module is passed to do proper un-boxing later on. This is needed because we have a re-routing based on actual module classtype.

Create a tracer for module.

The tracer stores the original module, an optional pre‑traced torch.jit.TracedModule (which allows re‑use of a previously computed trace), the name of the forward method to trace, and the arguments used for tracing. The arguments are post‑processed by :func:maybe_quantize_args_tensor to ensure compatibility with quantized modules.

torch_graph property
torch_graph

Return the underlying PyTorch graph object.

The actual torch.Graph is retrieved from the traced module. When a different forward method is requested (fn_name differs from "forward"), the corresponding sub‑graph is returned instead.

traced_module property
traced_module

Return the traced module, computing it lazily if required.

If self._traced_module is None the method will perform a jit.trace on self.mod with self.args while handling possible PyTorch version nuances. Any RuntimeError raised by torch.jit.trace is wrapped into a :class:~torch_to_nnef.exceptions.T2NErrorTorchJitTraceFailed exception.

TorchOp dataclass

TorchOp(kind: str, module_path: str, inputs: T.List[Data], outputs: T.List[TtupleOrVar], scope: str, op_ref: T.Optional[T.Callable], call_name: T.Optional[str])
call_op
call_op()

Produce operation output based on traced inputs with real torch call.

This operation call is done via self.args arguments (for now). Which means that we need to have all args needed in parameters order, following at least 1 underling torch operation signature.

NOTE: we use a different approach than original torch.onnx which pass parameter by keyword arguments, this is due to the fact that we are not aware of argument name being provided in exported graph ( from what we understand torch.onnx solve this via explicit rerouting of all signatures, which might be a bit bulky in most case ).

realise_output_type_and_size
realise_output_type_and_size(approx: bool = True) -> bool

Trace output and try to find type shape and constant realisation.

update_call_op_arg_kwargs
update_call_op_arg_kwargs(args)

Custom adaptation to call aten fn with torch exposed py fn.

fold_constant_ifs

fold_constant_ifs(graph: 'torch._C.Graph') -> int

Fold prim::If nodes whose condition is a prim::Constant[bool].

Replaces the If with the chosen block's nodes. Returns the count folded.

fold_constant_scalar_arithmetic

fold_constant_scalar_arithmetic(graph: 'torch._C.Graph') -> int

Fold scalar arithmetic on prim::Constant operands.

Walks aten::eq/ne/lt/le/gt/ge, aten::__not__, aten::__contains__, and the unary aten::Bool/Int/Float casts. Standalone replacement for _jit_pass_constant_propagation: used in the JIT-only export chain to avoid a torch internal assertion that fires when the upstream pass walks a graph mixing Phase 1 inlined submodules and Phase 2 size-fold constants.

Returns the number of nodes folded.

fold_data_dependent_ifs

fold_data_dependent_ifs(graph: 'torch._C.Graph', example_inputs: T.Sequence[T.Any]) -> int

Fold prim::If nodes whose condition is data-dependent on the input.

PyTorch's JIT shape-analysis passes do not propagate shapes through prim::If nodes that produce tensors, leaving runtime dim/shape checks (e.g. nn.LSTMCell's if input.dim() == 1: ...) unresolved by replace_size_calls_with_constants + fold_constant_ifs. To specialize the graph for the user's example inputs, we evaluate each remaining prim::If's condition by running the graph itself with the example, observing the chosen branch, and inlining it.

Only top-level Ifs are folded each pass; nested Ifs surface to the top once their parent is removed, so a fixed-point loop catches them too.

Returns the number of Ifs folded.

Requires torch._C._jit_interpret_graph, exposed since torch 1.10. The probe is deferred until a candidate If is actually found, so callers on older torch with already-clean graphs (no remaining Ifs) return 0 without raising. The rest of torch_to_nnef still works on torch 1.8+, only this one pass needs the newer API.

fold_tuple_index_through_tuple_construct

fold_tuple_index_through_tuple_construct(graph: 'torch._C.Graph') -> int

Fold prim::TupleIndex(tuple_const, k) into the k-th tuple input.

JIT artifacts whose Python source builds a tuple at the call site (e.g. return (h, c)) and consumes it later via positional indexing (pair[0], pair[1]) leave behind prim::TupleConstruct -> prim::TupleIndex chains in the inlined graph. t2n's parser already knows about TupleConstruct and TupleUnpack, but TupleIndex is unsupported. When the index is a static prim::Constant int and the tuple value is the direct output of a TupleConstruct, we rewire TupleIndex's output to the tuple's k-th input verbatim, leaving the TupleConstruct itself in place (DCE removes it later if it has no other consumers).

Returns the count of nodes folded.

fold_tuple_unpack_through_tuple_construct

fold_tuple_unpack_through_tuple_construct(graph: 'torch._C.Graph') -> int

Fold prim::TupleUnpack(prim::TupleConstruct(...)) into the inputs.

Sibling of fold_tuple_index_through_tuple_construct. JIT artifacts that build a tuple at the call site and consume it via destructuring assignment (a, b = my_pair()) leave behind a prim::TupleConstruct -> prim::TupleUnpack chain after inlining. The k-th unpack output is exactly the k-th construct input; we rewire each unpack output verbatim and destroy the unpack. The TupleConstruct is left in place; DCE removes it if its only consumers (the unpack outputs) are now gone.

Returns the count of TupleUnpack nodes folded.

harden_jit_for_export

harden_jit_for_export(model: 'torch.jit.ScriptModule', args: T.Union[T.Sequence[T.Any], 'torch.Tensor'], *, freeze: bool = True, diagnostics: T.Optional[T.Dict[str, T.Any]] = None) -> 'torch.jit.ScriptModule'

Specialize a JIT ScriptModule's graph for the given example inputs.

Returns the (possibly frozen) module with the chain applied in place. The chain has two stages.

Resolve CallMethod / CallFunction / GetAttr (one of):

  • torch.jit.freeze (freeze=True, default). Requires the module to be in eval mode; pass a model.eval()-ed instance or set freeze=False. On RuntimeError, the helper logs a warning and falls back to the manual inline.
  • inline_unresolvable_submodules. Used when freeze was disabled or failed; covers CallMethods whose target class isn't on the import path. Skipped when freeze succeeded.

Specialize for the example inputs (always, in this order):

  1. replace_size_calls_with_constants: forward-reach analysis folds aten::dim/size/len/numel whose values flow only into control flow. Tensor-shape consumers are left dynamic.
  2. fold_constant_scalar_arithmetic: cmps, __not__, __contains__, scalar casts.
  3. fold_constant_ifs: drop Ifs with a constant boolean condition.
  4. fold_tuple_index_through_tuple_construct and fold_tuple_unpack_through_tuple_construct: collapse the TupleConstruct -> TupleIndex and TupleConstruct -> TupleUnpack round-trips.
  5. strip_prim_data: drop prim::data(t) (autograd-detach no-op).
  6. strip_assertion_ifs: drop Ifs whose one branch is a pure RaiseException.
  7. fold_data_dependent_ifs: evaluate any remaining If condition under the example inputs and inline the chosen branch.

args is the model's forward arguments (no implicit self); the helper prepends the (possibly frozen) self receiver internally when invoking passes that re-execute the graph through torch._C._jit_interpret_graph. A single torch.Tensor is accepted as a shorthand for (tensor,), matching the permissive convention of export_model_to_nnef(model, args=...).

When diagnostics is a dict, it is populated in place with the count of nodes folded / stripped per pass, keyed by pass name. The froze key records whether freeze succeeded; inline_unresolvable_submodules is only present when the manual inline ran (i.e. freeze was disabled or failed). Useful for debugging unfamiliar JIT artifacts.

inline_unresolvable_submodules

inline_unresolvable_submodules(graph: 'torch._C.Graph', model: nn.Module, max_passes: int = 1024) -> int

Inline every prim::CallMethod whose target class is not importable.

Iterates until fixed point: an inlined body may itself contain further CallMethods that also need inlining. max_passes bounds the loop against pathological non-termination on weirdly structured graphs; raise it for unusually deep nested inlines, lower it to fail fast when debugging.

Returns the count of inlined calls.

replace_size_calls_with_constants

replace_size_calls_with_constants(graph: 'torch._C.Graph', example_inputs: T.Sequence[T.Any]) -> int

Fold size queries whose values flow only into control flow.

Reach analysis: walks forward from each candidate source. A source is folded only when every reach path terminates in prim::If condition, prim::Loop trip count, or prim::RaiseException without ever crossing a node that produces a tensor-typed output. Sources whose value flows into tensor production (via aten::reshape, aten::view, aten::expand, aten::zeros, ...) are left alone, so the standard aten::size handler in op/aten/other.py can route them through tract_core_shape_of under inference_target.has_dynamic_axes.

This makes the pass safe by default for any export target, including those declaring dynamic axes: a dim consumed by aten::view will not be baked into the NNEF graph as a constant.

Returns the count of size-call nodes folded.

strip_assertion_ifs

strip_assertion_ifs(graph: 'torch._C.Graph') -> int

Drop prim::If nodes whose one branch is purely a RaiseException.

Replace uses of the If's outputs with the non-raising block's outputs, then destroy the If. Walks nested blocks (assertion ifs are often inside other prim::If branches). Returns the count of stripped nodes.

strip_prim_data

strip_prim_data(graph: 'torch._C.Graph') -> int

Replace prim::data(t) nodes with their input.

prim::data is the IR form of Tensor .data access (detaches from autograd). In inference it is a no-op; t2n's parser doesn't have a handler for it, so we elide it.