Overview:
- The
DynamicBatchSampler
groups data samples into batches.
- It takes into account the length of each sample, a target batch size, and a maximum token limit per batch.
- A batch is generated when either the batch size or the sum of sample lengths reaches the specified limits.
Example Code:
import random
class DynamicBatchSampler:
def __init__(self, data_lengths, batch_size, max_tokens):
"""
Initializes the DynamicBatchSampler with data lengths, target batch size, and max tokens constraint.
:param data_lengths: List of lengths of each data sample.
:param batch_size: The base size of each batch.
:param max_tokens: The maximum number of tokens (or length sum) allowed per batch.
"""
self.data_lengths = data_lengths
self.batch_size = batch_size
self.max_tokens = max_tokens
def __iter__(self):
# Shuffle the indices to get a different sampling order each time.
indices = list(range(len(self.data_lengths)))
random.shuffle(indices)
batch = []
total_length = 0
for idx in indices:
length = self.data_lengths[idx]
# Check if adding this sample would exceed max tokens in the batch.
if total_length + length > self.max_tokens or len(batch) >= self.batch_size:
# Yield the current batch and reset for the next one.
yield batch
batch = []
total_length = 0
# Add the current sample to the batch.
batch.append(idx)
total_length += length
# Yield any remaining samples as the last batch.
if batch:
yield batch
# Example usage:
data_lengths = [5, 8, 3, 7, 10, 2, 4, 6] # Example lengths of samples.
batch_size = 3 # Maximum number of samples per batch.
max_tokens = 15 # Maximum total length of samples in a batch.
sampler = DynamicBatchSampler(data_lengths, batch_size, max_tokens)
# Iterate over the batches generated by the DynamicBatchSampler.
for batch in sampler:
print(f"Batch indices: {batch}, Batch lengths: {[data_lengths[i] for i in batch]}")
Explanation:
- Data Lengths:
[5, 8, 3, 7, 10, 2, 4, 6]
represents the lengths of different samples.
- Batch Size:
3
means each batch can have up to 3 samples.
- Max Tokens:
15
restricts each batch to a maximum total length of 15.
Output:
Batch indices: [6, 7], Batch lengths: [4, 6]
Batch indices: [4, 0], Batch lengths: [10, 5]
Batch indices: [1, 2, 5], Batch lengths: [8, 3, 2]
Batch indices: [3], Batch lengths: [7]
How It Works:
- The sampler iterates over the data indices and groups them into batches.
- A batch is finalized when adding another sample would exceed
max_tokens
or batch_size
.
- The example shows how batches are formed dynamically based on the length constraints, making it flexible for varying data sizes.