axis_registry
torch_to_nnef.nemo_tract.axis_registry
AxisSymbolRegistry
dataclass
AxisSymbolRegistry(symbols_per_input: T.Dict[str, AxisSymbolMap], rank_per_input: T.Dict[str, int], bind_to_dim: T.Dict[str, str], input_collapse_dims: T.Dict[str, T.List[str]], renamed_symbols_per_subnet: T.Dict[str, T.Dict[str, T.List[str]]], outputs_keep_per_subnet: T.Dict[str, T.List[str]], output_collapse_dims: T.Dict[str, T.List[int]] = dict(), eval_symbols_per_input: T.Dict[str, T.Dict[str, int]] = dict(), original_shape_per_input: T.Dict[str, T.List[T.Union[int, str]]] = dict(), extensions_per_subnet: T.Dict[str, T.List[str]] = dict())
Registry mapping input names to symbolic axis annotations.
Attributes:
| Name | Type | Description |
|---|---|---|
symbols_per_input |
Dict[str, AxisSymbolMap]
|
map from fully qualified input name (e.g., "encoder.audio_signal") to axis-index→symbol map. |
load_axis_symbol_registry
Load a YAML/JSON shape config into an AxisSymbolRegistry.
The expected structure is a mapping of input-name → list of dims, e.g.: encoder.audio_signal: [B, 128, S] encoder.length: [B] joiner.encoder_outputs: [B, 1024, R] joiner.decoder_outputs: [B, 640, U]