Skip to content

dyn_axes

torch_to_nnef.remodeler.dyn_axes

Provider-agnostic helpers for dynamic-axes manipulation.

These utilities operate on the generic dyn mapping ({input_name: {axis_index: symbol}}) and assertion/extension strings. They are used by NeMo and can be reused by any provider.

apply_eval_symbols

apply_eval_symbols(test_input: list, input_names: list[str], subnet_name: str, dyn: T.Dict[str, T.Dict[int, str]], eval_symbols: T.Dict[str, T.Dict[str, int]]) -> list

Resize test_input tensors according to eval_symbols.

apply_symbol_renames_to_dyn

apply_symbol_renames_to_dyn(dyn: T.Dict[str, T.Dict[int, str]], rename_map: T.Dict[str, T.List[str]]) -> T.Dict[str, T.Dict[int, str]]

Apply symbol renames directly to a dynamic axes mapping.

This is the lightweight alternative to BoundaryAdapter when only symbol renames are needed (no collapse, bind, or output filtering).

remove_eval_symbols_from_dyn

remove_eval_symbols_from_dyn(input_names: list[str], subnet_name: str, dyn: T.Dict[str, T.Dict[int, str]], eval_symbols: T.Dict[str, T.Dict[str, int]]) -> None

Remove pinned axes from dyn so the backend treats them as constant.

Must be called after the BoundaryAdapter is built, because the adapter needs the symbols to resolve bindings.

rewrite_and_filter_assertions

rewrite_and_filter_assertions(assertions: list[str], rename_map: T.Optional[dict[str, list[str]]], dyn: T.Optional[dict[str, dict[int, str]]]) -> list[str]

Rewrite assertions and drop those referencing removed symbols.

  • Applies symbol renames so source symbols map to their target alias.
  • Computes the set of present symbols from the current dynamic axes and discards any assertion that mentions a symbol not present after rewriting.
  • Returns de-duplicated assertions.

rewrite_assertions_with_renames

rewrite_assertions_with_renames(assertions: list[str], rename_map: T.Optional[dict[str, list[str]]]) -> list[str]

Rewrite assertion symbol names based on a rename mapping.

Parameters:

Name Type Description Default
assertions list[str]

List of assertion strings, e.g. "tract_assert U = BATCH".

required
rename_map Optional[dict[str, list[str]]]

Mapping of target symbol to list of source symbols that should be rewritten to the target. Comparison is case-insensitive; rewritten symbols are emitted uppercased.

required

Returns:

Type Description
list[str]

A list of assertions with symbols rewritten according to

list[str]

the provided mapping. Unknown tokens are left unchanged.