아래는 Streaming ASR 서버를 운영(production)한다는 관점에서의 latency 설계 프레임워크입니다. 목표는 단순히 “빠르게”가 아니라, p95 SLA(당신의 경우 500ms)를 안정적으로 지키는 시스템을 만드는 것입니다. A10, 동시 10명, client 20ms 업로드를 전제로 작성합니다.


1) Latency를 “예산”으로 쪼개서 설계한다

Streaming ASR의 end-to-end latency는 보통 아래 합으로 모델링합니다.

[
L_{e2e} = L_{buffer} + L_{batch} + L_{queue} + L_{gpu} + L_{post} + L_{net}
]

각 항목의 의미:

  • (L_{buffer}): hop을 만들기 위한 입력 버퍼링(서버 집계)
  • (L_{batch}): micro-batching window로 인한 추가 대기
  • (L_{queue}): GPU가 바쁠 때 작업이 큐에서 기다린 시간 (tail의 주범)
  • (L_{gpu}): 실제 모델 inference 시간
  • (L_{post}): 후처리(텍스트 정리, punctuation, endpointing)
  • (L_{net}): 전송 및 프레임 처리(WebSocket/gRPC)

p95=500ms 예산 예시(권장)

운영에서 통제 가능한 항목과 통제 불가능(환경/네트워크) 항목을 분리합니다.

  • (L_{buffer}): 160ms (hop=160ms로 고정)

  • (L_{batch}): ≤ 20ms (micro-batching)

  • (L_{net}+L_{post}): 50ms (여유 포함)

  • 그러면 남는 예산:

    • (500 - (160 + 20 + 50) = 270ms)

즉, p95에서 (L_{queue}+L_{gpu})를 270ms 이내로 유지해야 합니다.


2) p95를 망치는 1순위: Queueing(혼잡)

현장에서 p95가 튀는 가장 흔한 원인은 GPU 자체가 느린 게 아니라,

  • “GPU가 잠깐 밀릴 때”
  • “작업이 큐에서 쌓이며”
  • tail이 폭발하는 현상입니다.

따라서 운영 설계의 핵심은:

GPU compute 최적화도 중요하지만, queueing을 제어하는 정책이 더 중요하다.


3) 설계 레버 1: Hop과 Window를 명확히 분리

추천(동시 10명, A10, p95 500ms)

  • hop = 160ms (client 20ms × 8)
  • rolling window = 0.96~1.28s

왜 이렇게 분리하나:

  • hop은 사용자 업데이트 cadence(UX)를 결정
  • window는 GPU compute량을 결정

혼잡할 때는 hop을 건드리기보다 window를 줄이는 것이 p95 방어에 유리합니다.

# overload 시 window 축소 예시
if queue_wait_p95_ms > 150:
    window_secs = 0.96
else:
    window_secs = 1.28

4) 설계 레버 2: Micro-batching은 “p95 보호형”으로

동시 10명에서는 평균 batch size가 크게 안 커집니다.
그렇다면 micro-batching의 목표는 throughput이 아니라:

  • GPU 호출 수를 줄이되
  • 추가 대기 시간을 상한으로 묶는 것

권장:

  • max_latency_ms = 15~20ms
  • max_batch_size = 4~8
  • in_flight = 1 (단일 GPU stream)
batcher = DynamicBatcher(
    max_latency_ms=20,   # p95에 들어가는 상한
    max_batch_size=8,
    flush_interval_ms=2,
)

5) 설계 레버 3: “2단계 타이밍”을 운영 규칙으로 고정

Streaming에서 운영 규칙을 단순화하는 게 중요합니다.

Stage A: ingress (20ms)

  • 클라이언트가 보내는 단위
  • 서버는 여기서 “수신”만 한다

Stage B: inference trigger (160ms)

  • 서버가 GPU를 부르는 단위
  • hop마다 1회 job 생성

이 2단계 타이밍이 무너지면:

  • CPU도 불안정해지고
  • batching도 깨지고
  • tail이 흔들립니다.

6) p95 방어의 핵심: Overload Control(Degradation Policy)

운영에서 반드시 넣어야 하는 것은 “품질을 조금 희생해서 SLA를 지키는 장치”입니다.

추천 degradation 순서(실전)

  1. Emit throttling: partial 업데이트 빈도 줄이기

    • 예: 160ms → 320ms로 partial emit만 줄임(호출 자체는 유지 가능)
  2. Window 축소: 1.28s → 0.96s

  3. Decoder 비용 축소: beam 4 → 2 → greedy

  4. (최후) 세션 admission control: 신규 세션 거절/대기

if queue_wait_p95_ms > 200:
    emit_interval_ms = 320
    window_secs = 0.96
    beam = 2
elif queue_wait_p95_ms > 120:
    window_secs = 0.96

이 정책의 핵심은:

  • p95가 깨지기 전에 미리 조치 (proactive)
  • “어떤 값이 기준인지”를 명확히 로그로 남김

7) “측정” 없이는 설계가 아니다: 운영 메트릭 정의

p95를 지키려면 최소 다음 지표는 분리해서 봐야 합니다.

Required metrics

  • hop_buffer_ms (고정에 가까움)
  • batch_wait_ms (micro-batching 대기)
  • queue_wait_ms (GPU 앞에서의 대기)
  • gpu_infer_ms (실행 시간)
  • e2e_latency_ms (총합)
  • batch_size 분포

그리고 p50/p95를 각각 기록합니다.

@dataclass
class LatencyBreakdown:
    hop_buffer_ms: int
    batch_wait_ms: int
    queue_wait_ms: int
    gpu_infer_ms: int
    post_ms: int
    net_ms: int

