rnn
torch_to_nnef.op.custom_extractors.rnn
LSTMCellExtractor
Bases: ModuleInfoExtractor
Decompose nn.LSTMCell into primitive NNEF ops.
Unlike nn.LSTM, an LSTMCell carries a single time-step. We emit:
preact = matmul(input, w_ih, T) + matmul(h, w_hh, T) + b_ih + b_hh
i, f, g, o = chunk(preact, 4, axis=-1)
c_new = sigmoid(f) * c + sigmoid(i) * tanh(g)
h_new = sigmoid(o) * tanh(c_new)
Input order from the user-facing wrapper is (input, h, c) -- the
internal nn.LSTMCell call expects (input, (h, c)) which is handled by
_call_original_mod_with_args.
ordered_args
Reorder args so the first one is input (shape (B, input_size)).
t2n's IR sometimes reorders the cell's inputs after FixedTensorList /
tuple expansion, surfacing them as e.g. (h, input, c). The cell's
input_size (= weight_ih.shape[1]) lets us pick the input tensor by
shape; the relative order of (h, c) follows the JIT graph's
prim::ListConstruct that builds hx.