Skip to content

scan_ops

torch_to_nnef.op.extras.scan_ops

Handlers for t2n_extra::* scan-shaped ops.

Currently provides ssm_scan for Mamba's selective state-space scan. The handler emits a mamba_ssm_scan NNEF fragment call which wraps a tract_core_scan over a per-step mamba_ssm_step body. Tract's pulse declutter compiles the scan into a streaming graph, so the prefill cost is one tract call instead of one per token.

ssm_scan

ssm_scan(g, node, name_to_tensor, op_helper, inference_target, **kwargs) -> T.List[str]

Emit a mamba_ssm_scan fragment call.

Signature on the torch side

t2n_extra::ssm_scan(discrete_A, deltaB_u, C, h_init) -> (scan_outputs, h_final)

The fragment scans along axis 0 of its inputs (matches the GRU/LSTM convention). The handler pre-transposes the SSM tensors so the time axis lands at position 0 before the scan:

discrete_A  (B, D, T, N) -> (T, B, D, N)
deltaB_u    (B, D, T, N) -> (T, B, D, N)
C           (B, T, N)    -> (T, B, N)
h_init      (B, D, N)    -- unchanged (state)
After the scan

scan_y (T, B, D) -> (B, D, T) to match scan_outputs's PyTorch shape (stack on last axis). h_final (B, D, N)

ssm_scan_y

ssm_scan_y(g, node, name_to_tensor, op_helper, inference_target, **kwargs) -> T.List[str]

Pulse-friendly variant of ssm_scan: emits only y_t (no h_final).

The Scan pulsifier in tract rejects "last" outputs (h_final). Dropping it makes the scan body compatible with into_pulse.