운영에서 중요한 건 “총 p95”가 아니라,

  • 어느 항목이 p95를 만드는지를 매일 확인하는 것입니다.

8) 동시 10명 기준, 현실적인 기대치

잘 설계된 경우의 감각적 범위:

  • p50 e2e: ~220–320ms
  • p95 e2e: ~350–500ms
  • batch size 평균: 1–2 (가끔 3–4)

p95가 튀기 시작하면 대부분:

  • queue_wait_ms가 먼저 오릅니다.
    이걸 감지해서 window/beam을 줄이는 게 운영의 핵심입니다.

9) 운영 관점 설계 요약

  1. p95 예산을 분해하고, 각 항목에 상한을 둔다
  2. hop과 window를 분리하고, 혼잡 시 window부터 줄인다
  3. micro-batching은 throughput보다 p95 보호형 상한이 우선
  4. queueing이 tail의 주범이므로 degradation 정책을 필수로 둔다
  5. latency breakdown을 항목별로 계측하고, p95 원인을 추적한다

원하시면, 위 설계를 바로 코드로 연결해서:

  • FastAPI(WebSocket) 기준으로 Latency breakdown 로깅 + p95 계산 + 자동 degradation까지 포함한 “운영용 샘플 서버” 코드를 작성해 드리겠습니다.

Notes

 

Measuring transcription latency – for each correctly transcribed word, measure the time from when the word ends in the audio stream (x) to when that same word first appears in the partial transcript received by the client (y). Latency = y − x. Only words that are transcribed correctly are included in the calculation.The latency is measured using audio files balanced across:Clean speech – LibriSpeech excerptsReal-world calls – internal benchmark recordings (retail support, drive-thru, B2B triage)Stress clips – crafted edge-cases (rapid speaker turns, burst noise, long silences)This mix captures everyday usage and the extreme scenarios that typically break streaming systems.

 

Measuring transcription accuracy - We use word error rate (WER) for accuracy measurement. For comprehensive evaluation, we use datasets totaling more than 205 hours of audio. These datasets cover various domains, including meetings, broadcasts, and call centers, as well as a wide range of English accents. They also encompass various audio conditions, in terms of duration, signal-to-noise ratios, speech-to-silence ratios, and other factors.

🚀 Prefix Sum + Hashing Technique

The Prefix Sum + Hashing technique is a powerful approach to solving problems that involve subarrays with a specific sum condition. It helps optimize problems that would otherwise require O(N²) brute force solutions to an efficient O(N) approach using a hashmap (dictionary in Python).


🔹 When to Use Prefix Sum + Hashing?

✅ Finding subarrays with a given sum.
✅ Problems where we need fast lookups of previously seen prefix sums.
✅ Used frequently in range sum queries, subarrays, and constraints on sum differences.


🔹 Concept Behind Prefix Sum + Hashing

1️⃣ Compute the prefix sum:

  • The prefix sum at index i is the sum of all elements from 0 to i.
  • This helps us quickly compute subarray sums without iterating through the array multiple times.

2️⃣ Use a HashMap (Dictionary) to store prefix sums:

  • The hashmap stores the prefix sum as a key and its occurrence count or index as a value.
  • This allows quick lookups to check if a required sum exists.

3️⃣ Find the target sum efficiently:

  • If we want a subarray sum of target, we check if (prefix_sum - target) exists in the hashmap.

🔹 Example: Subarray Sum Equals K (LeetCode 560)

Problem: Find the number of contiguous subarrays whose sum equals k.
💡 Optimal Solution: Use prefix sum + hashmap to track occurrences of (prefix_sum - k).

🔹 Implementation

from collections import defaultdict

def subarraySum(nums, k):
    prefix_sum = 0
    count = 0
    sum_freq = defaultdict(int)
    sum_freq[0] = 1  # To handle cases where prefix_sum itself is k

    for num in nums:
        prefix_sum += num  # Compute prefix sum

        # Check if (prefix_sum - k) exists in hashmap
        if (prefix_sum - k) in sum_freq:
            count += sum_freq[prefix_sum - k]

        # Store the current prefix sum in the hashmap
        sum_freq[prefix_sum] += 1

    return count

🔹 Explanation with Example

Input:

nums = [1, 2, 3, 2, -3, 1, 4]
k = 5

We track the prefix sum and use a hashmap to store counts.

Index Element Prefix Sum (Prefix Sum - k) Exists? Count Updated? HashMap (sum_freq)

0 1 1 ❌ No 0 {0:1, 1:1}
1 2 3 ❌ No 0 {0:1, 1:1, 3:1}
2 3 6 ✅ Yes (6-5=1) +1 (count=1) {0:1, 1:1, 3:1, 6:1}
3 2 8 ✅ Yes (8-5=3) +1 (count=2) {0:1, 1:1, 3:1, 6:1, 8:1}
4 -3 5 ✅ Yes (5-5=0) +1 (count=3) {0:1, 1:1, 3:1, 6:1, 8:1, 5:1}
5 1 6 ✅ Yes (6-5=1) +1 (count=4) {0:1, 1:1, 3:1, 6:2, 8:1, 5:1}
6 4 10 ✅ Yes (10-5=5) +1 (count=5) {0:1, 1:1, 3:1, 6:2, 8:1, 5:1, 10:1}

Final Count: 5 (5 valid subarrays with sum = 5)


🔹 Time & Space Complexity

  • Time Complexity: O(N) (each element is processed once).
  • Space Complexity: O(N) (hashmap stores at most N prefix sums).

