Skip to content

registry_utils

torch_to_nnef.nemo_tract.registry_utils

auto_populate_output_collapse_dims

auto_populate_output_collapse_dims(registry: AxisSymbolRegistry) -> AxisSymbolRegistry

Auto-add collapse_dims: [0] for batch collapse.

For outputs of subnets with batch collapse on inputs.

For each subnet where at least one input has a collapse_dims containing a *__BATCH symbol, add collapse_dims: [0] for every output of that subnet (taken from outputs_keep_per_subnet). User-provided entries take precedence and are not overwritten.

Parameters:

Name Type Description Default
registry AxisSymbolRegistry

Axis symbol registry (possibly loaded from config).

required

Returns:

Type Description
AxisSymbolRegistry

The same registry instance with output_collapse_dims updated.

tie_batch_symbols_in_registry

tie_batch_symbols_in_registry(registry: AxisSymbolRegistry) -> AxisSymbolRegistry

Alias input-namespaced batch symbols to a unified BATCH per subnet.

For each subnet, collects all dynamic symbols that end with f"{_SEP}BATCH" and records a rename mapping so they are presented to backends as a single BATCH symbol. This keeps provider-facing symbols namespaced while unifying them for tract-facing assertions and adapters.

Parameters:

Name Type Description Default
registry AxisSymbolRegistry

Axis symbol registry discovered from signatures.

required

Returns:

Type Description
AxisSymbolRegistry

The same registry instance with renamed_symbols_per_subnet updated.