Skip to content

updater

torch_to_nnef.tensor.updater

ModTensorUpdater

ModTensorUpdater(model: torch.nn.Module, add_parameter_if_unset: bool = True, add_buffers: bool = False, add_unregistred_tensor: bool = False, disable_requires_grad: bool = False)

Helper to update parameter/buffer/unregistred tensor of a model cleanly.

Cleanly means without breaking shared reference between Tensors.

An example is the shared reference on transformers between first input_ids embedding and last linear layer projection weights.

Init ModTensorUpdater.

Parameters:

Name Type Description Default
model Module

nn.Module model that will have tensors updated with this class

required
add_parameter_if_unset bool

if you add a tensor where there is not yet a torch.nn.Parameters in the model it will add it

True
add_buffers bool

Scope all nn.Buffer PyTorch object of the model to be 'updatable'

False
add_unregistred_tensor bool

Scope all tensor PyTorch object of the model not referenced in nn.Parameters & nn.Buffer

False
disable_requires_grad bool

If set it force tensors replaced to be with no 'requires_grad' at update time

False
get_by_name
get_by_name(name: str) -> torch.Tensor

Get tensor based on it's reference name.

update_by_name
update_by_name(name: str, new_tensor: torch.Tensor, tie_replacements: bool = True, enforce_tensor_consistency: bool = True) -> torch.Tensor

Update tensor based on it's reference name.

update_by_ref
update_by_ref(ref: torch.nn.Parameter, new_tensor: torch.Tensor, enforce_tensor_consistency: bool = True) -> torch.Tensor

Update tensor based on it's reference object.