🔹 Other Problems Using Prefix Sum + Hashing

Problem Description

[LeetCode 325] Maximum Size Subarray Sum Equals K Find the longest subarray with sum = k
[LeetCode 523] Continuous Subarray Sum Check if subarray sum is a multiple of k
[LeetCode 974] Subarray Sums Divisible by K Count subarrays whose sum is divisible by k
[LeetCode 930] Binary Subarrays With Sum Count subarrays with sum exactly k in a binary array

🔹 Summary

Concept Description

Prefix Sum Computes cumulative sum to track subarrays efficiently
Hashing Stores prefix sums to allow O(1) lookup
Time Complexity O(N) using hashmap
Use Cases Finding subarrays with sum constraints

🚀 Takeaways

1️⃣ Always store the prefix sum in a hashmap for quick lookups.
2️⃣ (prefix_sum - target) in hashmap → A valid subarray is found.
3️⃣ O(N) time complexity makes it much faster than brute force O(N²).

Would you like me to explain any variations of this technique? 🚀

'ML Engineering > python' 카테고리의 다른 글

🚀 Backtracking  (0) 2025.03.19
🚀 Dynamic Sliding Window Technique  (0) 2025.03.19
[Sort] Merge Sort  (0) 2025.01.11
Heap/Quickselect | Top K Frequent Elements  (1) 2024.10.26
Heap/Quickselect | K Closest Points to Origin  (0) 2024.10.26

Backtracking is a powerful algorithmic technique used for solving problems that require searching through all possible solutions. It is particularly useful when the problem involves combinatorial search or decision-making with constraints.


🔹 Categories of Backtracking Problems

Backtracking problems generally fall into the following categories:

1️⃣ Subset and Subsequence Generation

  • Generate all subsets or subsequences of a given set.
  • Often involves including or excluding elements recursively.

📌 Examples:

💡 Pattern:

  • Choose an element, make a recursive call including it, then backtrack by removing it.

2️⃣ Permutations and Arrangements

  • Generate all possible orderings of a given set of elements.
  • May involve duplicates or additional constraints.

📌 Examples:

💡 Pattern:

  • Swap elements in-place to generate permutations.
  • Use a boolean array or set to track used elements.

3️⃣ Combination Problems

  • Generate combinations of elements where order does not matter.
  • Useful for choosing k elements from n.

📌 Examples:

💡 Pattern:

  • Recursively pick elements while ensuring uniqueness.
  • Maintain a current combination list and backtrack when needed.

4️⃣ Constraint Satisfaction Problems (CSP)

  • Problems where a solution must satisfy constraints, often involving grid-based solutions.

📌 Examples:

💡 Pattern:

  • Use recursion with constraints to limit the search space.
  • Maintain additional data structures (e.g., boolean arrays for rows/columns).

5️⃣ Path-Finding in Grids (Maze, Rat in a Maze)

  • Problems where you need to explore all possible paths in a grid.

📌 Examples:

💡 Pattern:

  • Use DFS with backtracking to explore possible paths.
  • Mark visited cells and unmark when backtracking.

6️⃣ Word and String Problems

  • Problems that involve constructing or searching for words using recursive backtracking.

📌 Examples:

💡 Pattern:

  • Use recursive DFS to explore character sequences.
  • Often combined with Trie data structures.

7️⃣ Mathematical and Number Problems

  • Problems that involve numbers, sequences, or generating numeric solutions.

📌 Examples:

💡 Pattern:

  • Use recursive choices with constraints on numbers.
  • May involve pruning to optimize the search.

🔹 Key Techniques in Backtracking

Recursion with Decision Trees: Explore all possible choices and backtrack when invalid.
Pruning (Branch Cutting): Avoid unnecessary recursive calls to improve efficiency.
Bitmasking & Hashing: Track state efficiently in some cases (e.g., N-Queens).
Memoization (Hybrid with DP): Store previously computed results to avoid recomputation.


🚀 Summary Table

Category Examples Pattern

Subsets & Subsequences Subsets, Palindrome Partitioning Include/exclude elements recursively
Permutations Permutations, Permutation Sequence Swap elements, track visited states
Combinations Combinations, Combination Sum Recursively pick k elements
Constraint Satisfaction Sudoku Solver, N-Queens Recursively place elements with constraints
Path-Finding Word Search, Unique Paths III DFS with backtracking in a grid
Word/String Problems Letter Combinations, Word Search Recursive character exploration
Math & Number Problems Beautiful Arrangement, 24 Game Recursive numerical choices

🛠 How to Get Better at Backtracking?

1️⃣ Solve Problems from Each Category – Start with simple ones like "Subsets" and then move to harder ones like "N-Queens".
2️⃣ Draw Decision Trees – Understand how recursion unfolds.
3️⃣ Practice Writing Constraints – This helps prune the search space.
4️⃣ Combine with Other Techniques – Hybridize with DFS, Dynamic Programming, and Bitmasking for advanced problems.


💡 Would you like a step-by-step explanation of any of these categories? 🚀

The Dynamic Sliding Window technique is an optimization over the fixed-size sliding window. Instead of keeping the window at a constant size, we dynamically expand and shrink the window based on constraints.


🔹 Key Concept

The window expands until the constraint is violated and shrinks until it becomes valid again.
The main idea is to use a two-pointer (left-right) approach where:

  • Expand right to include more elements.
  • Shrink left when a condition is violated.

This method is highly effective in subarray problems, frequency-based problems, and problems with constraints.


🔹 When to Use Dynamic Sliding Window?

  • Finding longest/shortest subarray that meets a condition.
  • Problems with constraints (e.g., maximum replacements allowed, sum conditions).
  • Problems where we need to track frequency/count of elements in a window.

