Skip to content

wrappers

torch_to_nnef.nemo_tract.wrappers

CollapseBatchDimWrapper

CollapseBatchDimWrapper(module: torch.nn.Module, sym_dynamic_axes: T.Dict[str, T.Dict[int, str]])

Bases: Module

Wrap a NeMo exportable subnet to remove batch from its interface.

DecoderWithoutTargetLength

DecoderWithoutTargetLength(decoder: torch.nn.Module, *, nemo_asr: InjectedNemoModule = INJECTED)

Bases: Module

Wrap decoder/joint+decoder to remove 'target_length' argument/output.

RenameOutputs

RenameOutputs(module: torch.nn.Module, rename_map: T.Dict[str, str])

Bases: Module

Wrapper that renames output tensor names for export-time only.

Leaves computation unchanged and preserves input names. Useful to avoid name collisions between inputs and outputs (e.g., both named 'length').

WrapAudioPreprocessor

WrapAudioPreprocessor(preprocessor: torch.nn.Module)

Bases: Module

Wraps the AudioPreprocessor to fix input_example empty.

WrapPreprocessorCast

WrapPreprocessorCast(preprocessor: torch.nn.Module, dtype: torch.dtype)

Bases: Module

Wraps the preprocessor to add a cast to float16/32 at the output.

decoder_fix_input_example_batch_size

decoder_fix_input_example_batch_size(input_example: T.List[torch.Tensor], batch_size: int) -> T.List[torch.Tensor]

Fix the batch size of the input example for decoder models.

use_pytorch_sdpa

use_pytorch_sdpa(model: torch.nn.Module, *, nemo: InjectedNemoModule = INJECTED)

Modify the model to use PyTorch sdpa implementations where applicable.