torch_to_nnef.remodeler
torch_to_nnef.remodeler
Provider-agnostic boundary remodeler scaffold.
This module defines small, typed building blocks to describe IO signatures and boundary-only transforms (collapse, bind, and backend-facing symbol renames), plus helpers to load/save a strict nested config.
Notes: - The concrete YAML/JSON schema is parsed by domain-specific loaders (e.g. AxisSymbolRegistry in the NeMo package). - Providers are expected to discover per-subnet signatures, and to apply a remodel plan by wrapping inner modules with an adapter that enforces the external boundary while preserving the internal contract.
BoundaryAdapter
BoundaryAdapter(module: torch.nn.Module, subnet_name: str, input_example: list, dynamic_axes: T.Optional[dict[str, dict[int, str]]], collapse_by_input: T.Optional[dict[str, set[str]]], binds_by_input: T.Optional[dict[str, str]] = None, renamed_map: T.Optional[dict[str, list[str]]] = None, outputs_keep: T.Optional[list[str]] = None, output_collapse_dims: T.Optional[dict[str, list[int]]] = None, *, apply_symbol_renames: bool = True)
Bases: Module
Boundary adapter applying tuple flattening and collapse at export time.
Both inputs AND outputs are flattened to flat tensor IO using
:func:build_new_names_and_elements, so that the external interface
(names, examples, dynamic axes) is always in terms of flat tensors.
When outputs_keep is set the adapter runs one forward pass at construction time to discover the output structure, then validates outputs_keep against the flattened output names and filters accordingly.
IODescriptor
dataclass
IODescriptor(name: str, shape: list[T.Union[int, str]], dtype: T.Optional[str] = None, notes: T.Optional[list[str]] = None)
Description of a single input or output.
PreparedSubnet
dataclass
PreparedSubnet(model: torch.nn.Module, test_input: list, input_names: list[str], output_names: list[str], dyn: dict[str, dict[int, str]], custom_extensions: list[str])
Result of applying registry-driven transforms to a raw subnet.
Provider
Bases: Protocol
Provider SPI to discover signatures and apply remodel plans.
Implementations should be small adapters around an existing provider (e.g., NeMo, plain PyTorch) that can: - discover raw and post-processed signatures for inspection - apply the boundary remodel plan to return wrapped modules ready for export
RemodelPlan
dataclass
A remodel plan built from a validated axis-symbol registry.
- registry: The validated, parsed axis registry (nested schema).
Typically an
AxisSymbolRegistryinstance provided by a domain package (e.g.torch_to_nnef_nemo).
RenameOutputs
Bases: Module
Wrapper that renames output tensor names for export-time only.
Leaves computation unchanged and preserves input names. Useful to avoid name collisions between inputs and outputs (e.g., both named 'length').
Stage
SubnetSignature
dataclass
SubnetSignature(name: str, stage: Stage, inputs: list[IODescriptor], outputs: list[IODescriptor], symbol_axes: T.Optional[dict[str, dict[int, str]]] = None, applied_flags: T.List[str] = list())
Per-subnet signature snapshot at a given stage.
apply_eval_symbols
apply_eval_symbols(test_input: list, input_names: list[str], subnet_name: str, dyn: T.Dict[str, T.Dict[int, str]], eval_symbols: T.Dict[str, T.Dict[str, int]]) -> list
Resize test_input tensors according to eval_symbols.
apply_symbol_renames_to_dyn
apply_symbol_renames_to_dyn(dyn: T.Dict[str, T.Dict[int, str]], rename_map: T.Dict[str, T.List[str]]) -> T.Dict[str, T.Dict[int, str]]
Apply symbol renames directly to a dynamic axes mapping.
This is the lightweight alternative to BoundaryAdapter when only symbol renames are needed (no collapse, bind, or output filtering).
plan_from_registry
Build a remodel plan from a validated registry (provider-specific).
prepare_subnet_export
prepare_subnet_export(model: torch.nn.Module, test_input: list, input_names: list[str], output_names: list[str], subnet_name: str, dyn: dict[str, dict[int, str]], custom_extensions: list[str], axis_registry: T.Optional[T.Any] = None) -> PreparedSubnet
Apply all registry-driven transforms and return export-ready data.
This consolidates eval-symbol resizing, boundary adaptation (collapse, bind, rename, output filtering), assertion rewriting, extension merging, and eval-symbol pinning into a single call.
Providers feed in raw subnet data; the remodeler returns everything needed to build final export parameters.
remove_eval_symbols_from_dyn
remove_eval_symbols_from_dyn(input_names: list[str], subnet_name: str, dyn: T.Dict[str, T.Dict[int, str]], eval_symbols: T.Dict[str, T.Dict[str, int]]) -> None
Remove pinned axes from dyn so the backend treats them as constant.
Must be called after the BoundaryAdapter is built, because the adapter needs the symbols to resolve bindings.
rewrite_and_filter_assertions
rewrite_and_filter_assertions(assertions: list[str], rename_map: T.Optional[dict[str, list[str]]], dyn: T.Optional[dict[str, dict[int, str]]]) -> list[str]
Rewrite assertions and drop those referencing removed symbols.
- Applies symbol renames so source symbols map to their target alias.
- Computes the set of present symbols from the current dynamic axes and discards any assertion that mentions a symbol not present after rewriting.
- Returns de-duplicated assertions.
rewrite_assertions_with_renames
rewrite_assertions_with_renames(assertions: list[str], rename_map: T.Optional[dict[str, list[str]]]) -> list[str]
Rewrite assertion symbol names based on a rename mapping.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
assertions |
list[str]
|
List of assertion strings, e.g. "tract_assert U = BATCH". |
required |
rename_map |
Optional[dict[str, list[str]]]
|
Mapping of target symbol to list of source symbols that should be rewritten to the target. Comparison is case-insensitive; rewritten symbols are emitted uppercased. |
required |
Returns:
| Type | Description |
|---|---|
list[str]
|
A list of assertions with symbols rewritten according to |
list[str]
|
the provided mapping. Unknown tokens are left unchanged. |
save_config
save_config(path: T.Union[Path, str, None], registry: T.Any, *, flow_seq: bool = True, stream: T.Optional[T.TextIO] = None) -> None
Save an AxisSymbolRegistry to YAML or JSON.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path |
Union[Path, str, None]
|
Output file path (.yml/.yaml/.json) or None when using stream. |
required |
registry |
Any
|
Parsed registry to serialize. |
required |
flow_seq |
bool
|
When YAML, render short lists in flow style. |
True
|
stream |
Optional[TextIO]
|
Optional text stream to write into (YAML or JSON). When
provided, |
None
|