torch_to_nnef.torch_graph
torch_to_nnef.torch_graph
torch_graph is intended to extract full representation of PyTorch Graph.
From PyTorch into a stable intermediate representation suitable to then apply translation operation to NNEF. This means that not all PyTorch orginal graph is translated. For example, we ignore part linked to device location informations, memory specific operation or parameters linked to gradients.
This choice which is different compared to torch.onnx module due to the absence of control (on our side) over evolution of PyTorch internals. If some of the PyTorch internals are modified only this module should idealy be impacted.
Here there is NO notion of dynamic axes all shapes are supposedly defined based on provided input example. At latter stage in other modules the dynamic shapes need to be introduced if requested by user.
FixedTensorList
dataclass
TensorVariable
dataclass
TensorVariable(name: str, data: T.Optional[torch.Tensor], shape: T.Optional[T.List[int]], dtype: T.Optional[torch.dtype], quant: T.Optional[T.Dict[str, T.Any]] = None, _traced_data: T.Optional[torch.Tensor] = None)
Bases: Data
TorchModuleIRGraph
TorchModuleIRGraph(torch_module_tracer: TorchModuleTracer, omit_useless_nodes: bool = True, is_root_module: bool = False)
Torch Graph intermediate representation from: jit.trace with recursion.
This is not direct torch._C.Graph but simpler abstraction, with:
A list of data nodes in self.data_nodes
A list of operations nodes in self.op_nodes
self.inputs
is a list of reference of some self.data_nodes
self.outputs
is a list of reference of some self.data_nodes
This abstraction of the vanilla Torch Graph allow to manipulate graph in order to check/complete missing data informations and ignore useless operations for our transcription needs.
It's also allows to be less reliant on base graph in case of modification of PyTorch Internals (think Adapter Pattern).
Warning ! Only NOT nested data container (TupleTensors, FixedTensorList, ...) are supported for now
TorchModuleTracer
TorchModuleTracer(module: nn.Module, traced_module: T.Optional[torch.jit.TracedModule] = None, fn_name: str = 'forward', args: T.Optional[T.Tuple[T.Any, ...]] = None)
Evaluate Optimized traced Function code so that signature always match.
original Module is passed to do proper un-boxing later on. This is needed because we have a re-routing based on actual module classtype.
Create a tracer for module.
The tracer stores the original module, an optional pre‑traced
torch.jit.TracedModule
(which allows re‑use of a previously
computed trace), the name of the forward method to trace, and the
arguments used for tracing. The arguments are post‑processed by
:func:maybe_quantize_args_tensor
to ensure compatibility with
quantized modules.
torch_graph
property
Return the underlying PyTorch graph object.
The actual torch.Graph
is retrieved from the traced module.
When a different forward method is requested (fn_name
differs
from "forward"), the corresponding sub‑graph is returned instead.
traced_module
property
Return the traced module, computing it lazily if required.
If self._traced_module
is None
the method will perform a
jit.trace
on self.mod
with self.args
while handling
possible PyTorch version nuances. Any RuntimeError
raised by
torch.jit.trace
is wrapped into a
:class:~torch_to_nnef.exceptions.T2NErrorTorchJitTraceFailed
exception.
TorchOp
dataclass
TorchOp(kind: str, module_path: str, inputs: T.List[Data], outputs: T.List[TtupleOrVar], scope: str, op_ref: T.Optional[T.Callable], call_name: T.Optional[str])
call_op
Produce operation output based on traced inputs with real torch call.
This operation call is done via self.args arguments (for now). Which means that we need to have all args needed in parameters order, following at least 1 underling torch operation signature.
NOTE: we use a different approach than original torch.onnx which pass parameter by keyword arguments, this is due to the fact that we are not aware of argument name being provided in exported graph ( from what we understand torch.onnx solve this via explicit rerouting of all signatures, which might be a bit bulky in most case ).
realise_output_type_and_size
Trace output and try to find type shape and constant realisation.