Skip to content

rnn

torch_to_nnef.op.aten.rnn

Aten-level RNN op handlers, adapters, and shared orchestration.

Canonical home for the RNN export math. Both the module-level extractors in op/custom_extractors/rnn.py (LSTMExtractor, GRUExtractor, RNNExtractor, LSTMCellExtractor) and the aten op handlers registered below go through the same set of free functions, so exports are byte-identical regardless of the entry point.

Layout:

  • Orchestration (multi-layer, bidirectional, batch_first, state setup): emit_rnn_via_fragment and its private helpers. Variant-agnostic; drives the per-layer / per-direction fragment call.
  • Per-variant params extraction: _lstm_tensor_params, _gru_tensor_params, _rnn_tensor_params. Read named weight attributes off a "module-like" object.
  • Aten adapters: _LSTMAtenAdapter, _GRUAtenAdapter, _RNNAtenAdapter. Build an attribute interface compatible with the per-variant tensor_params from the aten op's flat params: Tensor[] argument.
  • Aten op handlers: aten::lstm, aten::gru, aten::rnn_tanh, aten::rnn_relu, plus the single-step cell variants (aten::lstm_cell, aten::gru_cell, aten::rnn_tanh_cell, aten::rnn_relu_cell), each routed through a one-call NNEF fragment (lstm_cell.nnef / gru_cell.nnef / rnn_tanh_cell.nnef / rnn_relu_cell.nnef).

emit_gru_cell_via_fragment

emit_gru_cell_via_fragment(g, name_to_tensor, base: str, nnef_dtype, batch_dim: int, hidden: int, input_ref: NTensor, h_prev_ref: NTensor, w_ih: torch.Tensor, w_hh: torch.Tensor, b_ih: T.Optional[torch.Tensor], b_hh: T.Optional[torch.Tensor], h_new_tv: T.Optional[TensorVariable]) -> T.List[str]

Emit a single gru_cell NNEF fragment call.

Unlike lstm_cell, GRU keeps b_ih and b_hh separate because the new-gate biases are split across the reset-gated branch (see the fragment docstring for the equations).

emit_lstm_cell_via_fragment

emit_lstm_cell_via_fragment(g, name_to_tensor, base: str, nnef_dtype, batch_dim: int, hidden: int, input_ref: NTensor, h_prev_ref: NTensor, c_prev_ref: NTensor, w_ih: torch.Tensor, w_hh: torch.Tensor, b_ih: T.Optional[torch.Tensor], b_hh: T.Optional[torch.Tensor], h_new_tv: T.Optional[TensorVariable], c_new_tv: T.Optional[TensorVariable]) -> T.List[str]

Emit a single lstm_cell NNEF fragment call.

Internally does grouped (B, I) @ (4H, I).T plus (B, H) @ (4H, H).T, adds the unsqueezed combined bias, slices into the 4 gates, and computes (h_new, c_new).

emit_rnn_cell_via_fragment

emit_rnn_cell_via_fragment(g, name_to_tensor, base: str, nnef_dtype, batch_dim: int, hidden: int, input_ref: NTensor, h_prev_ref: NTensor, w_ih: torch.Tensor, w_hh: torch.Tensor, b_ih: T.Optional[torch.Tensor], b_hh: T.Optional[torch.Tensor], h_new_tv: T.Optional[TensorVariable], nonlinearity: str) -> T.List[str]

Emit a single rnn_{tanh,relu}_cell NNEF fragment call.

Like lstm_cell the biases are pre-summed to a single (1, H) term
the Elman cell's nonlinearity sits on the full preactivation so b_ih and b_hh are interchangeable in the math.

emit_rnn_via_fragment

emit_rnn_via_fragment(g, node, name_to_tensor, module, nnef_fragment_name: str, argument_names_order: T.Sequence[str], tensor_params_fn: T.Callable, inference_target=None, **tensor_params_kwargs) -> T.List[str]

Multi-layer / bidirectional RNN orchestration around a fragment call.

Variant-agnostic: drives the per-layer-and-direction loop, calls tensor_params_fn to materialize weights / states per slice, and issues one nnef_fragment_name call per (layer, direction). Bidi packing and multi-layer concat are handled internally.

gru

gru(g, node, name_to_tensor, **kwargs)

Map aten::gru.input to NNEF via the existing gru fragment.

gru_cell

gru_cell(g, node, name_to_tensor, **kwargs)

Map aten::gru_cell(input, hx, w_ih, w_hh, b_ih?, b_hh?) to NNEF.

hx is a single 2D state tensor (unlike LSTM's 2-element list).

lstm

lstm(g, node, name_to_tensor, **kwargs)

Map aten::lstm.input to NNEF via the existing lstm fragment.

lstm_cell

lstm_cell(g, node, name_to_tensor, **kwargs)

Map aten::lstm_cell(input, hx_list, w_ih, w_hh, b_ih?, b_hh?) to NNEF.

hx_list is a t2n FixedTensorList of [h_prev, c_prev]. The output is (h_new, c_new).

rnn_relu

rnn_relu(g, node, name_to_tensor, **kwargs)

Map aten::rnn_relu.input to NNEF via existing rnn_relu fragment.

rnn_relu_cell

rnn_relu_cell(g, node, name_to_tensor, **kwargs)

Map aten::rnn_relu_cell(input, hx, w_ih, w_hh, b_ih?, b_hh?).

rnn_tanh

rnn_tanh(g, node, name_to_tensor, **kwargs)

Map aten::rnn_tanh.input to NNEF via existing rnn_tanh fragment.

rnn_tanh_cell

rnn_tanh_cell(g, node, name_to_tensor, **kwargs)

Map aten::rnn_tanh_cell(input, hx, w_ih, w_hh, b_ih?, b_hh?).