🛠 Key Patterns in Dynamic Sliding Window

1️⃣ Longest Subarray with a Constraint

You keep expanding the window and shrink only when needed.

🔹 Example 1: Longest Substring with At Most K Distinct Characters

from collections import defaultdict

def longestSubstringKDistinct(s, k):
    char_count = defaultdict(int)
    left = 0
    max_length = 0

    for right in range(len(s)):
        char_count[s[right]] += 1

        while len(char_count) > k:  # More than k distinct chars
            char_count[s[left]] -= 1
            if char_count[s[left]] == 0:
                del char_count[s[left]]
            left += 1  # Shrink window

        max_length = max(max_length, right - left + 1)

    return max_length

Expands until it violates the constraint (more than k distinct chars).
Shrinks until it meets the constraint again.


2️⃣ Smallest Subarray with a Constraint

🔹 Example 2: Smallest Subarray with Sum ≥ S

def minSubArrayLen(target, nums):
    left = 0
    curr_sum = 0
    min_length = float('inf')

    for right in range(len(nums)):
        curr_sum += nums[right]

        while curr_sum >= target:
            min_length = min(min_length, right - left + 1)
            curr_sum -= nums[left]  # Shrink window
            left += 1

    return min_length if min_length != float('inf') else 0

Expands while sum is below target.
Shrinks to find the smallest valid subarray.


3️⃣ Longest Subarray with K Replacements (Binary or 1s/0s Problems)

🔹 Example 3: Longest Subarray of 1s After Replacing at Most K Zeros

def longestOnes(nums, k):
    left = 0
    max_length = 0
    zeros_count = 0

    for right in range(len(nums)):
        if nums[right] == 0:
            zeros_count += 1

        while zeros_count > k:
            if nums[left] == 0:
                zeros_count -= 1
            left += 1  # Shrink window

        max_length = max(max_length, right - left + 1)

    return max_length

Expands while zeros_count ≤ k.
Shrinks only when zeros_count > k.


🔹 Complexity Analysis

For all Dynamic Sliding Window problems:

  • Time Complexity: O(N) (each element is visited at most twice).
  • Space Complexity: O(1) (constant space unless storing elements).

🔹 Summary Table

Problem Type Pattern Example

Longest Subarray with a Constraint Expand → Shrink Longest substring with k distinct chars
Smallest Subarray with a Constraint Expand → Shrink to minimize Smallest subarray with sum ≥ S
Replacing Elements in a Window Expand → Shrink when replacements exceed limit Longest subarray of 1s after replacing k 0s

🔹 Key Takeaways

Expand as much as possible until a condition breaks.
Shrink only when necessary to regain validity.
Time Complexity is O(N), making it much faster than brute force.

Would you like a visual explanation for any of these problems? 🚀

'ML Engineering > python' 카테고리의 다른 글

🚀 Prefix Sum + Hashing Technique  (0) 2025.03.19
🚀 Backtracking  (0) 2025.03.19
[Sort] Merge Sort  (0) 2025.01.11
Heap/Quickselect | Top K Frequent Elements  (1) 2024.10.26
Heap/Quickselect | K Closest Points to Origin  (0) 2024.10.26

I am going to have two functions. The first one is to handle the recursive splitting of the array into smaller parts, and the second one will handle merging those parts back together in sorted order.

The first function is the main merge_sort function. Its job is to recursively divide the input array into smaller parts until each part contains only one element or no elements.

The second function is called merge, and its purpose is to take two sorted arrays and combine them into a single sorted array

def merge_sort(arr):
    if len(arr) <= 1:
        return arr  # Base case: array with 0 or 1 elements is already sorted

    # Split the array into two halves
    mid = len(arr) // 2
    left = merge_sort(arr[:mid])
    right = merge_sort(arr[mid:])

    # Merge the sorted halves
    return merge(left, right)

def merge(left, right):
    sorted_array = []
    i = j = 0

    # Compare elements from left and right and merge them in sorted order
    while i < len(left) and j < len(right):
        if left[i] <= right[j]:
            sorted_array.append(left[i])
            i += 1
        else:
            sorted_array.append(right[j])
            j += 1

    # Add any remaining elements from left or right
    sorted_array.extend(left[i:])
    sorted_array.extend(right[j:])

    return sorted_array

# Example Usage
arr = [38, 27, 43, 3, 9, 82, 10]
sorted_arr = merge_sort(arr)
print("Sorted Array:", sorted_arr)
Time Complexity.
Best Case, Worst Case, and Average Case: O(nlogn)
Divide Step:

    The array is repeatedly divided into two halves until we reach arrays of size 1.
    The number of divisions is equal to the height of the recursion tree, which is log_2(n).
Merge Step:
    Each level of the recursion tree processes 𝑛 elements during merging.
    At each level, the merging of two halves requires O(n) time.

Total Work:
The total work across all levels of the recursion tree is:

O(n)+O(n)+... + O(n) (for log2(n) levels) =  O(nlogn)

Key Insight: The log 𝑛 factor comes from the depth of the recursion (splitting), and the 𝑛 factor comes from merging at each level.

Space Complexity.
O(n)
Why?

The algorithm requires additional space for the temporary arrays used during the merge step.
At any given time, we store a portion of the array in temporary arrays for merging, which can take up to O(n) space

Recursion Stack Space: 𝑂(log𝑛)
The recursion depth corresponds to the height of the recursion tree, which is log2(n)

total = log2(n) + O(n) = O(n)

21. Merge Two Sorted Lists

