dynaxes
torch_to_nnef.nemo_tract.dynaxes
Dynamic-axes helpers shared by wrappers and export code.
This module centralizes dynamic-axis symbol generation, input-name expansion, normalization of NeMo mappings, and rank-based filtering to avoid duplication.
build_dynamic_axes
build_dynamic_axes(subnet, nemo_dynamic_axes: T.Mapping[str, T.Any], input_example: T.Optional[T.Sequence[object]] = None) -> T.Tuple[T.Dict[str, T.Dict[int, str]], T.Set[str]]
Build dynamic_axes mapping for a subnet using its example.
- Expand tuple/list inputs based on
input_example. - Normalize NeMo mapping to
base-name -> set(indices). - Emit symbols via
make_axis_symbolfor each expanded name and index. - Return
(dynamic_axes, custom_extensions_set).
expand_input_names
expand_input_names(input_names: T.Sequence[str], input_example: T.Optional[T.Sequence[object]]) -> T.Tuple[T.Dict[str, T.List[str]], T.Dict[str, int]]
Expand tuple/list inputs and infer ranks for each expanded name.
Returns (expand_map, ranks_by_name).
filter_dynamic_axes_by_ranks
filter_dynamic_axes_by_ranks(dynamic_axes: T.Dict[str, T.Dict[int, str]], ranks_by_name: T.Mapping[str, int]) -> T.Dict[str, T.Dict[int, str]]
Drop axis indices that are >= rank for each input tensor name.