harden
torch_to_nnef.torch_graph.harden
High-level helper that runs the JIT-only export hardening chain.
The individual passes in jit_inline and jit_passes are exposed for
fine-grained use (custom orderings, partial chains for debugging). Most
callers want the full chain; harden_jit_for_export wraps it with a
sensible default order and freeze step.
Each pass is a no-op on graphs that don't carry the relevant pattern, so the chain is safe to apply unconditionally.
Supported on torch >= 1.10 (the only API requirement is
torch._C._jit_interpret_graph, used by fold_data_dependent_ifs and
probed lazily). CI gates the chain on torch 2.11.0; older versions work
but are not regression-tested.
harden_jit_for_export
harden_jit_for_export(model: 'torch.jit.ScriptModule', args: T.Union[T.Sequence[T.Any], 'torch.Tensor'], *, freeze: bool = True, diagnostics: T.Optional[T.Dict[str, T.Any]] = None) -> 'torch.jit.ScriptModule'
Specialize a JIT ScriptModule's graph for the given example inputs.
Returns the (possibly frozen) module with the chain applied in place. The chain has two stages.
Resolve CallMethod / CallFunction / GetAttr (one of):
torch.jit.freeze(freeze=True, default). Requires the module to be in eval mode; pass amodel.eval()-ed instance or setfreeze=False. OnRuntimeError, the helper logs a warning and falls back to the manual inline.inline_unresolvable_submodules. Used when freeze was disabled or failed; covers CallMethods whose target class isn't on the import path. Skipped when freeze succeeded.
Specialize for the example inputs (always, in this order):
replace_size_calls_with_constants: forward-reach analysis foldsaten::dim/size/len/numelwhose values flow only into control flow. Tensor-shape consumers are left dynamic.fold_constant_scalar_arithmetic: cmps,__not__,__contains__, scalar casts.fold_constant_ifs: drop Ifs with a constant boolean condition.fold_tuple_index_through_tuple_constructandfold_tuple_unpack_through_tuple_construct: collapse theTupleConstruct -> TupleIndexandTupleConstruct -> TupleUnpackround-trips.strip_prim_data: dropprim::data(t)(autograd-detach no-op).strip_assertion_ifs: drop Ifs whose one branch is a pureRaiseException.fold_data_dependent_ifs: evaluate any remaining If condition under the example inputs and inline the chosen branch.
args is the model's forward arguments (no implicit self); the
helper prepends the (possibly frozen) self receiver internally
when invoking passes that re-execute the graph through
torch._C._jit_interpret_graph. A single torch.Tensor is accepted
as a shorthand for (tensor,), matching the permissive convention
of export_model_to_nnef(model, args=...).
When diagnostics is a dict, it is populated in place with the
count of nodes folded / stripped per pass, keyed by pass name. The
froze key records whether freeze succeeded;
inline_unresolvable_submodules is only present when the manual
inline ran (i.e. freeze was disabled or failed). Useful for
debugging unfamiliar JIT artifacts.