def mergeTwoLists(self, list1: Optional[ListNode], list2: Optional[ListNode]) -> Optional[ListNode]:
    dummy = ListNode()
    cur = dummy

    while list1 and list2:
        if list1.val < list2.val:
            cur.next = list1
            list1 = list1.next
        else:
            cur.next = list2
            list2 = list2.next
        cur = cur.next
    if list1:
        cur.next = list1
    if list2:
        cur.next = list2
    return dummy.next

# O(m+n)
# O(1)

23. Merge k Sorted Lists

class Solution:
    def mergeKLists(self, lists: List[Optional[ListNode]]) -> Optional[ListNode]:
        '''
        1. implement mergeTwoLists
        2. will divide and conquar method with mergeTwoLists
        '''
        if not lists:
            return None

        if len(lists) == 1:
            return lists[0]

        mid = len(lists) // 2 
        left = self.mergeKLists(lists[:mid])
        right = self.mergeKLists(lists[mid:])

        return self.mergeTwoLists(left, right)

    def mergeTwoLists(self, list1, list2):
        dummy = ListNode()
        cur = dummy

        while list1 and list2:
            if list1.val < list2.val:
                cur.next = list1
                list1 = list1.next
            else:
                cur.next = list2
                list2 = list2.next
            cur = cur.next

        if list1:
            cur.next = list1
        elif list2:
            cur.next = list2
        return dummy.next

# N: Total number of nodes across all linked lists.
# k: Number of linked lists.
#The algorithm splits the 𝑘 lists into two halves repeatedly until only one list remains.
# This takes log 𝑘 levels of merging.        

# At each level of merging, all 𝑁 nodes are processed across all pairs of lists.

#Space

# The algorithm divides the list recursively into two halves.
# At each recursive call, the depth of the recursion increases by 1.
# For 𝑘 lists, the depth of recursion is log 𝑘.

 

더보기

The most of the following contents are from the folloing link. This posting is just newly organized format of the huyenchip's posting so that it is easy for me to understand and remebmer. 

https://huyenchip.com/2023/05/02/rlhf.html

 

Three steps to build production performance LLM

  1. Pretraining
  2. Supervised Fine-Tuning (SFT)
  3. Reinforcement Learning from Human Feedback (RLHF)

 

Figure 1. the development process for ChatGPT

 

 

You can skip any of the three phases. For example, you can do RLHF directly on top of the pretrained model, without going through the SFT phase. However, empirically, combining all these three steps gives the best performance.

 

Pretraining is the most resource-intensive phase. For the InstructGPT model, pretraining takes up 98% of the overall compute and data resources. You can think of SFT and RLHF as unlocking the capabilities that the pretrained model already has but are hard for users to access via prompting alone.

 

Phase I. Pretraining

The result of the pretraining phase is a large language model (LLM), often known as the pretrained model. Examples include GPT-x (OpenAI), Gopher (DeepMind), LLaMa (Meta), StableLM (Stability AI).

 

Language Model

A language model encodes statistical information about language. For simplicity, statistical information tells us how likely something (e.g. a word, a character) is to appear in a given context. You can think of tokens as the vocabulary that a language model uses.

 

Fluent speakers of a language subconsciously have statistical knowledge of that language. For example, given the context My favorite color is __, if you speak English, you know that the word in the blank is much more likely to be green than car.

 

Mathematical formulation

  • ML task: language modeling
  • Training data: low-quality data
  • Data scale: usually in the order of trillions of tokens as of May 2023.
  • Model resulting from this process: LLM

Next token prediction training

 

Data for pre-training

Side-by-side comparison of RedPajama and LLaMa data, done by RedPajama.

A trillion token is: a book contains around 50,000 words or 67,000 tokens. 1 trillion tokens are equivalent to 15 million books.

 

Phase II. Supervised finetuning (SFT) for dialogue

 

The goal of SFT is to optimize the pretrained model to generate the responses that users are looking for. During SFT, we show our language model examples of how to appropriately respond to prompts of different use cases (e.g. question answering, summarization, translation). The examples follow the format (prompt, response) and are called demonstration data. OpenAI calls supervised finetuning behavior cloning. OpenAI showed that the finetuned approach produces much superior results.

 

The distribution of prompts used to finetune InstructGPT

Demonstration data

Demonstration data can be generated by humans, like what OpenAI did with InstructGPT and ChatGPT. demonstration data is generated by highly educated labelers ( ~90% have at least a college degree and more than one-third have a master’s degree.) OpenAI’s 40 labelers created around 13,000 (prompt, response) pairs for InstructGPT. 

 

Mathematical formulation

The mathematical formulation is very similar to the one in phase 1.

  • ML task: language modeling
  • Training data: high-quality data in the format of (prompt, response)
  • Data scale: 10,000 - 100,000 (prompt, response) pairs
  • Model input and output
    • Input: prompt
    • Output: response for this prompt
  • Loss function to minimize during the training process: cross entropy, but only the tokens in the response are counted towards the loss.

 

Phase III. Reinforcement Learning from Human Feedback (RLHF)

 

Dialogues are flexible. Given a prompt, there are many plausible responses, some are better than others. Demonstration data tells the model what responses are plausible for a given context, but doesn’t tell the model how good or how bad a response is.

 

The idea: what if we have a scoring function that, if given a prompt and a response, outputs a score for how good that response is? Then we use this scoring function to further train our LLMs towards giving responses with high scores. That’s exactly what RLHF does. RLHF consists of two parts:

  1. Train a reward model to act as a scoring function.
  2. Optimize LLM to generate responses for which the reward model will give high scores.

 

