Skip to content

rnn

torch_to_nnef.op.custom_extractors.rnn

LSTMCellExtractor

LSTMCellExtractor()

Bases: ModuleInfoExtractor

Decompose nn.LSTMCell into primitive NNEF ops.

Unlike nn.LSTM, an LSTMCell carries a single time-step. We emit: preact = matmul(input, w_ih, T) + matmul(h, w_hh, T) + b_ih + b_hh i, f, g, o = chunk(preact, 4, axis=-1) c_new = sigmoid(f) * c + sigmoid(i) * tanh(g) h_new = sigmoid(o) * tanh(c_new)

Input order from the user-facing wrapper is (input, h, c) -- the internal nn.LSTMCell call expects (input, (h, c)) which is handled by _call_original_mod_with_args.

ordered_args
ordered_args(torch_graph)

Reorder args so the first one is input (shape (B, input_size)).

t2n's IR sometimes reorders the cell's inputs after FixedTensorList / tuple expansion, surfacing them as e.g. (h, input, c). The cell's input_size (= weight_ih.shape[1]) lets us pick the input tensor by shape; the relative order of (h, c) follows the JIT graph's prim::ListConstruct that builds hx.

LSTMExtractor

LSTMExtractor()

Bases: _RNNMixin, ModuleInfoExtractor