Skip to content

8. Custom operators

Goals

At the end of this tutorial you will be able to:

  1. Control the transformation to NNEF of nn.Module as you wish. This is often useful in case those modules are not representable in the jit graph of PyTorch or because you wish to use custom NNEF operator for your inference engine.

Prerequisite

  • PyTorch and Python basics
  • 5 min to read this page

Sometimes you want to control specific torch.nn.Module expansion to NNEF. It may happen because you want to use specific custom operator on inference target instead of basic primitives, or simply because you need to map to something that is not traceable, like for example (but not limited to) a physics engine.

To this purpose with torch_to_nnef, you can create a subclass of torch_to_nnef.op.custom_extractors.ModuleInfoExtractor that is defined as such:

  • torch_to_nnef.op.custom_extractors.ModuleInfoExtractor

    ModuleInfoExtractor()
    

    Class to take manual control of NNEF expansion of a nn.Module.

    You need to subclass it, and set MODULE_CLASS according to your targeted module.

    Then write .convert_to_nnef according to your need.

    convert_to_nnef

    convert_to_nnef(g, node, name_to_tensor, null_ref, torch_graph, inference_target, **kwargs)
    

    Control NNEF content to be written for each MODULE_CLASS.

    This happen at macro level when converting from internal IR to NNEF IR stage.

    This is the Core method to overwrite in subclass.

    It is no different than any op implemented in torch_to_nnef in the module

    generate_in_torch_graph

    generate_in_torch_graph(torch_graph, *args, **kwargs)
    

    Internal method called by torch_to_nnef ir_graph.

    get_by_kind classmethod

    get_by_kind(kind: str)
    

    Get ModuleInfoExtractor by kind in torch_to_nnef internal IR.

    get_by_module classmethod

    get_by_module(module: nn.Module)
    

    Search if the module is one of the MODULE_CLASS registered.

    return appropriate ModuleInfoExtractor subclass if found

    ordered_args

    ordered_args(torch_graph)
    

    Odered args for the module call.

    Sometimes torch jit may reorder inputs. compared to targeted python ops in such case ordering need to be re-addressed

To make it works you need to accomplish 4 steps:

  1. sub-classing it
  2. defining its MODULE_CLASS attribute.
  3. defining its convert_to_nnef
  4. ensuring that the subclass you just defined is imported in your export script

This would look something like that:

from torch_to_nnef.op.custom_extractors import ModuleInfoExtractor

class MyCustomHandler(ModuleInfoExtractor):
    MODULE_CLASS = MyModuleToCustomConvert

    def convert_to_nnef(
        self,
        g,
        node,
        name_to_tensor,
        null_ref,
        torch_graph,
        **kwargs
    ):
        pass

You can take inspiration from our own management of RNN layers like:

  • torch_to_nnef.op.custom_extractors.LSTMExtractor

    LSTMExtractor()
    

    Bases: _RNNMixin, ModuleInfoExtractor

But ultimately this is just a chain of op's that needs to be written, inside the g graph, like we do when adding new aten operator