torch_to_nnef.torch_graph
torch_to_nnef.torch_graph
torch_graph is intended to extract full representation of PyTorch Graph.
From PyTorch into a stable intermediate representation suitable to then apply translation operation to NNEF. This means that not all PyTorch orginal graph is translated. For example, we ignore part linked to device location informations, memory specific operation or parameters linked to gradients.
This choice which is different compared to torch.onnx module due to the absence of control (on our side) over evolution of PyTorch internals. If some of the PyTorch internals are modified only this module should idealy be impacted.
Here there is NO notion of dynamic axes all shapes are supposedly defined based on provided input example. At latter stage in other modules the dynamic shapes need to be introduced if requested by user.
FixedTensorList
dataclass
TensorVariable
dataclass
TensorVariable(name: str, data: T.Any, shape: T.Optional[T.List[int]], dtype: T.Optional[torch.dtype], quant: T.Optional[T.Dict[str, T.Any]] = None, _traced_data: T.Optional[torch.Tensor] = None)
Bases: Data
TorchModuleIRGraph
TorchModuleIRGraph(torch_module_tracer: TorchModuleTracer, omit_useless_nodes: bool = True, is_root_module: bool = False)
Torch Graph intermediate representation from: jit.trace with recursion.
This is not direct torch._C.Graph but simpler abstraction, with:
A list of data nodes in self.data_nodes
A list of operations nodes in self.op_nodes
self.inputs is a list of reference of some self.data_nodes
self.outputs is a list of reference of some self.data_nodes
This abstraction of the vanilla Torch Graph allow to manipulate graph in order to check/complete missing data informations and ignore useless operations for our transcription needs.
It's also allows to be less reliant on base graph in case of modification of PyTorch Internals (think Adapter Pattern).
Warning ! Only NOT nested data container (TupleTensors, FixedTensorList, ...) are supported for now
TorchModuleTracer
TorchModuleTracer(module: nn.Module, traced_module: T.Optional[torch.jit.TracedModule] = None, fn_name: str = 'forward', args: T.Optional[T.Tuple[T.Any, ...]] = None)
Evaluate Optimized traced Function code so that signature always match.
original Module is passed to do proper un-boxing later on. This is needed because we have a re-routing based on actual module classtype.
Create a tracer for module.
The tracer stores the original module, an optional pre‑traced
torch.jit.TracedModule (which allows re‑use of a previously
computed trace), the name of the forward method to trace, and the
arguments used for tracing. The arguments are post‑processed by
:func:maybe_quantize_args_tensor to ensure compatibility with
quantized modules.
torch_graph
property
Return the underlying PyTorch graph object.
The actual torch.Graph is retrieved from the traced module.
When a different forward method is requested (fn_name differs
from "forward"), the corresponding sub‑graph is returned instead.
traced_module
property
Return the traced module, computing it lazily if required.
If self._traced_module is None the method will perform a
jit.trace on self.mod with self.args while handling
possible PyTorch version nuances. Any RuntimeError raised by
torch.jit.trace is wrapped into a
:class:~torch_to_nnef.exceptions.T2NErrorTorchJitTraceFailed
exception.
TorchOp
dataclass
TorchOp(kind: str, module_path: str, inputs: T.List[Data], outputs: T.List[TtupleOrVar], scope: str, op_ref: T.Optional[T.Callable], call_name: T.Optional[str])
call_op
Produce operation output based on traced inputs with real torch call.
This operation call is done via self.args arguments (for now). Which means that we need to have all args needed in parameters order, following at least 1 underling torch operation signature.
NOTE: we use a different approach than original torch.onnx which pass parameter by keyword arguments, this is due to the fact that we are not aware of argument name being provided in exported graph ( from what we understand torch.onnx solve this via explicit rerouting of all signatures, which might be a bit bulky in most case ).
realise_output_type_and_size
Trace output and try to find type shape and constant realisation.
fold_constant_ifs
Fold prim::If nodes whose condition is a prim::Constant[bool].
Replaces the If with the chosen block's nodes. Returns the count folded.
fold_constant_scalar_arithmetic
Fold scalar arithmetic on prim::Constant operands.
Walks aten::eq/ne/lt/le/gt/ge, aten::__not__,
aten::__contains__, and the unary aten::Bool/Int/Float casts.
Standalone replacement for _jit_pass_constant_propagation: used
in the JIT-only export chain to avoid a torch internal assertion
that fires when the upstream pass walks a graph mixing Phase 1
inlined submodules and Phase 2 size-fold constants.
Returns the number of nodes folded.
fold_data_dependent_ifs
Fold prim::If nodes whose condition is data-dependent on the input.
PyTorch's JIT shape-analysis passes do not propagate shapes through
prim::If nodes that produce tensors, leaving runtime dim/shape
checks (e.g. nn.LSTMCell's if input.dim() == 1: ...) unresolved
by replace_size_calls_with_constants + fold_constant_ifs. To
specialize the graph for the user's example inputs, we evaluate
each remaining prim::If's condition by running the graph itself
with the example, observing the chosen branch, and inlining it.
Only top-level Ifs are folded each pass; nested Ifs surface to the top once their parent is removed, so a fixed-point loop catches them too.
Returns the number of Ifs folded.
Requires torch._C._jit_interpret_graph, exposed since torch 1.10.
The probe is deferred until a candidate If is actually found, so
callers on older torch with already-clean graphs (no remaining Ifs)
return 0 without raising. The rest of torch_to_nnef still works
on torch 1.8+, only this one pass needs the newer API.
fold_tuple_index_through_tuple_construct
Fold prim::TupleIndex(tuple_const, k) into the k-th tuple input.
JIT artifacts whose Python source builds a tuple at the call site
(e.g. return (h, c)) and consumes it later via positional indexing
(pair[0], pair[1]) leave behind prim::TupleConstruct ->
prim::TupleIndex chains in the inlined graph. t2n's parser already
knows about TupleConstruct and TupleUnpack, but TupleIndex is
unsupported. When the index is a static prim::Constant int and the
tuple value is the direct output of a TupleConstruct, we rewire
TupleIndex's output to the tuple's k-th input verbatim, leaving
the TupleConstruct itself in place (DCE removes it later if it
has no other consumers).
Returns the count of nodes folded.
fold_tuple_unpack_through_tuple_construct
Fold prim::TupleUnpack(prim::TupleConstruct(...)) into the inputs.
Sibling of fold_tuple_index_through_tuple_construct. JIT artifacts
that build a tuple at the call site and consume it via destructuring
assignment (a, b = my_pair()) leave behind a
prim::TupleConstruct -> prim::TupleUnpack chain after inlining.
The k-th unpack output is exactly the k-th construct input; we
rewire each unpack output verbatim and destroy the unpack. The
TupleConstruct is left in place; DCE removes it if its only
consumers (the unpack outputs) are now gone.
Returns the count of TupleUnpack nodes folded.
harden_jit_for_export
harden_jit_for_export(model: 'torch.jit.ScriptModule', args: T.Union[T.Sequence[T.Any], 'torch.Tensor'], *, freeze: bool = True, diagnostics: T.Optional[T.Dict[str, T.Any]] = None) -> 'torch.jit.ScriptModule'
Specialize a JIT ScriptModule's graph for the given example inputs.
Returns the (possibly frozen) module with the chain applied in place. The chain has two stages.
Resolve CallMethod / CallFunction / GetAttr (one of):
torch.jit.freeze(freeze=True, default). Requires the module to be in eval mode; pass amodel.eval()-ed instance or setfreeze=False. OnRuntimeError, the helper logs a warning and falls back to the manual inline.inline_unresolvable_submodules. Used when freeze was disabled or failed; covers CallMethods whose target class isn't on the import path. Skipped when freeze succeeded.
Specialize for the example inputs (always, in this order):
replace_size_calls_with_constants: forward-reach analysis foldsaten::dim/size/len/numelwhose values flow only into control flow. Tensor-shape consumers are left dynamic.fold_constant_scalar_arithmetic: cmps,__not__,__contains__, scalar casts.fold_constant_ifs: drop Ifs with a constant boolean condition.fold_tuple_index_through_tuple_constructandfold_tuple_unpack_through_tuple_construct: collapse theTupleConstruct -> TupleIndexandTupleConstruct -> TupleUnpackround-trips.strip_prim_data: dropprim::data(t)(autograd-detach no-op).strip_assertion_ifs: drop Ifs whose one branch is a pureRaiseException.fold_data_dependent_ifs: evaluate any remaining If condition under the example inputs and inline the chosen branch.
args is the model's forward arguments (no implicit self); the
helper prepends the (possibly frozen) self receiver internally
when invoking passes that re-execute the graph through
torch._C._jit_interpret_graph. A single torch.Tensor is accepted
as a shorthand for (tensor,), matching the permissive convention
of export_model_to_nnef(model, args=...).
When diagnostics is a dict, it is populated in place with the
count of nodes folded / stripped per pass, keyed by pass name. The
froze key records whether freeze succeeded;
inline_unresolvable_submodules is only present when the manual
inline ran (i.e. freeze was disabled or failed). Useful for
debugging unfamiliar JIT artifacts.
inline_unresolvable_submodules
inline_unresolvable_submodules(graph: 'torch._C.Graph', model: nn.Module, max_passes: int = 1024) -> int
Inline every prim::CallMethod whose target class is not importable.
Iterates until fixed point: an inlined body may itself contain further
CallMethods that also need inlining. max_passes bounds the loop
against pathological non-termination on weirdly structured graphs;
raise it for unusually deep nested inlines, lower it to fail fast
when debugging.
Returns the count of inlined calls.
replace_size_calls_with_constants
replace_size_calls_with_constants(graph: 'torch._C.Graph', example_inputs: T.Sequence[T.Any]) -> int
Fold size queries whose values flow only into control flow.
Reach analysis: walks forward from each candidate source. A source is
folded only when every reach path terminates in prim::If condition,
prim::Loop trip count, or prim::RaiseException without ever
crossing a node that produces a tensor-typed output. Sources whose
value flows into tensor production (via aten::reshape, aten::view,
aten::expand, aten::zeros, ...) are left alone, so the standard
aten::size handler in op/aten/other.py can route them through
tract_core_shape_of under inference_target.has_dynamic_axes.
This makes the pass safe by default for any export target, including
those declaring dynamic axes: a dim consumed by aten::view will not
be baked into the NNEF graph as a constant.
Returns the count of size-call nodes folded.
strip_assertion_ifs
Drop prim::If nodes whose one branch is purely a RaiseException.
Replace uses of the If's outputs with the non-raising block's outputs, then destroy the If. Walks nested blocks (assertion ifs are often inside other prim::If branches). Returns the count of stripped nodes.