Skip to content

dynaxes

torch_to_nnef.nemo_tract.dynaxes

Dynamic-axes helpers shared by wrappers and export code.

This module centralizes dynamic-axis symbol generation, input-name expansion, normalization of NeMo mappings, and rank-based filtering to avoid duplication.

build_dynamic_axes

build_dynamic_axes(subnet, nemo_dynamic_axes: T.Mapping[str, T.Any], input_example: T.Optional[T.Sequence[object]] = None) -> T.Tuple[T.Dict[str, T.Dict[int, str]], T.Set[str]]

Build dynamic_axes mapping for a subnet using its example.

  • Expand tuple/list inputs based on input_example.
  • Normalize NeMo mapping to base-name -> set(indices).
  • Emit symbols via make_axis_symbol for each expanded name and index.
  • Return (dynamic_axes, custom_extensions_set).

expand_input_names

expand_input_names(input_names: T.Sequence[str], input_example: T.Optional[T.Sequence[object]]) -> T.Tuple[T.Dict[str, T.List[str]], T.Dict[str, int]]

Expand tuple/list inputs and infer ranks for each expanded name.

Returns (expand_map, ranks_by_name).

filter_dynamic_axes_by_ranks

filter_dynamic_axes_by_ranks(dynamic_axes: T.Dict[str, T.Dict[int, str]], ranks_by_name: T.Mapping[str, int]) -> T.Dict[str, T.Dict[int, str]]

Drop axis indices that are >= rank for each input tensor name.

normalize_dynamic_indices

normalize_dynamic_indices(nemo_dynamic_axes: T.Mapping[str, T.Any], input_names: T.Sequence[str]) -> T.Dict[str, T.Set[int]]

Normalize NeMo dynamic mapping to base-name -> set(axis indices).

symbols_from_input_types

symbols_from_input_types(input_types) -> T.Dict[str, T.Dict[int, str]]

Build symbols per input name from a module's input_types.

Each input maps to {axis_index: symbol} via make_axis_symbol.