torch_to_nnef.op.custom_extractors
torch_to_nnef.op.custom_extractors
op.custom_extractors provides mechanism to control extraction to NNEF.
while bypassing PyTorch full expansion of torch.Module within torch_graph
which by default use torch.jit.trace .
This may be for two main reasons
- Some layer such as LSTM/GRU have complex expension which are better handled by encapsulation instead of spreading high number of variable
- Some layer might not be serializable to .jit
- There might be some edge case where you prefer to keep full control on exported NNEF subgraph.
LSTMCellExtractor
Bases: ModuleInfoExtractor
Decompose nn.LSTMCell into primitive NNEF ops.
Unlike nn.LSTM, an LSTMCell carries a single time-step. We emit:
preact = matmul(input, w_ih, T) + matmul(h, w_hh, T) + b_ih + b_hh
i, f, g, o = chunk(preact, 4, axis=-1)
c_new = sigmoid(f) * c + sigmoid(i) * tanh(g)
h_new = sigmoid(o) * tanh(c_new)
Input order from the user-facing wrapper is (input, h, c) -- the
internal nn.LSTMCell call expects (input, (h, c)) which is handled by
_call_original_mod_with_args.
ordered_args
Reorder args so the first one is input (shape (B, input_size)).
t2n's IR sometimes reorders the cell's inputs after FixedTensorList /
tuple expansion, surfacing them as e.g. (h, input, c). The cell's
input_size (= weight_ih.shape[1]) lets us pick the input tensor by
shape; the relative order of (h, c) follows the JIT graph's
prim::ListConstruct that builds hx.
ModuleInfoExtractor
Class to take manual control of NNEF expansion of a nn.Module.
You need to subclass it, and set MODULE_CLASS according to your targeted module.
Then write .convert_to_nnef according to your need.
convert_to_nnef
Control NNEF content to be written for each MODULE_CLASS.
This happen at macro level when converting from internal IR to NNEF IR stage.
This is the Core method to overwrite in subclass.
It is no different than any op implemented in torch_to_nnef
in the module
generate_in_torch_graph
Internal method called by torch_to_nnef ir_graph.
get_by_kind
classmethod
Get ModuleInfoExtractor by kind in torch_to_nnef internal IR.
get_by_module
classmethod
Search if the module is one of the MODULE_CLASS registered.
return appropriate ModuleInfoExtractor subclass if found