Full SelfAttention

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class FullSelfAttention(nn.Module):
    """
    A vanilla multi-head self-attention layer with no masking and a projection at the end.
    This implementation doesn't use causal masking, meaning all tokens can attend to each other.
    """

    def __init__(self, n_embd, n_head, attn_pdrop, resid_pdrop):
        super().__init__()
        assert n_embd % n_head == 0
        self.n_head = n_head
        # key, query, value projections for all heads
        self.key = nn.Linear(n_embd, n_embd)
        self.query = nn.Linear(n_embd, n_embd)
        self.value = nn.Linear(n_embd, n_embd)
        # regularization
        self.attn_drop = nn.Dropout(attn_pdrop)
        self.resid_drop = nn.Dropout(resid_pdrop)
        # output projection
        self.proj = nn.Linear(n_embd, n_embd)

    def forward(self, x):
        B, T, C = x.size()

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)

        # full self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = F.softmax(att, dim=-1)  # no masking here, full attention
        att = self.attn_drop(att)
        y = att @ v  # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, C)  # re-assemble all head outputs side by side

        # output projection
        y = self.resid_drop(self.proj(y))
        return y

Causal SelfAttention


class CausalSelfAttention(nn.Module):
    """
    A vanilla multi-head masked self-attention layer with a projection at the end.
    It is possible to use torch.nn.MultiheadAttention here but I am including an
    explicit implementation here to show that there is nothing too scary here.
    """

    def __init__(self, n_embd, block_size, n_head, attn_pdrop, resid_pdrop):
        super().__init__()
        assert n_embd % n_head == 0
        self.n_head = n_head
        # key, query, value projections for all heads
        self.key = nn.Linear(n_embd, n_embd)
        self.query = nn.Linear(n_embd, n_embd)
        self.value = nn.Linear(n_embd, n_embd)
        # regularization
        self.attn_drop = nn.Dropout(attn_pdrop)
        self.resid_drop = nn.Dropout(resid_pdrop)
        # output projection
        self.proj = nn.Linear(n_embd, n_embd)
        # causal mask to ensure that attention is only applied to the left in the input sequence
        self.register_buffer("mask", torch.tril(torch.ones(block_size, block_size))
                                     .view(1, 1, block_size, block_size))
    def forward(self, x, layer_past=None):
        B, T, C = x.size()

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)

        # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        att = self.attn_drop(att)
        y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side

        # output projection
        y = self.resid_drop(self.proj(y))
        return y

Look-ahead MHSA

To implement causal attention with a "look ahead" mechanism (i.e., allowing some future tokens to be attended to, while still maintaining causality by limiting how far into the future attention can extend), you can modify the attention mechanism to apply a causal mask with a specified number of future tokens.

Here's a PyTorch implementation of causal attention with look-ahead:

