scan_ops
torch_to_nnef.op.extras.scan_ops
Handlers for t2n_extra::* scan-shaped ops.
Currently provides ssm_scan for Mamba's selective state-space scan.
The handler emits a mamba_ssm_scan NNEF fragment call which wraps a
tract_core_scan over a per-step mamba_ssm_step body. Tract's pulse
declutter compiles the scan into a streaming graph, so the prefill
cost is one tract call instead of one per token.
ssm_scan
Emit a mamba_ssm_scan fragment call.
Signature on the torch side
t2n_extra::ssm_scan(discrete_A, deltaB_u, C, h_init) -> (scan_outputs, h_final)
The fragment scans along axis 0 of its inputs (matches the GRU/LSTM convention). The handler pre-transposes the SSM tensors so the time axis lands at position 0 before the scan:
discrete_A (B, D, T, N) -> (T, B, D, N)
deltaB_u (B, D, T, N) -> (T, B, D, N)
C (B, T, N) -> (T, B, N)
h_init (B, D, N) -- unchanged (state)
After the scan
scan_y (T, B, D) -> (B, D, T) to match scan_outputs's
PyTorch shape (stack on last axis).
h_final (B, D, N)