ML Engineering/Tensorflow and Torch
[pytorch] Trainable parameters, Non-Trainable values, and Register_buffer
KeepPersistStay
2024. 9. 5. 14:15
In PyTorch, when defining a nn.Module
, attributes can either be trainable parameters (like weights and biases) or non-trainable values (such as constants, buffers, or pre-computed values).
Trainable Parameters vs Non-Trainable Attributes:
- Trainable Parameters:
- These are parameters that the model updates during training via backpropagation (e.g., weights in a neural network).
- PyTorch stores these parameters using
nn.Parameter
and registers them to the model. - Example: weights and biases in layers like
nn.Linear
ornn.Conv2d
.
self.weight = nn.Parameter(torch.randn(10, 10)) # Trainable
- Non-Trainable Attributes:
- These are attributes that do not change during training. They are useful for storing constants, lookup tables, pre-initialized matrices, etc.
- If you don’t want these values to be updated via backpropagation, you typically register them as a buffer or store them as regular attributes of the module.
- Example: a normalization constant, a precomputed matrix, or a codebook in vector quantization.
self.constant = torch.randn(10, 10) # Non-trainable, regular attribute
register_buffer
:
- PyTorch provides
register_buffer
to store non-trainable tensors in a model. This is useful because buffers will automatically be moved to the correct device (e.g., GPU) when the model is moved, but they won’t be updated during training. - However, if you don’t want or need this specific behavior, you can just store non-trainable values as regular attributes.
def __init__(block_size): self.register_buffer("mask", torch.tril(torch.ones(block_size, block_size)) .view(1, 1, block_size, block_size)) def forward(x): B, T, C = x.size() self.mask[:,:,:T,:T]