```python
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class CausalAttentionWithLookAhead(nn.Module):
    """
    Causal self-attention mechanism with a look-ahead window, allowing each token to attend to a limited number
    of future tokens, while still maintaining the causal nature (i.e., no attention to tokens further ahead).

    Usage:

        n_embd = 64  # embedding dimension
        n_head = 8   # number of attention heads
        attn_pdrop = 0.1  # dropout for attention weights
        resid_pdrop = 0.1  # dropout for output projection
        look_ahead_size = 2  # allow attention up to 2 future tokens

        model = CausalAttentionWithLookAhead(n_embd, n_head, attn_pdrop, resid_pdrop, look_ahead_size)
        x = torch.randn(16, 10, n_embd)  # Batch of 16 sequences, each of length 10, embedding size 64
        output = model(x)
        print(output.size())  # should return (16, 10, 64)
    """

    def __init__(self, n_embd, n_head, attn_pdrop, resid_pdrop, look_ahead_size):
        super().__init__()
        assert n_embd % n_head == 0
        self.n_head = n_head
        self.look_ahead_size = look_ahead_size

        # key, query, value projections for all heads
        self.key = nn.Linear(n_embd, n_embd)
        self.query = nn.Linear(n_embd, n_embd)
        self.value = nn.Linear(n_embd, n_embd)

        # regularization
        self.attn_drop = nn.Dropout(attn_pdrop)
        self.resid_drop = nn.Dropout(resid_pdrop)

        # output projection
        self.proj = nn.Linear(n_embd, n_embd)

    def forward(self, x):
        B, T, C = x.size()

        # calculate query, key, values for all heads and move head forward to be the batch dim
        k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)
        q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)
        v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)

        # Causal attention with look ahead: Self-attend (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))

        # Create the causal mask with a look-ahead window
        causal_mask = torch.tril(torch.ones(T, T), diagonal=self.look_ahead_size).view(1, 1, T, T).to(x.device)  # (1, 1, T, T)
        att = att.masked_fill(causal_mask == 0, float('-inf'))

        # Apply softmax and dropout
        att = F.softmax(att, dim=-1)
        att = self.attn_drop(att)

        # Apply attention to the value: (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        y = att @ v
        y = y.transpose(1, 2).contiguous().view(B, T, C)  # re-assemble all head outputs side by side

        # Output projection
        y = self.resid_drop(self.proj(y))
        return y

### Key Points:
1. **Causal Masking with Look-Ahead**:
   - A causal mask is created using `torch.tril`, which generates a lower triangular matrix. The `diagonal=self.look_ahead_size` argument allows attention to future tokens within a window of size `look_ahead_size`. 
   - For example, if `look_ahead_size=2`, each token will be able to attend to up to 2 future tokens in addition to past and current tokens.

2. **Attention Calculation**:
   - As usual, the queries (`q`), keys (`k`), and values (`v`) are computed and reshaped for the multi-head attention operation.
   - The attention scores are computed by multiplying the query matrix with the transpose of the key matrix.
   - After masking, softmax is applied to obtain the attention weights, and these weights are used to compute a weighted sum of the values.

3. **Handling the Future Look-Ahead**:
   - The `look_ahead_size` controls how many future tokens can be attended to. The larger the value, the further ahead the model can look while still being restricted by causality.

ChunkBasedAttention

To modify the `ChunkBasedAttention` implementation for parallel processing during training, the key change is to avoid sequential processing of chunks and instead process all chunks in parallel. This requires reshaping the input tensors so that the attention computation can be performed on all chunks simultaneously. Here's how you can adjust the `ChunkBasedAttention` class for parallel processing:

### Modified `ChunkBasedAttention` with Parallel Processing:

```python
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class ChunkBasedAttention(nn.Module):
    """
    Chunk-based self-attention mechanism with configurable left and right attention chunk sizes.
    This version allows for parallel processing of chunks during training.

    Example Usage:

      n_embd = 64  # embedding dimension
      n_head = 8   # number of attention heads
      attn_pdrop = 0.1  # dropout for attention weights
      resid_pdrop = 0.1  # dropout for output projection
      attention_chunk_size = 4  # size of each chunk
      left_chunk_size = 1  # allow attention to 1 chunk on the left
      right_chunk_size = 1  # allow attention to 1 chunk on the right

      model = ChunkBasedAttention(n_embd, n_head, attn_pdrop, resid_pdrop, attention_chunk_size, left_chunk_size, right_chunk_size)
      x = torch.randn(16, 12, n_embd)  # Batch of 16 sequences, each of length 12, embedding size 64
      output = model(x)
      print(output.size())  # should return (16, 12, 64)

    """

    def __init__(self, n_embd, n_head, attn_pdrop, resid_pdrop, attention_chunk_size, left_chunk_size, right_chunk_size):
        super().__init__()
        assert n_embd % n_head == 0
        self.n_head = n_head
        self.attention_chunk_size = attention_chunk_size
        self.left_chunk_size = left_chunk_size
        self.right_chunk_size = right_chunk_size

        # key, query, value projections for all heads
        self.key = nn.Linear(n_embd, n_embd)
        self.query = nn.Linear(n_embd, n_embd)
        self.value = nn.Linear(n_embd, n_embd)

        # regularization
        self.attn_drop = nn.Dropout(attn_pdrop)
        self.resid_drop = nn.Dropout(resid_pdrop)

        # output projection
        self.proj = nn.Linear(n_embd, n_embd)

    def forward(self, x):
        B, T, C = x.size()  # B: Batch size, T: Sequence length, C: Embedding dimension
        chunk_size = self.attention_chunk_size
        num_chunks = (T + chunk_size - 1) // chunk_size  # Total number of chunks

        # calculate query, key, values for all heads and move head forward to be the batch dim
        k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)
        q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)
        v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)

        # Reshape the queries, keys, and values into chunks for parallel processing
        q_chunks = q.view(B, self.n_head, num_chunks, chunk_size, C // self.n_head)  # (B, nh, num_chunks, chunk_size, hs)
        k_chunks = k.view(B, self.n_head, num_chunks, chunk_size, C // self.n_head)  # (B, nh, num_chunks, chunk_size, hs)
        v_chunks = v.view(B, self.n_head, num_chunks, chunk_size, C // self.n_head)  # (B, nh, num_chunks, chunk_size, hs)

        # Construct the causal mask with left and right chunk sizes
        chunk_mask = torch.zeros(num_chunks, num_chunks).to(x.device)
        for i in range(num_chunks):
            start_idx = max(0, i - self.left_chunk_size)
            end_idx = min(num_chunks, i + self.right_chunk_size + 1)
            chunk_mask[i, start_idx:end_idx] = 1

        # Apply the chunk mask to attention scores
        chunk_mask = chunk_mask.view(1, 1, num_chunks, num_chunks)  # (1, 1, num_chunks, num_chunks)

        # Compute attention for all chunks in parallel
        attn_scores = torch.einsum('bhqnc,bhknc->bhqkn', q_chunks, k_chunks) / math.sqrt(C // self.n_head)  # (B, nh, num_chunks, chunk_size, chunk_size)
        attn_scores = attn_scores.masked_fill(chunk_mask[:, :, None, :, :] == 0, float('-inf'))  # Apply the chunk-based mask
        attn_probs = F.softmax(attn_scores, dim=-1)  # (B, nh, num_chunks, chunk_size, chunk_size)
        attn_probs = self.attn_drop(attn_probs)

        # Apply attention to the value: (B, nh, num_chunks, chunk_size, chunk_size) x (B, nh, num_chunks, chunk_size, hs)
        y_chunks = torch.einsum('bhqkn,bhknc->bhqnc', attn_probs, v_chunks)  # (B, nh, num_chunks, chunk_size, hs)
        y = y_chunks.contiguous().view(B, self.n_head, T, C // self.n_head)  # (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, C)  # Reassemble all head outputs side by side

        # Output projection
        y = self.resid_drop(self.proj(y))
        return y


### Explanation of Changes:
1. **Parallel Processing**:
   - Instead of processing each chunk sequentially, we reshape the input into chunks and process all chunks in parallel using `torch.einsum`. This allows for faster training with GPUs since parallelization is leveraged.

2. **Causal Mask**:
   - A chunk-based mask is constructed to ensure that each chunk can only attend to its allowed left and right neighboring chunks.
   - This mask is then applied to the attention scores to ensure that attention is restricted to a specific range of chunks.

3. **Efficient Attention Calculation**:
   - The attention scores are computed using `torch.einsum`, which efficiently handles the matrix multiplication for all chunks in parallel.
   - The `softmax` operation is applied across the appropriate dimension (`-1`), and then the attention probabilities are used to weight the values.

4. **Reshaping for Output**:
   - After computing the attention-weighted values for all chunks, the output is reshaped back into the original sequence length to ensure consistency with the input format.


### Key Points:
- **Chunk-Based Masking**: Each chunk can attend only to a specified range of chunks to the left and right, and this is enforced using the chunk mask.
- **Parallel Computation**: By reshaping the input into chunks and using `torch.einsum`, we can compute attention for all chunks in parallel, which speeds up training.
- **Flexibility**: The chunk size, left, and right attention window sizes are flexible and can be adjusted based on the model's requirements.

 

+ Recent posts