Empirically, RLHF improves performance significantly compared to SFT alone. Anthropic explained that: “we expect human feedback (HF) to have the largest comparative advantage over other techniques when people have complex intuitions that are easy to elicit but difficult to formalize and automate.”

 

»»Side note: Hypotheses on why RLHF works««

Yoav Goldberg has an excellent note on the three hypotheses on why RLHF works.

  • The diversity hypothesis: during SFT, the model’s output is expected to somewhat match the demonstrated responses. For example, given the prompt “what’s an example of a language?”, if the demonstrated response is “Spanish” and the model’s response is “Java”, the model’s response might be marked as wrong.
  • The negative feedback hypothesis: demonstration only gives the model positive signals (e.g. only showing the model good responses), not negative signals (e.g. showing models what bad responses look like). RL allows us to show models negative signals.
  • The hallucination hypothesis: RLHF is supposed to help with hallucination, which we’ll go into in the RLHF and hallucination section.

 

3.1. Reward model (RM)

The RM’s job is to output a score for a pair of (prompt, response). Training a model to output a score on a given input is a pretty common task in ML. You can simply frame it as a classification or a regression task. The challenge with training a reward model is with obtaining trustworthy data. Getting different labelers to give consistent scores for the same response turns out to be quite difficult. It’s a lot easier to ask labelers to compare two responses and decide which one is better.

 

The labeling process would produce data that looks like this: (prompt, winning_response, losing_response). This is called comparison data.

 

Here’s an example of comparison data from  Anthropic ’s HH-RLHF dataset

Mathematical formulation

There might be some variations, but here’s the core idea.

  • Training data: high-quality data in the format of (prompt, winning_response, losing_response)
  • Data scale: 100K - 1M examples
    • InstructGPT: 50,000 prompts. Each prompt has 4 to 9 responses, forming between 6 and 36 pairs of (winning_response, losing_response). This means between 300K and 1.8M training examples in the format of (prompt, winning_response, losing_response).
    • Constitutional AI, which is suspected to be the backbone of Claude (Anthropic): 318K comparisons – 135K generated by humans, and 183K generated by AI. Anthropic has an older version of their data open-sourced (hh-rlhf), which consists of roughly 170K comparisons.

3.2. Finetuning using the reward model

 

In this phase, we will further train the SFT model to generate output responses that will maximize the scores by the RM. Today, most people use Proximal Policy Optimization (PPO), a reinforcement learning algorithm released by OpenAI in 2017.

 

During this process, prompts are randomly selected from a distribution – e.g. we might randomly select among customer prompts. Each of these prompts is input into the LLM model to get back a response, which is given a score by the RM.

 

Diagram that explains the  SFT and RLHF  for InstructGPT by OpenAI

 

Mathematical formulation

  • ML task: reinforcement learning
    • Action space: the vocabulary of tokens the LLM uses. Taking action means choosing a token to generate.
    • Observation space: the distribution over all possible prompts.
    • Policy: the probability distribution over all actions to take (aka all tokens to generate) given an observation (aka a prompt). An LLM constitutes a policy because it dictates how likely a token is to be generated next.
    • Reward function: the reward model.
  • Training data: randomly selected prompts
  • Data scale: 10,000 - 100,000 prompts

RLHF and hallucination

Hallucination happens when an AI model makes stuff up. It’s a big reason why many companies are hesitant to incorporate LLMs into their workflows.

 

There are two hypotheses that I found that explain why LLMs hallucinate.

The first hypothesis, first expressed by Pedro A. Ortega et al. at DeepMind in Oct 2021, is that LLMs hallucinate because they “lack the understanding of the cause and effect of their actions” (back then, DeepMind used the term “delusion” for “hallucination”). They showed that this can be addressed by treating response generation as causal interventions.

 

The second hypothesis is that hallucination is caused by the mismatch between the LLM’s internal knowledge and the labeler’s internal knowledge. In his UC Berkeley talk (April 2023), John Schulman, OpenAI co-founder and PPO author, suggested that behavior cloning causes hallucination. During SFT, LLMs are trained to mimic responses written by humans. If we give a response using the knowledge that we have but the LLM doesn’t have, we’re teaching the LLM to hallucinate.

 

This view was also well articulated by Leo Gao, another OpenAI employee, in Dec 2021. In theory, the human labeler can include all the context they know with each prompt to teach the model to use only the existing knowledge. However, this is impossible in practice.

 

 

더보기

The most of the following contents are from the folloing link. This posting is just newly organized format of the huyenchip's posting so that it is easy for me to understand and remebmer. 

https://huyenchip.com/2023/05/02/rlhf.html

 

https://leetcode.com/problems/top-k-frequent-elements/description/?envType=company&envId=facebook&favoriteSlug=facebook-thirty-days

 

Here’s an updated structured script that includes the Quickselect approach for solving the Top K Frequent Elements problem, along with sorting, heap, and bucket sort methods.


You:
"To solve the problem of finding the k most frequent elements in a list of integers, there are a few different approaches we can take, depending on the input size and the desired level of efficiency. I’ll walk through each approach, from simplest to most optimized, including their pros and cons.

1. Sorting Solution

The simplest approach is to use sorting:

  • Steps: First, count the frequency of each element using a dictionary or Python's Counter class. Then, convert the frequency dictionary to a list of (element, frequency) tuples and sort this list by frequency in descending order. Finally, select the first k elements from the sorted list.
  • Time Complexity: O(n log n), due to sorting the entire list of elements by frequency.
  • Space Complexity: O(n) for the frequency dictionary and sorted list.
  • Pros:
    • Straightforward and easy to implement.
    • Suitable for small to moderate input sizes.
  • Cons:
    • Sorting is inefficient for large lists when we only need the top k elements. Sorting all elements doesn’t leverage the partial results we need.

