8. Custom operators
Goals
At the end of this tutorial you will be able to:
- Control the transformation to NNEF of
nn.Moduleas you wish. This is often useful in case those modules are not representable in the jit graph of PyTorch or because you wish to use custom NNEF operator for your inference engine.
Prerequisite
- PyTorch and Python basics
- 5 min to read this page
Sometimes you want to control specific torch.nn.Module expansion to NNEF.
It may happen because you want to use specific custom operator on inference target instead of
basic primitives, or simply because you need to map to something that is not traceable,
like for example (but not limited to) a physics engine.
To this purpose with torch_to_nnef, you can create a subclass of torch_to_nnef.op.custom_extractors.ModuleInfoExtractor
that is defined as such:
-
torch_to_nnef.op.custom_extractors.ModuleInfoExtractorClass to take manual control of NNEF expansion of a nn.Module.
You need to subclass it, and set MODULE_CLASS according to your targeted module.
Then write .convert_to_nnef according to your need.
convert_to_nnefControl NNEF content to be written for each MODULE_CLASS.
This happen at macro level when converting from internal IR to NNEF IR stage.
This is the Core method to overwrite in subclass.
It is no different than any op implemented in
torch_to_nnefin the modulegenerate_in_torch_graphInternal method called by torch_to_nnef ir_graph.
get_by_kindclassmethodGet ModuleInfoExtractor by kind in torch_to_nnef internal IR.
get_by_moduleclassmethodSearch if the module is one of the MODULE_CLASS registered.
return appropriate ModuleInfoExtractor subclass if found
To make it works you need to accomplish 4 steps:
- sub-classing it
- defining its
MODULE_CLASSattribute. - defining its
convert_to_nnef - ensuring that the subclass you just defined is imported in your export script
This would look something like that:
from torch_to_nnef.op.custom_extractors import ModuleInfoExtractor
class MyCustomHandler(ModuleInfoExtractor):
MODULE_CLASS = MyModuleToCustomConvert
def convert_to_nnef(
self,
g,
node,
name_to_tensor,
null_ref,
torch_graph,
**kwargs
):
pass
You can take inspiration from our own management of RNN layers like:
-
torch_to_nnef.op.custom_extractors.LSTMExtractorBases:
_RNNMixin,ModuleInfoExtractor
But ultimately this is just a chain of op's that needs to be written,
inside the g graph, like we do when adding new aten operator