Skip to content

jit_passes

torch_to_nnef.torch_graph.jit_passes

Reusable JIT-graph passes for hardening JIT-only models against t2n parsing.

These helpers are useful when the source model arrives as a torch.jit.JIT artifact (e.g. Silero-VAD's silero_vad.jit) whose Python source isn't on the import path. After torch._C._jit_pass_inline flattens the graph, PyTorch's compiled-in dim/shape assertions (notably inside nn.LSTMCell and STFT helpers) leave behind prim::If nodes whose only effect on one branch is to raise an exception. Those branches feed scalar-typed arithmetic that t2n's tensor-oriented parser cannot represent, so we drop them.

fold_constant_ifs

fold_constant_ifs(graph: 'torch._C.Graph') -> int

Fold prim::If nodes whose condition is a prim::Constant[bool].

Replaces the If with the chosen block's nodes. Returns the count folded.

fold_constant_scalar_arithmetic

fold_constant_scalar_arithmetic(graph: 'torch._C.Graph') -> int

Fold scalar arithmetic on prim::Constant operands.

Walks aten::eq/ne/lt/le/gt/ge, aten::__not__, aten::__contains__, and the unary aten::Bool/Int/Float casts. Standalone replacement for _jit_pass_constant_propagation: used in the JIT-only export chain to avoid a torch internal assertion that fires when the upstream pass walks a graph mixing Phase 1 inlined submodules and Phase 2 size-fold constants.

Returns the number of nodes folded.

fold_data_dependent_ifs

fold_data_dependent_ifs(graph: 'torch._C.Graph', example_inputs: T.Sequence[T.Any]) -> int

Fold prim::If nodes whose condition is data-dependent on the input.

PyTorch's JIT shape-analysis passes do not propagate shapes through prim::If nodes that produce tensors, leaving runtime dim/shape checks (e.g. nn.LSTMCell's if input.dim() == 1: ...) unresolved by replace_size_calls_with_constants + fold_constant_ifs. To specialize the graph for the user's example inputs, we evaluate each remaining prim::If's condition by running the graph itself with the example, observing the chosen branch, and inlining it.

Only top-level Ifs are folded each pass; nested Ifs surface to the top once their parent is removed, so a fixed-point loop catches them too.

Returns the number of Ifs folded.

Requires torch._C._jit_interpret_graph, exposed since torch 1.10. The probe is deferred until a candidate If is actually found, so callers on older torch with already-clean graphs (no remaining Ifs) return 0 without raising. The rest of torch_to_nnef still works on torch 1.8+, only this one pass needs the newer API.

fold_tuple_index_through_tuple_construct

fold_tuple_index_through_tuple_construct(graph: 'torch._C.Graph') -> int

Fold prim::TupleIndex(tuple_const, k) into the k-th tuple input.

JIT artifacts whose Python source builds a tuple at the call site (e.g. return (h, c)) and consumes it later via positional indexing (pair[0], pair[1]) leave behind prim::TupleConstruct -> prim::TupleIndex chains in the inlined graph. t2n's parser already knows about TupleConstruct and TupleUnpack, but TupleIndex is unsupported. When the index is a static prim::Constant int and the tuple value is the direct output of a TupleConstruct, we rewire TupleIndex's output to the tuple's k-th input verbatim, leaving the TupleConstruct itself in place (DCE removes it later if it has no other consumers).

Returns the count of nodes folded.

fold_tuple_unpack_through_tuple_construct

fold_tuple_unpack_through_tuple_construct(graph: 'torch._C.Graph') -> int

Fold prim::TupleUnpack(prim::TupleConstruct(...)) into the inputs.

Sibling of fold_tuple_index_through_tuple_construct. JIT artifacts that build a tuple at the call site and consume it via destructuring assignment (a, b = my_pair()) leave behind a prim::TupleConstruct -> prim::TupleUnpack chain after inlining. The k-th unpack output is exactly the k-th construct input; we rewire each unpack output verbatim and destroy the unpack. The TupleConstruct is left in place; DCE removes it if its only consumers (the unpack outputs) are now gone.

Returns the count of TupleUnpack nodes folded.

replace_size_calls_with_constants

replace_size_calls_with_constants(graph: 'torch._C.Graph', example_inputs: T.Sequence[T.Any]) -> int

Fold size queries whose values flow only into control flow.

Reach analysis: walks forward from each candidate source. A source is folded only when every reach path terminates in prim::If condition, prim::Loop trip count, or prim::RaiseException without ever crossing a node that produces a tensor-typed output. Sources whose value flows into tensor production (via aten::reshape, aten::view, aten::expand, aten::zeros, ...) are left alone, so the standard aten::size handler in op/aten/other.py can route them through tract_core_shape_of under inference_target.has_dynamic_axes.

This makes the pass safe by default for any export target, including those declaring dynamic axes: a dim consumed by aten::view will not be baked into the NNEF graph as a constant.

Returns the count of size-call nodes folded.

strip_assertion_ifs

strip_assertion_ifs(graph: 'torch._C.Graph') -> int

Drop prim::If nodes whose one branch is purely a RaiseException.

Replace uses of the If's outputs with the non-raising block's outputs, then destroy the If. Walks nested blocks (assertion ifs are often inside other prim::If branches). Returns the count of stripped nodes.

strip_prim_data

strip_prim_data(graph: 'torch._C.Graph') -> int

Replace prim::data(t) nodes with their input.

prim::data is the IR form of Tensor .data access (detaches from autograd). In inference it is a no-op; t2n's parser doesn't have a handler for it, so we elide it.