model_wrapper
torch_to_nnef.model_wrapper
Wrap model to bypass limitation of torch_to_nnef internals.
Cases where inputs or outputs of a model contains:
tuples, list, dicts, Object.
UnfoldModelInfo
dataclass
UnfoldModelInfo(model: nn.Module, original_inputs: T.Tuple[torch.Tensor], original_outputs: T.List[torch.Tensor], flat_inputs: T.Tuple[torch.Tensor], flat_outputs: T.Tuple[torch.Tensor], input_names: T.List[str], output_names: T.List[str])
Hold model input/output structure information.
WrapStructIO
Bases: Module
Once traced it should be nop in final graph.
build_new_names_and_elements
build_new_names_and_elements(original_names: T.Optional[T.List[str]], elms: T.Iterable, default_element_name_tmpl: str)
Build names of elements based on containers parents.
Usecase 1:. provide: original_names: ['input', "a"] elms: [[tensor, tensor, tensor], {"arm": tensor, "head": tensor}] Expected output names: ["input_0", input_1", "input_2", "a", "head"]
(undefined names)
provide: original_names: ['plop'] elms: [[tensor, tensor, tensor], tensor, tensor]
Expected output names: ["plop_0", plop_1", "plop_2", default_element_name_tmpl %ix=1, default_element_name_tmpl %ix=2 ]
(dict with prefix)
provide: original_names: ['a', 'dic'] elms: [tensor, {"arm": tensor, "head": tensor}]
Expected output names: ["a", "dic_arm", "dic_head"]
build_structured_inputs
Rebuild structured inputs from a flat args sequence.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
flat_args |
Flat sequence of tensor values (non-tensor constants are automatically re-inserted from input_infos). |
required | |
input_infos |
Flattened element descriptors produced by
:func: |
required |
Returns:
| Type | Description |
|---|---|
|
Tuple of structured arguments matching the original model signature. |
flatten_structured_outputs
Flatten structured model outputs to a flat list of tensors.
If the output is already a simple tuple of tensors, it is returned as-is.
insert_fixed_nontraceable_args
Re-insert non-tensor constant values into the flat args list.
During flattening, non-tensor elements (ints, bools, …) are recorded in input_infos but excluded from the dynamic flat_args. This function splices them back at the correct positions so that flat_args aligns 1-1 with input_infos again.