rnn
torch_to_nnef.op.aten.rnn
Aten-level RNN op handlers, adapters, and shared orchestration.
Canonical home for the RNN export math. Both the module-level extractors
in op/custom_extractors/rnn.py (LSTMExtractor, GRUExtractor,
RNNExtractor, LSTMCellExtractor) and the aten op handlers registered
below go through the same set of free functions, so exports are
byte-identical regardless of the entry point.
Layout:
- Orchestration (multi-layer, bidirectional, batch_first, state setup):
emit_rnn_via_fragmentand its private helpers. Variant-agnostic; drives the per-layer / per-direction fragment call. - Per-variant params extraction:
_lstm_tensor_params,_gru_tensor_params,_rnn_tensor_params. Read named weight attributes off a "module-like" object. - Aten adapters:
_LSTMAtenAdapter,_GRUAtenAdapter,_RNNAtenAdapter. Build an attribute interface compatible with the per-variant tensor_params from the aten op's flatparams: Tensor[]argument. - Aten op handlers:
aten::lstm,aten::gru,aten::rnn_tanh,aten::rnn_relu, plusaten::lstm_cell(single-step, fragment-based vialstm_cell.nnef).
emit_lstm_cell_via_fragment
emit_lstm_cell_via_fragment(g, name_to_tensor, base: str, nnef_dtype, batch_dim: int, hidden: int, input_ref: NTensor, h_prev_ref: NTensor, c_prev_ref: NTensor, w_ih: torch.Tensor, w_hh: torch.Tensor, b_ih: T.Optional[torch.Tensor], b_hh: T.Optional[torch.Tensor], h_new_tv: T.Optional[TensorVariable], c_new_tv: T.Optional[TensorVariable]) -> T.List[str]
Emit a single lstm_cell NNEF fragment call.
Internally does grouped (B, I) @ (4H, I).T plus
(B, H) @ (4H, H).T, adds the unsqueezed combined bias, slices into
the 4 gates, and computes (h_new, c_new).
emit_rnn_via_fragment
emit_rnn_via_fragment(g, node, name_to_tensor, module, nnef_fragment_name: str, argument_names_order: T.Sequence[str], tensor_params_fn: T.Callable, **tensor_params_kwargs) -> T.List[str]
Multi-layer / bidirectional RNN orchestration around a fragment call.
Variant-agnostic: drives the per-layer-and-direction loop, calls
tensor_params_fn to materialize weights / states per slice, and
issues one nnef_fragment_name call per (layer, direction). Bidi
packing and multi-layer concat are handled internally.
gru
Map aten::gru.input to NNEF via the existing gru fragment.
lstm
Map aten::lstm.input to NNEF via the existing lstm fragment.
lstm_cell
Map aten::lstm_cell(input, hx_list, w_ih, w_hh, b_ih?, b_hh?) to NNEF.
hx_list is a t2n FixedTensorList of [h_prev, c_prev]. The output is
(h_new, c_new).
rnn_relu
Map aten::rnn_relu.input to NNEF via existing rnn_relu fragment.