Skip to content

torch_to_nnef.op.custom_extractors

torch_to_nnef.op.custom_extractors

op.custom_extractors provides mechanism to control extraction to NNEF.

while bypassing PyTorch full expansion of torch.Module within torch_graph which by default use torch.jit.trace .

This may be for two main reasons
  • Some layer such as LSTM/GRU have complex expension which are better handled by encapsulation instead of spreading high number of variable
  • Some layer might not be serializable to .jit
  • There might be some edge case where you prefer to keep full control on exported NNEF subgraph.

LSTMExtractor

LSTMExtractor()

Bases: _RNNMixin, ModuleInfoExtractor

ModuleInfoExtractor

ModuleInfoExtractor()

Class to take manual control of NNEF expansion of a nn.Module.

You need to subclass it, and set MODULE_CLASS according to your targeted module.

Then write .convert_to_nnef according to your need.

convert_to_nnef
convert_to_nnef(g, node, name_to_tensor, null_ref, torch_graph, inference_target, **kwargs)

Control NNEF content to be written for each MODULE_CLASS.

This happen at macro level when converting from internal IR to NNEF IR stage.

This is the Core method to overwrite in subclass.

It is no different than any op implemented in torch_to_nnef in the module

generate_in_torch_graph
generate_in_torch_graph(torch_graph, *args, **kwargs)

Internal method called by torch_to_nnef ir_graph.

get_by_kind classmethod
get_by_kind(kind: str)

Get ModuleInfoExtractor by kind in torch_to_nnef internal IR.

get_by_module classmethod
get_by_module(module: nn.Module)

Search if the module is one of the MODULE_CLASS registered.

return appropriate ModuleInfoExtractor subclass if found

ordered_args
ordered_args(torch_graph)

Odered args for the module call.

Sometimes torch jit may reorder inputs. compared to targeted python ops in such case ordering need to be re-addressed