Skip to content

axis_registry

torch_to_nnef.nemo_tract.axis_registry

AxisSymbolRegistry dataclass

AxisSymbolRegistry(symbols_per_input: T.Dict[str, AxisSymbolMap], rank_per_input: T.Dict[str, int], bind_to_dim: T.Dict[str, str], input_collapse_dims: T.Dict[str, T.List[str]], renamed_symbols_per_subnet: T.Dict[str, T.Dict[str, T.List[str]]], outputs_keep_per_subnet: T.Dict[str, T.List[str]], output_collapse_dims: T.Dict[str, T.List[int]] = dict(), eval_symbols_per_input: T.Dict[str, T.Dict[str, int]] = dict(), original_shape_per_input: T.Dict[str, T.List[T.Union[int, str]]] = dict(), extensions_per_subnet: T.Dict[str, T.List[str]] = dict())

Registry mapping input names to symbolic axis annotations.

Attributes:

Name Type Description
symbols_per_input Dict[str, AxisSymbolMap]

map from fully qualified input name (e.g., "encoder.audio_signal") to axis-index→symbol map.

load_axis_symbol_registry

load_axis_symbol_registry(config_path: Path) -> AxisSymbolRegistry

Load a YAML/JSON shape config into an AxisSymbolRegistry.

The expected structure is a mapping of input-name → list of dims, e.g.: encoder.audio_signal: [B, 128, S] encoder.length: [B] joiner.encoder_outputs: [B, 1024, R] joiner.decoder_outputs: [B, 640, U]