8. Custom operators
Goals
At the end of this tutorial you will be able to:
- Control the transformation to NNEF of
nn.Module
as you wish. This is often useful in case those modules are not representable in the jit graph of PyTorch or because you wish to use custom NNEF operator for your inference engine.
Prerequisite
- PyTorch and Python basics
- 5 min to read this page
Sometimes you want to control specific torch.nn.Module
expansion to NNEF.
It may happen because you want to use specific custom operator on inference target instead of
basic primitives, or simply because you need to map to something that is not traceable,
like for example (but not limited to) a physics engine.
To this purpose with torch_to_nnef
, you can create a subclass of torch_to_nnef.op.custom_extractors.ModuleInfoExtractor
that is defined as such:
-
torch_to_nnef.op.custom_extractors.ModuleInfoExtractor
Class to take manual control of NNEF expansion of a nn.Module.
You need to subclass it, and set MODULE_CLASS according to your targeted module.
Then write .convert_to_nnef according to your need.
convert_to_nnef
Control NNEF content to be written for each MODULE_CLASS.
This happen at macro level when converting from internal IR to NNEF IR stage.
This is the Core method to overwrite in subclass.
It is no different than any op implemented in
torch_to_nnef
in the modulegenerate_in_torch_graph
Internal method called by torch_to_nnef ir_graph.
get_by_kind
classmethod
Get ModuleInfoExtractor by kind in torch_to_nnef internal IR.
get_by_module
classmethod
Search if the module is one of the MODULE_CLASS registered.
return appropriate ModuleInfoExtractor subclass if found
To make it works you need to accomplish 4 steps:
- sub-classing it
- defining its
MODULE_CLASS
attribute. - defining its
convert_to_nnef
- ensuring that the subclass you just defined is imported in your export script
This would look something like that:
from torch_to_nnef.op.custom_extractors import ModuleInfoExtractor
class MyCustomHandler(ModuleInfoExtractor):
MODULE_CLASS = MyModuleToCustomConvert
def convert_to_nnef(
self,
g,
node,
name_to_tensor,
null_ref,
torch_graph,
**kwargs
):
pass
You can take inspiration from our own management of RNN layers like:
-
torch_to_nnef.op.custom_extractors.LSTMExtractor
Bases:
_RNNMixin
,ModuleInfoExtractor
But ultimately this is just a chain of op's that needs to be written,
inside the g
graph, like we do when adding new aten operator