Skip to content

torch_to_nnef.remodeler

torch_to_nnef.remodeler

Provider-agnostic boundary remodeler scaffold.

This module defines small, typed building blocks to describe IO signatures and boundary-only transforms (collapse, bind, and backend-facing symbol renames), plus helpers to load/save a strict nested config.

Notes: - The concrete YAML/JSON schema is parsed by domain-specific loaders (e.g. AxisSymbolRegistry in the NeMo package). - Providers are expected to discover per-subnet signatures, and to apply a remodel plan by wrapping inner modules with an adapter that enforces the external boundary while preserving the internal contract.

BoundaryAdapter

BoundaryAdapter(module: torch.nn.Module, subnet_name: str, input_example: list, dynamic_axes: T.Optional[dict[str, dict[int, str]]], collapse_by_input: T.Optional[dict[str, set[str]]], binds_by_input: T.Optional[dict[str, str]] = None, renamed_map: T.Optional[dict[str, list[str]]] = None, outputs_keep: T.Optional[list[str]] = None, output_collapse_dims: T.Optional[dict[str, list[int]]] = None, *, apply_symbol_renames: bool = True)

Bases: Module

Boundary adapter applying tuple flattening and collapse at export time.

Both inputs AND outputs are flattened to flat tensor IO using :func:build_new_names_and_elements, so that the external interface (names, examples, dynamic axes) is always in terms of flat tensors.

When outputs_keep is set the adapter runs one forward pass at construction time to discover the output structure, then validates outputs_keep against the flattened output names and filters accordingly.

IODescriptor dataclass

IODescriptor(name: str, shape: list[T.Union[int, str]], dtype: T.Optional[str] = None, notes: T.Optional[list[str]] = None)

Description of a single input or output.

PreparedSubnet dataclass

PreparedSubnet(model: torch.nn.Module, test_input: list, input_names: list[str], output_names: list[str], dyn: dict[str, dict[int, str]], custom_extensions: list[str])

Result of applying registry-driven transforms to a raw subnet.

Provider

Bases: Protocol

Provider SPI to discover signatures and apply remodel plans.

Implementations should be small adapters around an existing provider (e.g., NeMo, plain PyTorch) that can: - discover raw and post-processed signatures for inspection - apply the boundary remodel plan to return wrapped modules ready for export

RemodelPlan dataclass

RemodelPlan(registry: T.Any)

A remodel plan built from a validated axis-symbol registry.

  • registry: The validated, parsed axis registry (nested schema). Typically an AxisSymbolRegistry instance provided by a domain package (e.g. torch_to_nnef_nemo).

RenameOutputs

RenameOutputs(module: torch.nn.Module, rename_map: T.Dict[str, str])

Bases: Module

Wrapper that renames output tensor names for export-time only.

Leaves computation unchanged and preserves input names. Useful to avoid name collisions between inputs and outputs (e.g., both named 'length').

Stage

Bases: Enum

Logical stages for signature inspection.

order property
order: int

Stable sort order (RAW < COLLAPSED < BOUND < FINAL).

SubnetSignature dataclass

SubnetSignature(name: str, stage: Stage, inputs: list[IODescriptor], outputs: list[IODescriptor], symbol_axes: T.Optional[dict[str, dict[int, str]]] = None, applied_flags: T.List[str] = list())

Per-subnet signature snapshot at a given stage.

apply_eval_symbols

apply_eval_symbols(test_input: list, input_names: list[str], subnet_name: str, dyn: T.Dict[str, T.Dict[int, str]], eval_symbols: T.Dict[str, T.Dict[str, int]]) -> list

Resize test_input tensors according to eval_symbols.

apply_symbol_renames_to_dyn

apply_symbol_renames_to_dyn(dyn: T.Dict[str, T.Dict[int, str]], rename_map: T.Dict[str, T.List[str]]) -> T.Dict[str, T.Dict[int, str]]

Apply symbol renames directly to a dynamic axes mapping.

This is the lightweight alternative to BoundaryAdapter when only symbol renames are needed (no collapse, bind, or output filtering).

plan_from_registry

plan_from_registry(registry: T.Any) -> RemodelPlan

Build a remodel plan from a validated registry (provider-specific).

prepare_subnet_export

prepare_subnet_export(model: torch.nn.Module, test_input: list, input_names: list[str], output_names: list[str], subnet_name: str, dyn: dict[str, dict[int, str]], custom_extensions: list[str], axis_registry: T.Optional[T.Any] = None) -> PreparedSubnet

Apply all registry-driven transforms and return export-ready data.

This consolidates eval-symbol resizing, boundary adaptation (collapse, bind, rename, output filtering), assertion rewriting, extension merging, and eval-symbol pinning into a single call.

Providers feed in raw subnet data; the remodeler returns everything needed to build final export parameters.

remove_eval_symbols_from_dyn

remove_eval_symbols_from_dyn(input_names: list[str], subnet_name: str, dyn: T.Dict[str, T.Dict[int, str]], eval_symbols: T.Dict[str, T.Dict[str, int]]) -> None

Remove pinned axes from dyn so the backend treats them as constant.

Must be called after the BoundaryAdapter is built, because the adapter needs the symbols to resolve bindings.

rewrite_and_filter_assertions

rewrite_and_filter_assertions(assertions: list[str], rename_map: T.Optional[dict[str, list[str]]], dyn: T.Optional[dict[str, dict[int, str]]]) -> list[str]

Rewrite assertions and drop those referencing removed symbols.

  • Applies symbol renames so source symbols map to their target alias.
  • Computes the set of present symbols from the current dynamic axes and discards any assertion that mentions a symbol not present after rewriting.
  • Returns de-duplicated assertions.

rewrite_assertions_with_renames

rewrite_assertions_with_renames(assertions: list[str], rename_map: T.Optional[dict[str, list[str]]]) -> list[str]

Rewrite assertion symbol names based on a rename mapping.

Parameters:

Name Type Description Default
assertions list[str]

List of assertion strings, e.g. "tract_assert U = BATCH".

required
rename_map Optional[dict[str, list[str]]]

Mapping of target symbol to list of source symbols that should be rewritten to the target. Comparison is case-insensitive; rewritten symbols are emitted uppercased.

required

Returns:

Type Description
list[str]

A list of assertions with symbols rewritten according to

list[str]

the provided mapping. Unknown tokens are left unchanged.

save_config

save_config(path: T.Union[Path, str, None], registry: T.Any, *, flow_seq: bool = True, stream: T.Optional[T.TextIO] = None) -> None

Save an AxisSymbolRegistry to YAML or JSON.

Parameters:

Name Type Description Default
path Union[Path, str, None]

Output file path (.yml/.yaml/.json) or None when using stream.

required
registry Any

Parsed registry to serialize.

required
flow_seq bool

When YAML, render short lists in flow style.

True
stream Optional[TextIO]

Optional text stream to write into (YAML or JSON). When provided, path can be None and the function infers YAML vs JSON based on the filename if available; defaults to YAML behavior.

None