Skip to content

norm

torch_to_nnef.op.aten.norm

batch_norm

batch_norm(g, node, name_to_tensor, null_ref, inference_target, **kwargs)

Translate operator aten::batch_norm to NNEF.

Nnef inputs:. input: tensor mean: tensor variance: tensor offset: tensor scale: tensor epsilon: scalar

nnef op

output = offset + scale * (input - mean) / sqrt(variance + epsilon);

group_norm

group_norm(g, node, name_to_tensor, inference_target, **kwargs)

Translate operators aten::group_norm to NNEF.

Decomposed flow:

  1. Reshape input from (B, C, *spatial) to (B, C, S) where S = prod(spatial): the t2n emitter knows the spatial shape statically and does the flatten here.
  2. Call the group_norm fragment, which works entirely in 3D (B, num_groups, C/num_groups * S) then projects back to (B, C, S). The fragment does NOT apply scale/offset.
  3. Reshape the 3D result back to (B, C, *spatial).
  4. Multiply by scale and add offset: both pre-unsqueezed to trailing-1 shape so NNEF's left-aligned broadcast extends them cleanly to the full input rank (this is the same pattern other norms use).

instance_norm

instance_norm(node, op_helper, **kwargs)

Map PyTorch: 'aten:instance_norm' to NNEF via a fragment.

instance_norm(input, weight?, bias?, running_mean?, running_var?, use_input_stats, momentum, eps, cudnn_enabled) normalises each (n, c) plane independently using (x - mean) / sqrt(var + eps) over the spatial axes. The optional affine pair is reshaped to (1, C, 1, ..., 1) and applied as a post-multiply / add. The fragment lives in op/fragment/instance_norm.nnef and uses only NNEF stdlib (moments / sub / add / sqrt / div).

layer_norm

layer_norm(g, node, name_to_tensor, null_ref, inference_target, **kwargs)

Map PyTorch: 'aten:layer_norm', 'aten:native_layer_norm' to NNEF.

When the input is fp16 and inference_target.force_norm_in_f32 is set, sandwich the fragment between an upcast to f32 and a downcast back to the traced output dtype: keeps the variance/rsqrt and the affine (x - mean) * weight + bias in f32 for stability, and aligns dtypes when an f32-attention residual flows in (which would otherwise hit tract's RmsNorm-folded layer_norm op with mismatched operand dtypes -- the "tensor is F32, accessed as F16" crash).

norm

norm(g, node, name_to_tensor, inference_target, **kwargs)

NOTE this is only the normed vector.

prefer_native_tract_rms_norm

prefer_native_tract_rms_norm(inference_target, mean_axes) -> bool

Return True when we should emit tract's native rms_norm primitive.

Native tract_transformers_rms_norm is registered through tract's transformers extension, which t2n only auto-enables (via the --nnef-tract-transformers CLI flag) for tract >= 0.22.0. The native op also takes a single integer axis, so multi-axis normalized_shape keeps the fragment fallback.

rms_norm

rms_norm(g, node, name_to_tensor, inference_target, null_ref, **kwargs)

Map PyTorch: 'aten:rms_norm' to NNEF.

Signature from torch.nn.functional.rms_norm: rms_norm(input, normalized_shape, weight, eps)

On tract >= 0.22.0 with a single normalized dim, emit the native tract_transformers_rms_norm op (gives tract access to its optimized GPU kernels and rewrite rules) and chain a mul for elementwise affine. Multi-axis normalized_shape and non-tract targets fall back to the custom rms_norm{,_with_affine} fragments.