Skip to content

base

torch_to_nnef.llm_tract.models.base

BaseCausalWithDynCacheAndTriu

BaseCausalWithDynCacheAndTriu(model: AutoModelForCausalLM, num_logits_to_keep: int = 1)

Bases: TorchToNNEFWrappedLLM

Assume common AutoModelForCausalLM arch.

with : - .model - .lm_head

forward
forward(input_ids: torch.Tensor, *args)

Forward of BaseCausalWithDynCacheAndTriu.

Same as calling without any smart caching mechanism self.model.model+lm_head and softmax.

This export module is extremly ineficient because no caching can be provided ...

TorchToNNEFWrappedLLM

TorchToNNEFWrappedLLM()

Bases: Module

Base module class for all LLM wrapping.

These wrapper are needed to ensure deterministic inputs/outputs graph signature and allow some modeling optimization of few architecture.

ctx_dtype_dyn_cache

ctx_dtype_dyn_cache()

Context Manager to handle inconsistent device type in KV-cache update.

This may be due for example to the use of accelerate 'meta' tensors device.

This manager is stackable in such case only largest context will be applied.

update_forward_signature

update_forward_signature(self)

Trickery to help torch > 2.0 new export API tracing.

use_dtype_dyn_cache

use_dtype_dyn_cache(f)

Annotator for forward function applying ctx_dtype_dyn_cache.