Skip to content

exp_norm

torch_to_nnef.op.extras.exp_norm

Handlers for t2n_extra::exp_unit_norm / exp_mean_norm ops.

Both lower to the matching tract_extra_exp_unit_norm / tract_extra_exp_mean_norm NNEF ops. Tract's OpPulsifier is already registered for ExpUnitNorm (see tract/extra/src/exp_unit_norm.rs), so a streaming-axis trace pulses end-to-end.

DPDFNet's ErbNorm (centring per-frame EMA followed by a fixed-std scale) maps to exp_mean_norm with scaling_factor=40.0. SpecNorm (per-frame magnitude EMA divided into the complex spectrum) maps to exp_unit_norm with complex=True.

Signatures (matching the eager torch.library declarations in example/test code):

t2n_extra::exp_unit_norm(input, state_init, axis, alpha, epsilon, complex)
t2n_extra::exp_mean_norm(input, state_init, axis, alpha, scaling_factor)

state_init has the shape of input with axis removed (and, for complex=True, the trailing-2 axis also removed); pass zeros at trace time -- the pulsifier overrides skip with the runtime delay.

exp_mean_norm

exp_mean_norm(g, node, name_to_tensor, op_helper, inference_target, **kwargs) -> T.List[str]

Lower t2n_extra::exp_mean_norm to tract_extra_exp_mean_norm.

Eager signature:

t2n_extra::exp_mean_norm(input, state_init, axis: int,
                         alpha: float, scaling_factor: float)
                         -> Tensor

Centres input with a per-time-step EMA mean, then divides by scaling_factor. state_init's shape is input with axis removed.

exp_unit_norm

exp_unit_norm(g, node, name_to_tensor, op_helper, inference_target, **kwargs) -> T.List[str]

Lower t2n_extra::exp_unit_norm to tract_extra_exp_unit_norm.

Eager signature (matches the example / test declarations):

t2n_extra::exp_unit_norm(input, state_init, axis: int,
                         alpha: float, epsilon: float,
                         complex: bool) -> Tensor

Computes a per-time-step EMA of the input magnitude, then divides input by sqrt(state) along axis. state_init is the initial hidden state (zeros at trace time); its shape equals input with axis removed (and the trailing-2 axis removed for complex=True).