2. Heap Solution (Efficient for Larger Lists with Small k)

A more efficient approach for larger inputs is to use a min-heap:

  • Steps: After creating a frequency dictionary, we use a min-heap to keep track of the k most frequent elements. For each element and its frequency, we push it into the heap. If the heap size exceeds k, we remove the element with the smallest frequency, ensuring that only the k most frequent elements remain.
  • Time Complexity: O(n log k), where we add n elements to a heap of size k.
  • Space Complexity: O(n) for the dictionary and O(k) for the heap.
  • Pros:
    • Efficient for large inputs where k is small relative to n.
    • Uses less space by storing only k elements in the heap.
  • Cons:
    • More complex to implement due to heap operations.

3. Bucket Sort Solution (Optimal for Frequency-Based Grouping)

An even more efficient approach in terms of time complexity is bucket sort:

  • Steps: After building the frequency dictionary, we create an array of buckets where each index represents a frequency count. Each bucket stores elements that appear that many times. Finally, we collect the top k elements by iterating through the buckets from highest to lowest frequency.
  • Time Complexity: O(n), as we only count elements and place them into frequency-based buckets.
  • Space Complexity: O(n) for the dictionary and buckets.
  • Pros:
    • Highly efficient for large inputs and avoids sorting or heap maintenance.
    • Works well for situations where k is close to n.
  • Cons:
    • Bucket sort can be less intuitive to implement, and requires extra space for the buckets.

4. Quickselect Solution (Optimal for Top-k Selection)

Another highly efficient solution, especially for very large lists, is Quickselect:

  • Steps: Quickselect is a partition-based algorithm similar to quicksort. After building the frequency dictionary, we convert it into a list of (element, frequency) pairs and use Quickselect to partially sort the list such that the k most frequent elements are positioned in the first k spots. We partition the list until the k-th most frequent element is in the correct position, and return the first k elements.
  • Time Complexity: O(n) on average, as Quickselect only partially sorts the list to find the top k elements.
  • Space Complexity: O(n) for the dictionary and list.
  • Pros:
    • Very efficient with an average-case complexity of O(n), especially for very large lists.
    • Avoids sorting the entire list, which makes it faster than the sorting approach.
  • Cons:
    • The worst-case complexity is O(n^2), though using random pivot selection helps mitigate this risk.
    • Quickselect is more complex to implement and understand compared to other solutions.

Summary:

  • Sorting: Simple but less efficient for large inputs, with O(n log n) complexity.
  • Heap: Ideal for large lists when k is much smaller than n, with O(n log k) complexity.
  • Bucket Sort: Optimized for large lists and frequency-based grouping, with O(n) complexity, though it requires additional space.
  • Quickselect: Offers the best average-case efficiency with O(n) time complexity, ideal for very large lists and when k is close to n.

Each solution has its trade-offs, so I’d choose the approach based on input size and constraints. For large lists with small k, the heap or Quickselect approach would be optimal, while for lists where k is close to n, bucket sort may be best."


This script provides a structured breakdown of each solution, explaining when each approach is optimal based on the constraints, making it easy to decide the best solution.


Problem Statement Recap:
You are given a list of integers and an integer k. The goal is to return the k most frequent elements in the list.


1. Sorting Solution

Explanation:

  • Start by counting the frequency of each element in the list. We can use Python's Counter from the collections module to achieve this.
  • Once we have the frequency of each element, we convert the frequency dictionary into a list of tuples, where each tuple is (element, frequency).
  • Sort this list of tuples in descending order based on frequency.
  • Finally, select the first k elements from this sorted list.

Implementation:

from collections import Counter
from typing import List

def top_k_frequent_sort(nums: List[int], k: int) -> List[int]:
    # Step 1: Count frequencies
    freq_count = Counter(nums)
    # Step 2: Sort items by frequency in descending order
    sorted_items = sorted(freq_count.items(), key=lambda item: item[1], reverse=True)
    # Step 3: Extract the first k elements
    return [item[0] for item in sorted_items[:k]]

# Example usage
nums = [1, 1, 1, 2, 2, 3]
k = 2
print(top_k_frequent_sort(nums, k))  # Output: [1, 2]

Time Complexity: O(n log n)

  • Counting frequencies takes O(n), while sorting the items by frequency takes O(n log n).

Space Complexity: O(n) for storing the frequency dictionary and sorted list.

Pros:

  • Simple and straightforward to implement.
  • Effective for small to medium inputs.

Cons:

  • Sorting the entire frequency list is unnecessary when we only need the top k elements, making it less efficient for large inputs.

2. Heap Solution (Optimal for Large Lists with Small k)

Explanation:

  • After counting the frequency of each element, we use a min-heap of size k to keep track of the k most frequent elements.
  • We push each element along with its frequency into the heap.
    • If the heap exceeds size k, we remove the element with the smallest frequency (root of the min-heap).
  • By the end, the heap contains only the k most frequent elements.

Implementation:

import heapq
from collections import Counter
from typing import List

def top_k_frequent_heap(nums: List[int], k: int) -> List[int]:
    # Step 1: Count frequencies
    freq_count = Counter(nums)
    # Step 2: Use a min-heap of size k
    min_heap = []
    for num, freq in freq_count.items():
        heapq.heappush(min_heap, (freq, num))
        if len(min_heap) > k:
            heapq.heappop(min_heap)
    # Step 3: Extract elements from the heap
    return [num for (freq, num) in min_heap]

