Skip to content

rnn

torch_to_nnef.op.aten.rnn

Aten-level RNN op handlers, adapters, and shared orchestration.

Canonical home for the RNN export math. Both the module-level extractors in op/custom_extractors/rnn.py (LSTMExtractor, GRUExtractor, RNNExtractor, LSTMCellExtractor) and the aten op handlers registered below go through the same set of free functions, so exports are byte-identical regardless of the entry point.

Layout:

  • Orchestration (multi-layer, bidirectional, batch_first, state setup): emit_rnn_via_fragment and its private helpers. Variant-agnostic; drives the per-layer / per-direction fragment call.
  • Per-variant params extraction: _lstm_tensor_params, _gru_tensor_params, _rnn_tensor_params. Read named weight attributes off a "module-like" object.
  • Aten adapters: _LSTMAtenAdapter, _GRUAtenAdapter, _RNNAtenAdapter. Build an attribute interface compatible with the per-variant tensor_params from the aten op's flat params: Tensor[] argument.
  • Aten op handlers: aten::lstm, aten::gru, aten::rnn_tanh, aten::rnn_relu, plus aten::lstm_cell (single-step, fragment-based via lstm_cell.nnef).

emit_lstm_cell_via_fragment

emit_lstm_cell_via_fragment(g, name_to_tensor, base: str, nnef_dtype, batch_dim: int, hidden: int, input_ref: NTensor, h_prev_ref: NTensor, c_prev_ref: NTensor, w_ih: torch.Tensor, w_hh: torch.Tensor, b_ih: T.Optional[torch.Tensor], b_hh: T.Optional[torch.Tensor], h_new_tv: T.Optional[TensorVariable], c_new_tv: T.Optional[TensorVariable]) -> T.List[str]

Emit a single lstm_cell NNEF fragment call.

Internally does grouped (B, I) @ (4H, I).T plus (B, H) @ (4H, H).T, adds the unsqueezed combined bias, slices into the 4 gates, and computes (h_new, c_new).

emit_rnn_via_fragment

emit_rnn_via_fragment(g, node, name_to_tensor, module, nnef_fragment_name: str, argument_names_order: T.Sequence[str], tensor_params_fn: T.Callable, **tensor_params_kwargs) -> T.List[str]

Multi-layer / bidirectional RNN orchestration around a fragment call.

Variant-agnostic: drives the per-layer-and-direction loop, calls tensor_params_fn to materialize weights / states per slice, and issues one nnef_fragment_name call per (layer, direction). Bidi packing and multi-layer concat are handled internally.

gru

gru(g, node, name_to_tensor, **kwargs)

Map aten::gru.input to NNEF via the existing gru fragment.

lstm

lstm(g, node, name_to_tensor, **kwargs)

Map aten::lstm.input to NNEF via the existing lstm fragment.

lstm_cell

lstm_cell(g, node, name_to_tensor, **kwargs)

Map aten::lstm_cell(input, hx_list, w_ih, w_hh, b_ih?, b_hh?) to NNEF.

hx_list is a t2n FixedTensorList of [h_prev, c_prev]. The output is (h_new, c_new).

rnn_relu

rnn_relu(g, node, name_to_tensor, **kwargs)

Map aten::rnn_relu.input to NNEF via existing rnn_relu fragment.

rnn_tanh

rnn_tanh(g, node, name_to_tensor, **kwargs)

Map aten::rnn_tanh.input to NNEF via existing rnn_tanh fragment.