torch_to_nnef.op.custom_extractors
torch_to_nnef.op.custom_extractors
op.custom_extractors
provides mechanism to control extraction to NNEF.
while bypassing PyTorch full expansion of torch.Module
within torch_graph
which by default use torch.jit.trace .
This may be for two main reasons
- Some layer such as LSTM/GRU have complex expension which are better handled by encapsulation instead of spreading high number of variable
- Some layer might not be serializable to .jit
- There might be some edge case where you prefer to keep full control on exported NNEF subgraph.
ModuleInfoExtractor
Class to take manual control of NNEF expansion of a nn.Module.
You need to subclass it, and set MODULE_CLASS according to your targeted module.
Then write .convert_to_nnef according to your need.
convert_to_nnef
Control NNEF content to be written for each MODULE_CLASS.
This happen at macro level when converting from internal IR to NNEF IR stage.
This is the Core method to overwrite in subclass.
It is no different than any op implemented in torch_to_nnef
in the module
generate_in_torch_graph
Internal method called by torch_to_nnef ir_graph.
get_by_kind
classmethod
Get ModuleInfoExtractor by kind in torch_to_nnef internal IR.
get_by_module
classmethod
Search if the module is one of the MODULE_CLASS registered.
return appropriate ModuleInfoExtractor subclass if found