# Example usage
nums = [1, 1, 1, 2, 2, 3]
k = 2
print(top_k_frequent_heap(nums, k))  # Output: [2, 1]

Time Complexity: O(n log k)

  • Counting frequencies takes O(n), and maintaining a heap of size k takes O(log k) time for each of the n elements.

Space Complexity: O(n) for the frequency dictionary and O(k) for the heap.

Pros:

  • Efficient for large inputs when k is much smaller than n.
  • Keeps track of only k elements, optimizing space and time usage.

Cons:

  • Slightly more complex to implement due to heap management.

3. Bucket Sort Solution (Optimal for Frequency-Based Grouping)

Explanation:

  • This approach leverages the fact that the frequency of elements is bounded by the length of the list (n), as the maximum frequency an element can have is n.
  • Create an array of n+1 empty buckets (index 0 to n). Each bucket at index i holds a list of elements that appear i times.
  • Place each element from the frequency dictionary into its corresponding bucket based on frequency.
  • Finally, iterate through the buckets in reverse order, collecting elements until we have k elements.

Implementation:

from collections import Counter
from typing import List

def top_k_frequent_bucket_sort(nums: List[int], k: int) -> List[int]:
    # Step 1: Count frequencies
    freq_count = Counter(nums)
    # Step 2: Initialize buckets where index is frequency
    buckets = [[] for _ in range(len(nums) + 1)]
    for num, freq in freq_count.items():
        buckets[freq].append(num)
    # Step 3: Gather top k elements from the buckets
    result = []
    for i in range(len(buckets) - 1, 0, -1):
        for num in buckets[i]:
            result.append(num)
            if len(result) == k:
                return result

# Example usage
nums = [1, 1, 1, 2, 2, 3]
k = 2
print(top_k_frequent_bucket_sort(nums, k))  # Output: [1, 2]

Time Complexity: O(n)

  • Counting frequencies takes O(n), and placing elements into buckets also takes O(n).

Space Complexity: O(n) for the frequency dictionary and buckets.

Pros:

  • Very efficient for problems where n is large and k is close to n.
  • No sorting or heap maintenance required, and it handles frequencies directly.

Cons:

  • Bucket sort can be less intuitive to implement.
  • Requires extra space for buckets, which may not be ideal for space-constrained environments.

Certainly! Here’s the Quickselect implementation for solving the Top K Frequent Elements problem.

4. Quickselect Solution

Explanation:

  • We start by building a frequency dictionary to count the occurrences of each element.
  • Then, we convert this dictionary into a list of tuples (element, frequency).
  • Using Quickselect, we partition the list of tuples so that the k most frequent elements are positioned in the first k spots in the list.
  • After partitioning, we return the elements from the first k positions.

Implementation:

import random
from collections import Counter
from typing import List, Tuple

def top_k_frequent_quickselect(nums: List[int], k: int) -> List[int]:
    # Step 1: Count frequencies
    freq_count = Counter(nums)
    # Convert the dictionary to a list of (element, frequency) pairs
    freq_items = list(freq_count.items())

    def partition(left: int, right: int, pivot_index: int) -> int:
        pivot_frequency = freq_items[pivot_index][1]
        # Move pivot to the end
        freq_items[pivot_index], freq_items[right] = freq_items[right], freq_items[pivot_index]
        store_index = left
        # Move all elements with frequency greater than pivot to the left
        for i in range(left, right):
            if freq_items[i][1] > pivot_frequency:
                freq_items[store_index], freq_items[i] = freq_items[i], freq_items[store_index]
                store_index += 1
        # Move pivot to its final place
        freq_items[right], freq_items[store_index] = freq_items[store_index], freq_items[right]
        return store_index

    def quickselect(left: int, right: int, k_smallest: int):
        if left == right:  # If the list contains only one element
            return
        # Select a random pivot index
        pivot_index = random.randint(left, right)
        # Partition the array around the pivot
        pivot_index = partition(left, right, pivot_index)
        # Recursively apply quickselect on the relevant half
        if k_smallest == pivot_index:
            return
        elif k_smallest < pivot_index:
            quickselect(left, pivot_index - 1, k_smallest)
        else:
            quickselect(pivot_index + 1, right, k_smallest)

    # Perform Quickselect for the k most frequent elements
    n = len(freq_items)
    quickselect(0, n - 1, k - 1)

    # Return the first k elements' values from freq_items
    return [item[0] for item in freq_items[:k]]

# Example usage
nums = [1, 1, 1, 2, 2, 3]
k = 2
print(top_k_frequent_quickselect(nums, k))  # Output: [1, 2]

Explanation of Key Steps:

  1. Partition Function:
    • Selects a pivot and rearranges elements such that all elements with frequency higher than the pivot are on the left, and all elements with frequency lower than the pivot are on the right.
    • This allows us to position elements based on their frequency.
  2. Quickselect Function:
    • Partitions the list around a pivot index until the k-th most frequent element is in the correct position.
    • This process allows us to get the top k frequent elements in average O(n) time without fully sorting the list.

Pros and Cons:

  • Pros: Efficient with an average time complexity of O(n), ideal for large lists.
  • Cons: The worst-case time complexity is O(n^2), though random pivot selection mitigates this in practice.

Summary

  • Sorting Solution: Simple but inefficient for large n, with O(n log n) complexity.
  • Heap Solution: Ideal for large n with small k, with O(n log k) complexity.
  • Bucket Sort Solution: Optimal for large n and frequency-based grouping, with O(n) complexity, but uses more space.

+ Recent posts