Internal design
Internals of torch to NNEF export are mostly segmented in 6 steps as shown bellow:

Each of those steps have specific aims and goals.
- Aims to make sense of complex inputs and outputs such as dict, dict like object containing tensors or tensors inside containers in containers...
- Name each tensor after the module it is assigned to (if it's shared across multiple modules first name encountered will be retained)
- Trace the PyTorch Graph module by module starting from the provided model each sub-module call is solved after this module have been traced. each submodule is colapsed inside it's parent. This tracing build a specific internal representation (IR) in torch to nnef which is NOT torch graph but a simplified version of it that is no more tied to torch cpp internals and with removed useless operators for inference.
- Translate the torch to nnef internal IR into NNEF depending on inference target selected
- Save each tensor on disk in .dat and serialize the graph.nnef and graph.quant associated.
- Allow to perform a serie of test after NNEF model asset has been generated (typically checking output similarities)

Note
These steps only apply to torch_to_nnef.export_model_to_nnef
export function that exports the graph + the tensors.
To observe those: setting log level to info for this lib is helpful, a proposed default logger is available in torch_to_nnef.log.init_log
1. Auto wrapper
The auto wrapper is available at torch_to_nnef.model_wrapper
. In essence,
this step tries hard to make sense of the input and output provided by the
user as input parameters by 'flattening' and extracting from complex data-structures a proper
list of tensor to be passed. Some example can be seen in our multi inputs/outputs tutorial.
Still note that as of today the graph is traced statically with Python primitive constantized.
Also raw objects passed in forward
function are not supported yet (uncertainty about the order in which tensors found in it should be passed).
2. Tensor naming
This replaces each tensor in the graph (code can be found in torch_to_nnef.tensor.named
)
by a named tensor holding the name it will have in the different intermediate representations.
This is helpful to keep consistent tensor naming between the PyTorch parameters/buffers name
and NNEF archive we build. Allowing confident reference between the 2 worlds. In practice this
tensor acts just like a classical torch.Tensor
so it can even be used beyond torch_to_nnef
usecase,
if you want to name tensors.
3. Internal IR representation
While tracing the graph recursively you may debug its parsed representation as follows:
let's imagine you set a breakpoint in torch_to_nnef.torch_graph.ir_graph.TorchModuleIRGraph.parse
method you could call self.tracer.torch_graph
to observe the
PyTorch representation:
graph(%self.1 : __torch__.torchvision.models.alexnet.___torch_mangle_39.AlexNet,
%x.1 : Float(1, 3, 224, 224, strides=[150528, 50176, 224, 1], requires_grad=0, device=cpu)):
%classifier : __torch__.torch.nn.modules.container.___torch_mangle_38.Sequential = prim::GetAttr[name="classifier"](%self.1)
%avgpool : __torch__.torch.nn.modules.pooling.___torch_mangle_30.AdaptiveAvgPool2d = prim::GetAttr[name="avgpool"](%self.1)
%features : __torch__.torch.nn.modules.container.___torch_mangle_29.Sequential = prim::GetAttr[name="features"](%self.1)
%394 : Tensor = prim::CallMethod[name="forward"](%features, %x.1)
%395 : Tensor = prim::CallMethod[name="forward"](%avgpool, %394)
%277 : int = prim::Constant[value=1]() # /Users/julien.balian/SONOS/src/torch-to-nnef/.venv/lib/python3.12/site-packages/torchvision/models/alexnet.py:50:0
%278 : int = prim::Constant[value=-1]() # /Users/julien.balian/SONOS/src/torch-to-nnef/.venv/lib/python3.12/site-packages/torchvision/models/alexnet.py:50:0
%input.19 : Float(1, 9216, strides=[9216, 1], requires_grad=0, device=cpu) = aten::flatten(%395, %277, %278) # /Users/julien.balian/SONOS/src/torch-to-nnef/.venv/lib/python3.12/site-packages/torchvision/models/alexnet.py:50:0
%396 : Tensor = prim::CallMethod[name="forward"](%classifier, %input.19)
return (%396)
self.printall()
and observe the current torch to NNEF representation:
___________________________________[PyTorch JIT Graph '<class 'torchvision.models.alexnet.AlexNet'>']___________________________________
inputs: (AlexNet_x_1: torch.float32@[1, 3, 224, 224])
Static Constants:
int AlexNet_277 := 1
int AlexNet_278 := -1
Static Tensor:
Blob TorchScript:
List:
TupleTensors:
Directed Acyclic Graph:
None AlexNet_394 := prim::CallMethod<Sequential.forward>( AlexNet_x_1 )
None AlexNet_395 := prim::CallMethod<AdaptiveAvgPool2d.forward>( AlexNet_394 )
torch.float32 AlexNet_input_19 := aten::flatten( AlexNet_395, AlexNet_277, AlexNet_278 )
None AlexNet_396 := prim::CallMethod<Sequential.forward>( AlexNet_input_19 )
outputs: (AlexNet_396: None@None)
____________________________________________________________________________________________________
Since the process is recursive you can see this representation evolve as each submodule gets parsed.
Also if you want to learn more the representation data structure we use you can look at the
torch_to_nnef.torch_graph.ir_data
and torch_to_nnef.torch_graph.ir_op
.
This step is crucial in order to get an accurate representation of the Graph.
A lot of thing can go wrong and this interface with some internal part of PyTorch
which aren't guarantied as
stable. This is one of the reason we have a dedicated IR in torch_to_nnef
. When code breaks
in this part, a good understanding of PyTorch internals is often required, and due to the lack of documentation,
reading their source code is necessary.
4. NNEF translation
This step is probably one that need the most code, but that's often rather straightforward.
It's responsible to mapping between our internal representation and the NNEF graph.
Adding a new operator is a rather simple process as long as the
2 engines (PyTorch and the inference target) share similar operator to composes.
But since there are so much operators in PyTorch
there is a lot of mapping to do.
In some case when there is too much discrepancy between the engines it may be worth
proposing to reify the operation in the targeted inference engine.
5. NNEF dump
This step is rather simple. It uses a modernized version of the dump logic proposed by Khronos group
in their package nnef_tools
, with few extensions around
custom .dat
format serialization (code is available here).