In PyTorch, when defining a nn.Module
, attributes can either be trainable parameters (like weights and biases) or non-trainable values (such as constants, buffers, or pre-computed values).
Trainable Parameters vs Non-Trainable Attributes:
- Trainable Parameters:
- These are parameters that the model updates during training via backpropagation (e.g., weights in a neural network).
- PyTorch stores these parameters using
nn.Parameter
and registers them to the model. - Example: weights and biases in layers like
nn.Linear
ornn.Conv2d
.
self.weight = nn.Parameter(torch.randn(10, 10)) # Trainable
- Non-Trainable Attributes:
- These are attributes that do not change during training. They are useful for storing constants, lookup tables, pre-initialized matrices, etc.
- If you don’t want these values to be updated via backpropagation, you typically register them as a buffer or store them as regular attributes of the module.
- Example: a normalization constant, a precomputed matrix, or a codebook in vector quantization.
self.constant = torch.randn(10, 10) # Non-trainable, regular attribute
register_buffer
:
- PyTorch provides
register_buffer
to store non-trainable tensors in a model. This is useful because buffers will automatically be moved to the correct device (e.g., GPU) when the model is moved, but they won’t be updated during training. - However, if you don’t want or need this specific behavior, you can just store non-trainable values as regular attributes.
def __init__(block_size): self.register_buffer("mask", torch.tril(torch.ones(block_size, block_size)) .view(1, 1, block_size, block_size)) def forward(x): B, T, C = x.size() self.mask[:,:,:T,:T]
'ML Engineering > Tensorflow and Torch' 카테고리의 다른 글
[PyTorch] Keyword Spotting 위한 데이터 로더 만들기 (2) | 2020.11.21 |
---|---|
[PyTorch] Pytorch 개요 (0) | 2020.11.20 |
[Tensorflow] C++ Inference를 위해 pb 파일 만드는 방법 (0) | 2020.06.20 |
[Tensorflow] TF 모델 저장방법 2가지 (ckpt, h5) (0) | 2020.06.20 |