Overview:

  • The DynamicBatchSampler groups data samples into batches.
  • It takes into account the length of each sample, a target batch size, and a maximum token limit per batch.
  • A batch is generated when either the batch size or the sum of sample lengths reaches the specified limits.

Example Code:

import random

class DynamicBatchSampler:
    def __init__(self, data_lengths, batch_size, max_tokens):
        """
        Initializes the DynamicBatchSampler with data lengths, target batch size, and max tokens constraint.

        :param data_lengths: List of lengths of each data sample.
        :param batch_size: The base size of each batch.
        :param max_tokens: The maximum number of tokens (or length sum) allowed per batch.
        """
        self.data_lengths = data_lengths
        self.batch_size = batch_size
        self.max_tokens = max_tokens

    def __iter__(self):
        # Shuffle the indices to get a different sampling order each time.
        indices = list(range(len(self.data_lengths)))
        random.shuffle(indices)

        batch = []
        total_length = 0

        for idx in indices:
            length = self.data_lengths[idx]
            # Check if adding this sample would exceed max tokens in the batch.
            if total_length + length > self.max_tokens or len(batch) >= self.batch_size:
                # Yield the current batch and reset for the next one.
                yield batch
                batch = []
                total_length = 0

            # Add the current sample to the batch.
            batch.append(idx)
            total_length += length

        # Yield any remaining samples as the last batch.
        if batch:
            yield batch

# Example usage:
data_lengths = [5, 8, 3, 7, 10, 2, 4, 6]  # Example lengths of samples.
batch_size = 3  # Maximum number of samples per batch.
max_tokens = 15  # Maximum total length of samples in a batch.

sampler = DynamicBatchSampler(data_lengths, batch_size, max_tokens)

# Iterate over the batches generated by the DynamicBatchSampler.
for batch in sampler:
    print(f"Batch indices: {batch}, Batch lengths: {[data_lengths[i] for i in batch]}")

Explanation:

  • Data Lengths: [5, 8, 3, 7, 10, 2, 4, 6] represents the lengths of different samples.
  • Batch Size: 3 means each batch can have up to 3 samples.
  • Max Tokens: 15 restricts each batch to a maximum total length of 15.

Output:

Batch indices: [6, 7], Batch lengths: [4, 6]
Batch indices: [4, 0], Batch lengths: [10, 5]
Batch indices: [1, 2, 5], Batch lengths: [8, 3, 2]
Batch indices: [3], Batch lengths: [7]

How It Works:

  • The sampler iterates over the data indices and groups them into batches.
  • A batch is finalized when adding another sample would exceed max_tokens or batch_size.
  • The example shows how batches are formed dynamically based on the length constraints, making it flexible for varying data sizes.

+ Recent posts