Skip to content

provider

torch_to_nnef.nemo_tract.provider

NemoProvider dataclass

NemoProvider(inference_target: TractNNEF, skip_preprocessor: bool = False, split_joint_decoder: bool = False, float_dtype: T.Optional[torch.dtype] = None, only_subnets: T.Optional[T.Collection[str]] = None)

Bases: Provider

NeMo remodeler provider.

Parameters:

Name Type Description Default
inference_target TractNNEF

Export target used for discovery (Tract by default).

required
skip_preprocessor bool

Whether to exclude the preprocessor subnet.

False
split_joint_decoder bool

Whether to split decoder and joint.

False
float_dtype Optional[dtype]

Preferred float dtype for examples.

None
only_subnets Optional[Collection[str]]

Optional subset of subnets to consider.

None
apply
apply(model: torch.nn.Module, plan: RemodelPlan) -> dict[str, torch.nn.Module]

Apply boundary remodel plan and return wrapped subnets.

Uses the export pipeline to construct BoundaryAdapter-wrapped modules following the external IO boundary described by the plan.

discover_signatures
discover_signatures(model: torch.nn.Module, stage: Stage) -> T.List[SubnetSignature]

Discover per-subnet signatures for the given stage.