updater
torch_to_nnef.tensor.updater
ModTensorUpdater
ModTensorUpdater(model: torch.nn.Module, add_parameter_if_unset: bool = True, add_buffers: bool = False, add_unregistred_tensor: bool = False, disable_requires_grad: bool = False)
Helper to update parameter/buffer/unregistred tensor of a model cleanly.
Cleanly means without breaking shared reference between Tensors.
An example is the shared reference on transformers between first input_ids embedding and last linear layer projection weights.
Init ModTensorUpdater.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
Module
|
nn.Module model that will have tensors updated with this class |
required |
add_parameter_if_unset |
bool
|
if you add a tensor where there is not yet a torch.nn.Parameters in the model it will add it |
True
|
add_buffers |
bool
|
Scope all nn.Buffer PyTorch object of the model to be 'updatable' |
False
|
add_unregistred_tensor |
bool
|
Scope all tensor PyTorch object of the model not referenced in nn.Parameters & nn.Buffer |
False
|
disable_requires_grad |
bool
|
If set it force tensors replaced to be with no 'requires_grad' at update time |
False
|