jit_passes
torch_to_nnef.torch_graph.jit_passes
Reusable JIT-graph passes for hardening JIT-only models against t2n parsing.
These helpers are useful when the source model arrives as a torch.jit.JIT
artifact (e.g. Silero-VAD's silero_vad.jit) whose Python source isn't on
the import path. After torch._C._jit_pass_inline flattens the graph,
PyTorch's compiled-in dim/shape assertions (notably inside nn.LSTMCell
and STFT helpers) leave behind prim::If nodes whose only effect on one
branch is to raise an exception. Those branches feed scalar-typed
arithmetic that t2n's tensor-oriented parser cannot represent, so we drop
them.
fold_constant_ifs
Fold prim::If nodes whose condition is a prim::Constant[bool].
Replaces the If with the chosen block's nodes. Returns the count folded.
fold_constant_scalar_arithmetic
Fold scalar arithmetic on prim::Constant operands.
Walks aten::eq/ne/lt/le/gt/ge, aten::__not__,
aten::__contains__, and the unary aten::Bool/Int/Float casts.
Standalone replacement for _jit_pass_constant_propagation: used
in the JIT-only export chain to avoid a torch internal assertion
that fires when the upstream pass walks a graph mixing Phase 1
inlined submodules and Phase 2 size-fold constants.
Returns the number of nodes folded.
fold_data_dependent_ifs
Fold prim::If nodes whose condition is data-dependent on the input.
PyTorch's JIT shape-analysis passes do not propagate shapes through
prim::If nodes that produce tensors, leaving runtime dim/shape
checks (e.g. nn.LSTMCell's if input.dim() == 1: ...) unresolved
by replace_size_calls_with_constants + fold_constant_ifs. To
specialize the graph for the user's example inputs, we evaluate
each remaining prim::If's condition by running the graph itself
with the example, observing the chosen branch, and inlining it.
Only top-level Ifs are folded each pass; nested Ifs surface to the top once their parent is removed, so a fixed-point loop catches them too.
Returns the number of Ifs folded.
Requires torch._C._jit_interpret_graph, exposed since torch 1.10.
The probe is deferred until a candidate If is actually found, so
callers on older torch with already-clean graphs (no remaining Ifs)
return 0 without raising. The rest of torch_to_nnef still works
on torch 1.8+, only this one pass needs the newer API.
fold_tuple_index_through_tuple_construct
Fold prim::TupleIndex(tuple_const, k) into the k-th tuple input.
JIT artifacts whose Python source builds a tuple at the call site
(e.g. return (h, c)) and consumes it later via positional indexing
(pair[0], pair[1]) leave behind prim::TupleConstruct ->
prim::TupleIndex chains in the inlined graph. t2n's parser already
knows about TupleConstruct and TupleUnpack, but TupleIndex is
unsupported. When the index is a static prim::Constant int and the
tuple value is the direct output of a TupleConstruct, we rewire
TupleIndex's output to the tuple's k-th input verbatim, leaving
the TupleConstruct itself in place (DCE removes it later if it
has no other consumers).
Returns the count of nodes folded.
fold_tuple_unpack_through_tuple_construct
Fold prim::TupleUnpack(prim::TupleConstruct(...)) into the inputs.
Sibling of fold_tuple_index_through_tuple_construct. JIT artifacts
that build a tuple at the call site and consume it via destructuring
assignment (a, b = my_pair()) leave behind a
prim::TupleConstruct -> prim::TupleUnpack chain after inlining.
The k-th unpack output is exactly the k-th construct input; we
rewire each unpack output verbatim and destroy the unpack. The
TupleConstruct is left in place; DCE removes it if its only
consumers (the unpack outputs) are now gone.
Returns the count of TupleUnpack nodes folded.
replace_size_calls_with_constants
replace_size_calls_with_constants(graph: 'torch._C.Graph', example_inputs: T.Sequence[T.Any]) -> int
Fold size queries whose values flow only into control flow.
Reach analysis: walks forward from each candidate source. A source is
folded only when every reach path terminates in prim::If condition,
prim::Loop trip count, or prim::RaiseException without ever
crossing a node that produces a tensor-typed output. Sources whose
value flows into tensor production (via aten::reshape, aten::view,
aten::expand, aten::zeros, ...) are left alone, so the standard
aten::size handler in op/aten/other.py can route them through
tract_core_shape_of under inference_target.has_dynamic_axes.
This makes the pass safe by default for any export target, including
those declaring dynamic axes: a dim consumed by aten::view will not
be baked into the NNEF graph as a constant.
Returns the count of size-call nodes folded.
strip_assertion_ifs
Drop prim::If nodes whose one branch is purely a RaiseException.
Replace uses of the If's outputs with the non-raising block's outputs, then destroy the If. Walks nested blocks (assertion ifs are often inside other prim::If branches). Returns the count of stripped nodes.