Skip to content

model_wrapper

torch_to_nnef.model_wrapper

Wrap model to bypass limitation of torch_to_nnef internals.

Cases where inputs or outputs of a model contains:

tuples, list, dicts, Object.

UnfoldModelInfo dataclass

UnfoldModelInfo(model: nn.Module, original_inputs: T.Tuple[torch.Tensor], original_outputs: T.List[torch.Tensor], flat_inputs: T.Tuple[torch.Tensor], flat_outputs: T.Tuple[torch.Tensor], input_names: T.List[str], output_names: T.List[str])

Hold model input/output structure information.

WrapStructIO

WrapStructIO(model: nn.Module, input_infos, output_infos)

Bases: Module

Once traced it should be nop in final graph.

build_new_names_and_elements

build_new_names_and_elements(original_names: T.Optional[T.List[str]], elms: T.Iterable, default_element_name_tmpl: str)

Build names of elements based on containers parents.

Usecase 1:. provide: original_names: ['input', "a"] elms: [[tensor, tensor, tensor], {"arm": tensor, "head": tensor}] Expected output names: ["input_0", input_1", "input_2", "a", "head"]

(undefined names)

provide: original_names: ['plop'] elms: [[tensor, tensor, tensor], tensor, tensor]

Expected output names: ["plop_0", plop_1", "plop_2", default_element_name_tmpl %ix=1, default_element_name_tmpl %ix=2 ]

(dict with prefix)

provide: original_names: ['a', 'dic'] elms: [tensor, {"arm": tensor, "head": tensor}]

Expected output names: ["a", "dic_arm", "dic_head"]

build_structured_inputs

build_structured_inputs(flat_args, input_infos)

Rebuild structured inputs from a flat args sequence.

Parameters:

Name Type Description Default
flat_args

Flat sequence of tensor values (non-tensor constants are automatically re-inserted from input_infos).

required
input_infos

Flattened element descriptors produced by :func:flatten_dict_tuple_or_list — each entry is (types, indexes, original_value).

required

Returns:

Type Description

Tuple of structured arguments matching the original model signature.

flatten_structured_outputs

flatten_structured_outputs(struct_output, output_infos)

Flatten structured model outputs to a flat list of tensors.

If the output is already a simple tuple of tensors, it is returned as-is.

insert_fixed_nontraceable_args

insert_fixed_nontraceable_args(flat_args, input_infos)

Re-insert non-tensor constant values into the flat args list.

During flattening, non-tensor elements (ints, bools, …) are recorded in input_infos but excluded from the dynamic flat_args. This function splices them back at the correct positions so that flat_args aligns 1-1 with input_infos again.

tupleize_structure

tupleize_structure(inps, input_infos)

Convert mutable lists back to tuples where the original had tuples.

During reconstruction lists are used because tuples are immutable. This pass converts them back based on the type information recorded in input_infos.