Encoding Byte-Level Representation. We consider UTF8 encoding of text, which encodes each Unicode character into 1 to 4 bytes. This allows us to model a sentence as a sequence of bytes instead of characters. While there are 138K Unicode characters covering over 150 languages, we represent a sentence in any language as a sequence of UTF-8 bytes (248 out of 256 possible bytes).
A byte sequence representation of text is often much longer (up to 4x) than a character sequence representation, which makes it computationally demanding to use bytes as they are. As an alternative, we consider segmenting a byte sequence into variable-length n-grams (byte-level “subwords”). Specifically, we learn BPE vocabulary on the byte-level representation which extends UTF-8 byte set with byte n-grams. We denote this type of vocabulary as B(ytelevel)BPE in the rest of the paper. Figure 1 shows an example of BBPE tokenization.
BBPE symbols can be partial characters shared by different characters or the combination of complete and partial characters. This arbitrariness may necessitate incorporating a larger context surrounding each symbol for disambiguation and learning the character boundaries. In this work, we base our experiments on Transformer (Vaswani et al. 2017) models. We propose to use either a depth-wise convolutional layer (Kaiser, Gomez, and Chollet 2017) or a bidirectional recurrent layer with gated recurrent units (Cho et al. 2014, GRU,) to contextualize BBPE embeddings before feeding them into the model:
Decoding with Byte-Level Subwords. While any sentence can be represented as a byte sequence, the converse is, however, not necessarily true in that there are byte sequences that do not translate to valid character sequences. Empirically, we find that invalid outputs from trained models are very rare. We do not observe any in the experiments described below (note that one of them does have a large test set of 165K examples). And a common error pattern in halftrained models is redundant repeating bytes. In our system, we try to recover as many Unicode characters as possible from this error pattern efficiently in linear time. The algorithm is as follows: For a given byte sequence {B} N k=1, we denote the maximum number of characters that we can recover from it as f(k). Then f(k) has optimal substructure and can be solved by dynamic programming:
corresponds to a valid character, otherwise 0. When f(k) is calculated recursively, we also record the selections at each position k so that we can recover the solution through backtracking. The design of UTF-8 encoding ensures the uniqueness of this recovery process: for a character UTF-8 encoded with multiple bytes, its trailing bytes will not make a valid UTF-8 encoded character. Then the best selection in Eq. 1 is unique and so is the final solution.
Byte-level models have been proposed for natural language processing (NLP) [9] [10] [11]. The idea is to convert text to a sequence of variable-length UTF-8 codewords, and to have the model predict one byte at each decoding step. The advantages of byte-level representation are compactness and universality, as any combination of languages may be represented with an output dimension of only 256. However, a sequence represented at the byte level is always much longer than its character-level counterpart for languages such as Chinese and Japanese [12], which is because many characters of these languages are represented by multiple bytes in UTF-8. As a result, a byte-level model can be error-prone since it needs to make multiple predictions for many single characters, and each prediction has a chance to make a mistake. To compensate for this drawback, [12] proposes byte-level subwords for neural machine translation. The idea is to apply byte pair encoding (BPE) [13] to UTF-8 codeword sequences and as a result, an approach referred to as byte-level BPE (BBPE). BBPE inherits the advantages of UTF-8 byte-level representation. BBPE is able to represent all languages while keeping the output dimension in check. At the same time, as BBPE tokens are in general longer than byte-level tokens, the approach reduces the number of steps required by the decoding process.
In this work, we investigate bilingual (English and Mandarin) E2E ASR models by exploring different types of output representations, including character-level, BPE, byte-level (UTF-8) and BBPE. Similar to some of the previous work cited, we build a single E2E model for utterance-based bilingual speech recognition. Our contributions are threefold. First, we compare the strengths and weaknesses of different output representations in monolingual and bilingual use cases. Second, we propose a method to adjust the bigram statistics in the BPE algorithm and show that the BBPE representation leads to accuracy improvements in the bilingual scenario. Finally, we analyze different representations and show how we might improve them for multilingual ASR.
OUTPUT REPRESENTATIONS FOR E2E ASR
Character-level Representation
Using a character-level representation in an E2E model means that the output symbol set for the model is the set of graphemes of the target language. In addition to graphemes, the output representation may also contain punctuation marks, digits, emojis or special tokens such as begin-of-sentence (BOS) or end-of-sentence (EOS). According to [14] [15], character-level representation is often a good representation for Mandarin E2E models, and this serves as one of the baselines in our experiments.
BPE Representation
The BPE algorithm [13] starts from the character representation and iteratively merges the most frequent bigrams given a training text corpus. At the end of this process, the BPE algorithm produces a symbol set that consists of subwords with different lengths. This symbol set can then be used by an E2E model as its output units. It is common to keep the single characters in the final symbol set, so unseen words in the test set can still be represented by the symbol set. For English, BPE is widely used in E2E ASR systems, as it improves accuracy and reduces computation due to the use of frequent subwords and the resulting shorter labeling sequences.
Byte-level Representation
Scalability is one of the important aspects in designing an output representation for a multilingual E2E ASR model. As the model supports more languages, the size of the symbol set increases. To tackle this problem [8] proposes a byte-level representation based on UTF-8. Instead of using characters or subwords as the symbols, byte-level model uses UTF-8 codewords as the output symbol set. The resulting representation is compact as each UTF-8 codeword only has 256 values so each symbol uses one byte. Yet, this representation is capable of representing any language, and adding more languages does not increase the size of the symbol set, which is an advantage compared to the character-level and BPE representation. However, byte-level representation has two drawbacks, first, it increases the length of the sequence by up to 4x [12], and it increases the number of decoding steps during inference. Second, not all byte sequences are valid UTF-8 sequences, which means the byte-level models may generate invalid byte sequences that require special handling.
To repair an invalid byte sequence, [8] proposes a dynamic programming algorithm to recover the Unicode characters given any byte sequence. We use this post-processing approach to recover characters from byte sequences as much as possible.
Byte-level BPE Representation
To circumvent the increase of sequence length for byte-level representation, [12] proposes byte-level BPE (BBPE) for neural machine translation, which applies BPE to byte-represented text. The advantage of this approach is that it reduces the sequence length by adopting frequent byte-level subwords and it keeps the size of the symbol set in check. It is important to note that BBPE is equivalent to BPE for many Latin-based languages, since in UTF-8, all Latin characters are single byte units. However, for languages like Chinese or Japanese, characters can use multiple bytes, so BBPE could be helpful. Similar to BPE representation, BBPE representation might generate invalid byte sequences, and post-processing using dynamic programming is necessary to remedy that. Another aspect is that if we keep all the single-byte UTF-8 codewords in the symbol set after BPE, BBPE can represent all languages, as with the byte-level representation.
Reference
[1] Anjuli Kannan, Arindrima Datta, Tara Sainath, Eugene Weinstein, Bhuvana Ramabhadran, Yonghui Wu, Ankur Bapna, and Zhifeng Chen, “Large-scale multilingual speech recognition with a streaming end-to-end model,” in Proceedings of the INTERSPEECH, 2019.
[2] Surabhi Punjabi, Harish Arsikere, Zeynab Raeesy, Chander Chandak, Nikhil Bhave, Ankish Bansal, Markus M¨uller, Sergio Murillo, Ariya Rastrow, Sri Garimella, et al., “Streaming end-to-end bilingual ASR systems with joint language identification,” in arXiv preprint arXiv:2007.03900, 2020.
[3] Vineel Pratap, Anuroop Sriram, Paden Tomasello, Awni Hannun, Vitaliy Liptchinsky, Gabriel Synnaeve, and Ronan Collobert, “Massively multilingual ASR: 50 languages, 1 model, 1 billion parameters,” pp. 4751–4755, 2020.
[4] Ke Li, Jinyu Li, Guoli Ye, Rui Zhao, and Yifan Gong, “Towards code-switching ASR for end-to-end CTC models,” in Proceedings of the IEEE International Conference on Acoustics, Speech, and Signal Processing, 2019.
[5] Changhao Shan, Chao Weng, Guangsen Wang, Dan Su, Min Luo, Dong Yu, and Lei Xie, “Investigating end-to-end speech recognition for mandarin-english code-switching,” in Proceedings of the IEEE International Conference on Acoustics, Speech, and Signal Processing, 2019.
[6] Zimeng Qiu, Yiyuan Li, Xinjian Li, Florian Metze, and William M. Campbell, “Towards context-aware end-to-end code-switching speech recognition,” in Proceedings of the INTERSPEECH, 2020.
[8] Bo Li, Yu Zhang, Tara Sainath, Yonghui Wu, and William Chan, “Bytes are all you need: End-to-end multilingual speech recognition and synthesis with bytes,” in Proceedings of the IEEE International Conference on Acoustics, Speech, and Signal Processing, 2019, pp. 5621–5625.
[9] Dan Gillick, Cliff Brunk, Oriol Vinyals, and Amarnag Subramanya, “Multilingual language processing from bytes,” in Proceedings of the Conference of the North American Chapter of the Association for Computational Linguistics - Human Language Technologies, 2016, pp. 1296–1306.
[10] Marta Ruiz Costa-Juss`a, Carlos Escolano Peinado, and Jos´e Adri´an Rodr´ıguez Fonollosa, “Byte-based neural machine translation,” in Proceedings of the First Workshop on Subword and Character Level Models in NLP, 2017, pp. 154–158.
[11] Linting Xue, Aditya Barua, Noah Constant, Rami Al-Rfou, Sharan Narang, Mihir Kale, Adam Roberts, and Colin Raffel, “Byt5: Towards a token-free future with pre-trained byte-tobyte models,” 2021.
[12] Changhan Wang, Kyunghyun Cho, and Jiatao Gu, “Neural machine translation with byte-level subwords,” in Proceedings of the AAAI Conference on Artificial Intelligence, 2020, pp. 9154–9160.
End-to-end (E2E) neural networks are flexible and accurate models for multilingual automatic speech recognition (ASR). The output of such a multilingual model is often unions of characters or subwords of the supported languages. However, as the number of languages increases, the size of the output layer increases, which can negatively affect compute, memory usage and asset size. This problem is more prominent when the system supports languages that have large character sets, such as Chinese, Japanese and Korean (CJK). To tackle this problem, previous work proposed the use of byte level representation for E2E ASR [1, 2]. By using UTF-8 [3] codewords as the underlying base tokens, the output vocabulary is no longer constrained by the character sets of each language, allowing developers to choose a vocabulary size based on compute, and memory constraints. One well-known multilingual ASR system that uses UTF-8 subwords is Whisper [4].
UTF-8 aims to represent all the characters used in major languages. The encoding and decoding processes are designed to be simple and efficient. UTF-8 is a variable length prefix code where each character is represented by one to four bytes. Most byte sequences are not valid UTF-8 strings, and the UTF-8 decoder needs to detect invalid sequences. UTF-8 also provides backward compatibility, where ASCII characters are represented by a single byte and they are the same as the ASCII encoding. While UTF-8 has proven to be an effective output representation for ASR, it is unclear whether it is optimal. For example, characters with similar pronunciations or meaning are not guaranteed to share the same prefixes. In addition, the large number of invalid byte sequences means the model needs to identify valid UTF-8 strings, an additional burden.
UTF-8 BASED REPRESENTATION
UTF-8 based models have been proposed for natural language processing (NLP) [5] [6] [7]. The idea is to convert text to a sequence of variable-length UTF-8 codewords, and to have the model predict one byte at each decoding step. The advantages of byte-level representation are compactness and universality, as any combination of languages may be represented with an output dimension of only 256. However, a sequence represented at byte level is often longer than its characterlevel counterpart, especially for CJK languages [8]. This is because while Latin characters are represented by a single byte, many CJK characters and accented characters are represented by multiple bytes. As a result, a byte-level model can be error-prone since it needs to make multiple predictions for many single characters, and each prediction might make a mistake.
To compensate for the drawback of making byte level mistakes, [1, 2] propose byte-level subwords for E2E ASR. The idea is to apply byte pair encoding (BPE) [9] to UTF-8 codeword sequences to create UTF-8 subwords. As subwords are in general longer than byte-level tokens, this approach reduces the number of steps required by the decoding process. However, BPE does not guarantee that the output will be a valid UTF-8 sequence. To repair an invalid byte sequence, [1] proposes a dynamic programming algorithm to recover as many characters as possible given any byte sequence. While this dynamic programming approach ensures the output sequence is always valid, it optimizes for the number of valid characters, not ASR quality.
Reference
[1] Bo Li, Yu Zhang, Tara Sainath, Yonghui Wu, and William Chan, “Bytes are all you need: End-to-end multilingual speech recognition and synthesis with bytes,” in Proceedings of the IEEE International Conference on Acoustics, Speech, and Signal Processing, 2019, pp. 5621–5625.
[2] L. Deng, R. Hsiao, and A. Ghoshal, “Bilingual endto-end ASR with byte-level subwords,” in Proceedings of the IEEE International Conference on Acoustics, Speech, and Signal Processing, 2022.
[8] Changhan Wang, Kyunghyun Cho, and Jiatao Gu, “Neural machine translation with byte-level subwords,” in Proceedings of the AAAI Conference on Artificial Intelligence, 2020, pp. 9154–9160.
[9] Rico Sennrich, Barry Haddow, and Alexandra Birch, “Neural machine translation of rare words with subword units,” in Proceedings of the Annual Meeting of the Association for Computational Linguistics, 2016, pp. 1715–1725.
이 포맷은 16진수(헥사) 표현이며, 한글과 같은 유니코드 문자를 UTF-8에서 인코딩할 때 사용됩니다.
256이라는 숫자는 바이트(byte)와 관련된 값으로, 8비트의 최대값을 의미합니다. UTF-8 인코딩과 바이트 기반 처리에서 256의 의미를 설명하면 다음과 같습니다:
1. 바이트와 256의 관계
바이트(byte)는 8비트(bit)로 구성된 데이터 단위입니다.
1바이트는 0부터 255까지의 값을 가질 수 있습니다.
0에서 255까지 총 256개의 서로 다른 값을 표현할 수 있는 이유는, 8비트로 나타낼 수 있는 수의 조합이 2^8 = 256이기 때문입니다.
2. UTF-8과 바이트
UTF-8은 유니코드 문자를 1바이트에서 최대 4바이트로 인코딩합니다.
예를 들어, 영어 알파벳이나 ASCII 문자들은 1바이트(0~127)로 표현됩니다.
한글과 같은 유니코드 문자는 2바이트 이상이 필요하며, 이때 각 바이트는 0~255 사이의 값을 갖습니다.
3. Byte-Level BPE와 256
Byte-Level BPE에서 바이트 단위를 사용해 텍스트를 토큰화할 때, 각 바이트가 0~255 범위 내의 값을 가집니다.
이를 통해 모든 가능한 바이트 조합을 다룰 수 있으며, 256개의 개별 바이트는 다양한 유니코드 문자의 조합을 표현할 수 있습니다.
요약
256은 바이트가 표현할 수 있는 값의 개수를 나타내며, 0부터 255까지 총 256가지입니다.
UTF-8에서 각 바이트는 0~255 사이의 값을 갖고, 이를 통해 다양한 문자를 표현할 수 있습니다.
Byte-Level BPE에서는 이 256개의 바이트 값을 토큰화의 기본 단위로 사용하여, 유니코드 문자와 그 조합을 다루게 됩니다.
UTF-8의 이론적인 표현 가능 문자 수는 유니코드의 설계와 관련이 있습니다. 유니코드 자체는 전 세계의 모든 문자를 포함하기 위해 설계되었고, UTF-8은 이 유니코드 문자를 바이트 단위로 인코딩합니다. UTF-8이 이론적으로 표현할 수 있는 문자 수는 다음과 같이 계산할 수 있습니다:
1. UTF-8의 구조
UTF-8은 1바이트에서 4바이트까지 가변 길이 인코딩 방식을 사용합니다. 각 바이트는 특정한 비트 패턴을 통해 인코딩되며, 이 패턴은 유니코드 코드 포인트 범위에 따라 달라집니다.
1바이트 (7비트 사용): 0x00~0x7F (ASCII와 동일, 128개)
2바이트 (11비트 사용): 0x0080~0x07FF (약 2,048개)
3바이트 (16비트 사용): 0x0800~0xFFFF (약 65,536개)
4바이트 (21비트 사용): 0x010000~0x10FFFF (약 1,112,064개)
2. 유니코드 코드 포인트 범위
유니코드의 코드 포인트는 U+0000에서 U+10FFFF까지 정의되어 있습니다.
이 범위는 약 1,114,112개의 코드 포인트를 포함합니다. 그러나, 실제로 사용 가능한 문자 수는 이보다 적습니다.
3. 실제 표현 가능한 문자 수
모든 코드 포인트가 문자를 나타내지는 않습니다. 일부 코드 포인트는 제어 문자나 특수 용도로 예약되어 있으며, 문자로 직접적으로 사용되지 않습니다.
현재 유니코드 표준(유니코드 15.0 기준)에서는 약 149,186개의 문자가 정의되어 있습니다. 이 숫자는 유니코드의 다양한 언어 스크립트와 기호를 포함합니다.
4. 이론적인 UTF-8 표현 가능 수
이론적으로 UTF-8은 0x0000에서 0x10FFFF까지의 유니코드 범위를 모두 표현할 수 있기 때문에, 약 1,114,112개의 코드 포인트를 지원합니다.
그러나, 실제로 유니코드 표준에서 정의한 문자 수는 이보다 적고, 추가로 제어 문자와 기타 예약된 코드 포인트를 고려해야 합니다.
요약
이론적으로 UTF-8은 약 1,114,112개의 코드 포인트를 표현할 수 있습니다.
실제로 현재 유니코드에서 정의된 문자는 약 149,186개이며, 유니코드는 계속해서 확장되고 있습니다.
UTF-8은 유니코드의 모든 코드 포인트를 바이트 시퀀스로 인코딩할 수 있는 유연한 방식입니다.
0x0080과 같은 표현은 16진수(Hexadecimal)를 사용해 숫자를 나타내는 방식입니다. 이 표현은 컴퓨터 과학과 프로그래밍에서 많이 사용되며, 특히 메모리 주소, 유니코드 코드 포인트, 바이트 데이터 등을 다룰 때 유용합니다. 아래에 0x0080과 같은 표현에 대해 자세히 설명해 드릴게요.
1. 16진수(Hexadecimal) 표현
16진수는 0~9와 A~F를 사용하는 숫자 체계입니다. 이는 기수 16의 체계로, 한 자리 수가 최대 16가지 값을 가질 수 있습니다.
예를 들어:
0x0 = 0 (십진수)
0xA = 10 (십진수)
0xF = 15 (십진수)
0x10 = 16 (십진수)
16진수 표현에서 앞에 붙는 0x는 이 값이 16진수임을 나타내는 표기입니다. 즉, 0x0080은 16진수 80을 의미하며, 십진수로는 128에 해당합니다.
2. 0x0080의 의미
0x0080은 16진수로 표현된 숫자 128입니다.
00 부분은 자리수를 맞추기 위한 것이며, 실제 값은 80입니다.
이는 십진수로 변환하면 128이 됩니다.
이 값은 컴퓨터 메모리에서 바이트를 표현하거나, 유니코드 코드 포인트, 색상 코드, 메모리 주소 등을 나타낼 때 자주 사용됩니다.
3. 유니코드와 16진수
유니코드의 코드 포인트는 보통 16진수 형식으로 표현됩니다.
예를 들어, 유니코드 문자 U+0080은 유니코드 표준에서 128번째 코드 포인트를 나타냅니다.
UTF-8 인코딩에서, 0x0080에 해당하는 문자는 여러 바이트로 표현될 수 있습니다.
0x0080은 유니코드의 코드 포인트로 보면, ASCII 확장 영역에 해당합니다. ASCII는 0x0000부터 0x007F(0~127)까지 1바이트로 표현되며, 그 이후의 0x0080(128) 이상의 값들은 2바이트 이상을 사용해 표현됩니다.
4. 16진수 사용의 장점
메모리 주소 표현: 메모리 주소와 바이트 데이터를 표현할 때 16진수를 사용하면, 메모리의 크기를 더 쉽게 계산하고 확인할 수 있습니다.
컴퓨터 친화적: 컴퓨터 하드웨어는 이진수(0과 1)로 데이터를 처리하지만, 이를 사람이 읽기 쉽게 표현하기 위해 2진수를 16진수로 변환하여 사용합니다. 16진수는 4비트를 한 자리로 표현할 수 있어 간결합니다.
예: 0x10 (16진수)은 0001 0000 (2진수)로 표현할 수 있습니다.
요약
0x0080은 16진수 표기법으로, 십진수로는 128에 해당합니다.
16진수 표기는 메모리 주소, 바이트 값, 유니코드 코드 포인트 등을 다룰 때 유용하며, 컴퓨터 과학에서 자주 사용됩니다.
유니코드에서 0x0080은 ASCII 확장 영역에 해당하며, 이를 UTF-8 인코딩으로 표현할 때는 여러 바이트로 나타낼 수 있습니다.
2진수(바이너리)를 16진수(헥사)로 변환하는 방법은 비교적 간단합니다. 2진수는 0과 1의 조합으로 구성되고, 16진수는 0-9와 A-F의 조합으로 이루어져 있습니다. 변환 방법을 설명하고 예시를 통해 보여드리겠습니다.
1. 2진수에서 16진수로 변환하는 방법
2진수를 16진수로 변환하려면 4비트씩 묶어서 각 묶음을 16진수로 변환하면 됩니다.
16진수는 4비트로 표현할 수 있기 때문에, 2진수의 비트들을 4비트 그룹으로 나누면 쉽게 변환할 수 있습니다.
2. 변환 과정
4비트씩 그룹화:
2진수를 오른쪽부터 4비트씩 묶습니다. 만약 비트 수가 4의 배수가 아니라면 왼쪽에 0을 추가해서 4비트로 맞춥니다.
각 4비트를 16진수로 변환:
각 4비트 그룹은 다음과 같이 16진수로 변환할 수 있습니다:
2진수
16진수
0000
0
0001
1
0010
2
0011
3
0100
4
0101
5
0110
6
0111
7
1000
8
1001
9
1010
A
1011
B
1100
C
1101
D
1110
E
1111
F
각 그룹을 대응되는 16진수로 변환하여 결합합니다.
3. 예시
예시 1: 2진수 11011011을 16진수로 변환
2진수를 4비트씩 그룹화합니다:
1101 1011
각 그룹을 16진수로 변환합니다:
1101 → D
1011 → B
따라서, 11011011의 16진수 표현은 0xDB입니다.
예시 2: 2진수 101010을 16진수로 변환
2진수를 4비트씩 그룹화합니다. 이 경우 6비트이므로, 앞에 00을 추가해 4의 배수로 맞춥니다:
The DynamicBatchSampler groups data samples into batches.
It takes into account the length of each sample, a target batch size, and a maximum token limit per batch.
A batch is generated when either the batch size or the sum of sample lengths reaches the specified limits.
Example Code:
import random
class DynamicBatchSampler:
def __init__(self, data_lengths, batch_size, max_tokens):
"""
Initializes the DynamicBatchSampler with data lengths, target batch size, and max tokens constraint.
:param data_lengths: List of lengths of each data sample.
:param batch_size: The base size of each batch.
:param max_tokens: The maximum number of tokens (or length sum) allowed per batch.
"""
self.data_lengths = data_lengths
self.batch_size = batch_size
self.max_tokens = max_tokens
def __iter__(self):
# Shuffle the indices to get a different sampling order each time.
indices = list(range(len(self.data_lengths)))
random.shuffle(indices)
batch = []
total_length = 0
for idx in indices:
length = self.data_lengths[idx]
# Check if adding this sample would exceed max tokens in the batch.
if total_length + length > self.max_tokens or len(batch) >= self.batch_size:
# Yield the current batch and reset for the next one.
yield batch
batch = []
total_length = 0
# Add the current sample to the batch.
batch.append(idx)
total_length += length
# Yield any remaining samples as the last batch.
if batch:
yield batch
# Example usage:
data_lengths = [5, 8, 3, 7, 10, 2, 4, 6] # Example lengths of samples.
batch_size = 3 # Maximum number of samples per batch.
max_tokens = 15 # Maximum total length of samples in a batch.
sampler = DynamicBatchSampler(data_lengths, batch_size, max_tokens)
# Iterate over the batches generated by the DynamicBatchSampler.
for batch in sampler:
print(f"Batch indices: {batch}, Batch lengths: {[data_lengths[i] for i in batch]}")
Explanation:
Data Lengths: [5, 8, 3, 7, 10, 2, 4, 6] represents the lengths of different samples.
Batch Size: 3 means each batch can have up to 3 samples.
Max Tokens: 15 restricts each batch to a maximum total length of 15.
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class FullSelfAttention(nn.Module):
"""
A vanilla multi-head self-attention layer with no masking and a projection at the end.
This implementation doesn't use causal masking, meaning all tokens can attend to each other.
"""
def __init__(self, n_embd, n_head, attn_pdrop, resid_pdrop):
super().__init__()
assert n_embd % n_head == 0
self.n_head = n_head
# key, query, value projections for all heads
self.key = nn.Linear(n_embd, n_embd)
self.query = nn.Linear(n_embd, n_embd)
self.value = nn.Linear(n_embd, n_embd)
# regularization
self.attn_drop = nn.Dropout(attn_pdrop)
self.resid_drop = nn.Dropout(resid_pdrop)
# output projection
self.proj = nn.Linear(n_embd, n_embd)
def forward(self, x):
B, T, C = x.size()
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
# full self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = F.softmax(att, dim=-1) # no masking here, full attention
att = self.attn_drop(att)
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
# output projection
y = self.resid_drop(self.proj(y))
return y
Causal SelfAttention
class CausalSelfAttention(nn.Module):
"""
A vanilla multi-head masked self-attention layer with a projection at the end.
It is possible to use torch.nn.MultiheadAttention here but I am including an
explicit implementation here to show that there is nothing too scary here.
"""
def __init__(self, n_embd, block_size, n_head, attn_pdrop, resid_pdrop):
super().__init__()
assert n_embd % n_head == 0
self.n_head = n_head
# key, query, value projections for all heads
self.key = nn.Linear(n_embd, n_embd)
self.query = nn.Linear(n_embd, n_embd)
self.value = nn.Linear(n_embd, n_embd)
# regularization
self.attn_drop = nn.Dropout(attn_pdrop)
self.resid_drop = nn.Dropout(resid_pdrop)
# output projection
self.proj = nn.Linear(n_embd, n_embd)
# causal mask to ensure that attention is only applied to the left in the input sequence
self.register_buffer("mask", torch.tril(torch.ones(block_size, block_size))
.view(1, 1, block_size, block_size))
def forward(self, x, layer_past=None):
B, T, C = x.size()
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf'))
att = F.softmax(att, dim=-1)
att = self.attn_drop(att)
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
# output projection
y = self.resid_drop(self.proj(y))
return y
Look-ahead MHSA
To implement causal attention with a "look ahead" mechanism (i.e., allowing some future tokens to be attended to, while still maintaining causality by limiting how far into the future attention can extend), you can modify the attention mechanism to apply a causal mask with a specified number of future tokens.
Here's a PyTorch implementation of causal attention with look-ahead:
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class CausalAttentionWithLookAhead(nn.Module):
"""
Causal self-attention mechanism with a look-ahead window, allowing each token to attend to a limited number
of future tokens, while still maintaining the causal nature (i.e., no attention to tokens further ahead).
Usage:
n_embd = 64 # embedding dimension
n_head = 8 # number of attention heads
attn_pdrop = 0.1 # dropout for attention weights
resid_pdrop = 0.1 # dropout for output projection
look_ahead_size = 2 # allow attention up to 2 future tokens
model = CausalAttentionWithLookAhead(n_embd, n_head, attn_pdrop, resid_pdrop, look_ahead_size)
x = torch.randn(16, 10, n_embd) # Batch of 16 sequences, each of length 10, embedding size 64
output = model(x)
print(output.size()) # should return (16, 10, 64)
"""
def __init__(self, n_embd, n_head, attn_pdrop, resid_pdrop, look_ahead_size):
super().__init__()
assert n_embd % n_head == 0
self.n_head = n_head
self.look_ahead_size = look_ahead_size
# key, query, value projections for all heads
self.key = nn.Linear(n_embd, n_embd)
self.query = nn.Linear(n_embd, n_embd)
self.value = nn.Linear(n_embd, n_embd)
# regularization
self.attn_drop = nn.Dropout(attn_pdrop)
self.resid_drop = nn.Dropout(resid_pdrop)
# output projection
self.proj = nn.Linear(n_embd, n_embd)
def forward(self, x):
B, T, C = x.size()
# calculate query, key, values for all heads and move head forward to be the batch dim
k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
# Causal attention with look ahead: Self-attend (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
# Create the causal mask with a look-ahead window
causal_mask = torch.tril(torch.ones(T, T), diagonal=self.look_ahead_size).view(1, 1, T, T).to(x.device) # (1, 1, T, T)
att = att.masked_fill(causal_mask == 0, float('-inf'))
# Apply softmax and dropout
att = F.softmax(att, dim=-1)
att = self.attn_drop(att)
# Apply attention to the value: (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
y = att @ v
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
# Output projection
y = self.resid_drop(self.proj(y))
return y
### Key Points:
1. **Causal Masking with Look-Ahead**:
- A causal mask is created using `torch.tril`, which generates a lower triangular matrix. The `diagonal=self.look_ahead_size` argument allows attention to future tokens within a window of size `look_ahead_size`.
- For example, if `look_ahead_size=2`, each token will be able to attend to up to 2 future tokens in addition to past and current tokens.
2. **Attention Calculation**:
- As usual, the queries (`q`), keys (`k`), and values (`v`) are computed and reshaped for the multi-head attention operation.
- The attention scores are computed by multiplying the query matrix with the transpose of the key matrix.
- After masking, softmax is applied to obtain the attention weights, and these weights are used to compute a weighted sum of the values.
3. **Handling the Future Look-Ahead**:
- The `look_ahead_size` controls how many future tokens can be attended to. The larger the value, the further ahead the model can look while still being restricted by causality.
ChunkBasedAttention
To modify the `ChunkBasedAttention` implementation for parallel processing during training, the key change is to avoid sequential processing of chunks and instead process all chunks in parallel. This requires reshaping the input tensors so that the attention computation can be performed on all chunks simultaneously. Here's how you can adjust the `ChunkBasedAttention` class for parallel processing:
### Modified `ChunkBasedAttention` with Parallel Processing:
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class ChunkBasedAttention(nn.Module):
"""
Chunk-based self-attention mechanism with configurable left and right attention chunk sizes.
This version allows for parallel processing of chunks during training.
Example Usage:
n_embd = 64 # embedding dimension
n_head = 8 # number of attention heads
attn_pdrop = 0.1 # dropout for attention weights
resid_pdrop = 0.1 # dropout for output projection
attention_chunk_size = 4 # size of each chunk
left_chunk_size = 1 # allow attention to 1 chunk on the left
right_chunk_size = 1 # allow attention to 1 chunk on the right
model = ChunkBasedAttention(n_embd, n_head, attn_pdrop, resid_pdrop, attention_chunk_size, left_chunk_size, right_chunk_size)
x = torch.randn(16, 12, n_embd) # Batch of 16 sequences, each of length 12, embedding size 64
output = model(x)
print(output.size()) # should return (16, 12, 64)
"""
def __init__(self, n_embd, n_head, attn_pdrop, resid_pdrop, attention_chunk_size, left_chunk_size, right_chunk_size):
super().__init__()
assert n_embd % n_head == 0
self.n_head = n_head
self.attention_chunk_size = attention_chunk_size
self.left_chunk_size = left_chunk_size
self.right_chunk_size = right_chunk_size
# key, query, value projections for all heads
self.key = nn.Linear(n_embd, n_embd)
self.query = nn.Linear(n_embd, n_embd)
self.value = nn.Linear(n_embd, n_embd)
# regularization
self.attn_drop = nn.Dropout(attn_pdrop)
self.resid_drop = nn.Dropout(resid_pdrop)
# output projection
self.proj = nn.Linear(n_embd, n_embd)
def forward(self, x):
B, T, C = x.size() # B: Batch size, T: Sequence length, C: Embedding dimension
chunk_size = self.attention_chunk_size
num_chunks = (T + chunk_size - 1) // chunk_size # Total number of chunks
# calculate query, key, values for all heads and move head forward to be the batch dim
k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
# Reshape the queries, keys, and values into chunks for parallel processing
q_chunks = q.view(B, self.n_head, num_chunks, chunk_size, C // self.n_head) # (B, nh, num_chunks, chunk_size, hs)
k_chunks = k.view(B, self.n_head, num_chunks, chunk_size, C // self.n_head) # (B, nh, num_chunks, chunk_size, hs)
v_chunks = v.view(B, self.n_head, num_chunks, chunk_size, C // self.n_head) # (B, nh, num_chunks, chunk_size, hs)
# Construct the causal mask with left and right chunk sizes
chunk_mask = torch.zeros(num_chunks, num_chunks).to(x.device)
for i in range(num_chunks):
start_idx = max(0, i - self.left_chunk_size)
end_idx = min(num_chunks, i + self.right_chunk_size + 1)
chunk_mask[i, start_idx:end_idx] = 1
# Apply the chunk mask to attention scores
chunk_mask = chunk_mask.view(1, 1, num_chunks, num_chunks) # (1, 1, num_chunks, num_chunks)
# Compute attention for all chunks in parallel
attn_scores = torch.einsum('bhqnc,bhknc->bhqkn', q_chunks, k_chunks) / math.sqrt(C // self.n_head) # (B, nh, num_chunks, chunk_size, chunk_size)
attn_scores = attn_scores.masked_fill(chunk_mask[:, :, None, :, :] == 0, float('-inf')) # Apply the chunk-based mask
attn_probs = F.softmax(attn_scores, dim=-1) # (B, nh, num_chunks, chunk_size, chunk_size)
attn_probs = self.attn_drop(attn_probs)
# Apply attention to the value: (B, nh, num_chunks, chunk_size, chunk_size) x (B, nh, num_chunks, chunk_size, hs)
y_chunks = torch.einsum('bhqkn,bhknc->bhqnc', attn_probs, v_chunks) # (B, nh, num_chunks, chunk_size, hs)
y = y_chunks.contiguous().view(B, self.n_head, T, C // self.n_head) # (B, nh, T, hs)
y = y.transpose(1, 2).contiguous().view(B, T, C) # Reassemble all head outputs side by side
# Output projection
y = self.resid_drop(self.proj(y))
return y
### Explanation of Changes:
1. **Parallel Processing**:
- Instead of processing each chunk sequentially, we reshape the input into chunks and process all chunks in parallel using `torch.einsum`. This allows for faster training with GPUs since parallelization is leveraged.
2. **Causal Mask**:
- A chunk-based mask is constructed to ensure that each chunk can only attend to its allowed left and right neighboring chunks.
- This mask is then applied to the attention scores to ensure that attention is restricted to a specific range of chunks.
3. **Efficient Attention Calculation**:
- The attention scores are computed using `torch.einsum`, which efficiently handles the matrix multiplication for all chunks in parallel.
- The `softmax` operation is applied across the appropriate dimension (`-1`), and then the attention probabilities are used to weight the values.
4. **Reshaping for Output**:
- After computing the attention-weighted values for all chunks, the output is reshaped back into the original sequence length to ensure consistency with the input format.
### Key Points:
- **Chunk-Based Masking**: Each chunk can attend only to a specified range of chunks to the left and right, and this is enforced using the chunk mask.
- **Parallel Computation**: By reshaping the input into chunks and using `torch.einsum`, we can compute attention for all chunks in parallel, which speeds up training.
- **Flexibility**: The chunk size, left, and right attention window sizes are flexible and can be adjusted based on the model's requirements.
BEST-RQ: SSL with Random-projection Quantizer for Speech Recognition
BEST-RQ introduces a novel technique of self-supervised training using a combination of Random Projection Quantizer (RPQ) and Masked Language Modeling (MLM).
Entire process of BEST-RQ is firstly to proceed Random Projection Quantizer (RPQ) (randomly initialized linear layer and a single codebook for quantizing and discretizing the audio):
The Mel filterbanks are projected through the linear layer.
The index of the nearest codebook entry to the projection is selected as the target.
The nearest codebook entry is found by calculating the argmin of the normalized distance between the projection and each codebook entry.
Afterward, a mask is applied to a portion of the Mel filterbanks, and the model’s objective is to predict the correct targets for the masked sections. This is framed as a classification task, and cross-entropy loss is used to compute the training objective.
1. Random Projection Quantizer (RPQ)
The Random Projection Quantizer is the core part in BEST-RQ, designed to discretize continuous speech features, making them suitable for BERT-like pretraining. RPQ consists of two major components: the Projection Matrix and the Codebook. Both are randomly initialized and remain fixed throughout the training process.
1) Projection Matrix
The projection matrix projects the original speech features into a lower-dimensional space. The matrix is of size ( d \times k ), where:
d: Dimensionality of the original speech features (typically high, such as hundreds or thousands).
k: Target dimensionality after projection (usually much lower than ( d )).
This dimensionality reduction is essential for handling the vast amount of speech data efficiently.
2) Codebook
The Codebook is a collection of n code vectors, each of size ( k ). These code vectors represent the discrete code space into which the speech features are projected.
n: The size of the codebook, which can be tuned based on the task at hand.
Given an input vector ( x ) (a ( d )-dimensional vector computed from speech signals), RPQ maps ( x ) to discrete labels ( y ) through the following operation:
Where:
The projection matrix ( A ) is a randomly initialized ( h \times d ) matrix.
The codebook ( C = {c_1, ..., c_n} ) contains randomly initialized ( h )-dimensional vectors.
The function ( \text{norm}_{l2} ) denotes the L2 normalization.
This transformation enables the speech signals to be quantized into discrete labels, providing a structured learning signal for the downstream tasks.
The projection matrix is initialized using Xavier initialization (Glorot & Bengio, 2010). The codebook is initialized using a standard normal distribution. Both are kept frozen during the entire pretraining process, ensuring that the quantization remains consistent.
Code
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.linalg import vector_norm
class RandomProjectionQuantizer(nn.Module):
"""
Vector quantization using a projection and a randomly initialized codebook.
The output is the indices of the closest code in the codebook for each time step of the input.
Example
-------
>>> quantizer = RandomProjectionQuantizer(16, 16, 8192)
>>> inputs = torch.rand(10, 12, 16)
>>> output = quantizer(inputs)
>>> output.shape
torch.Size([10, 12])
"""
def __init__(self, input_dim, codebook_dim, codebook_vocab):
super().__init__()
self.input_dim = input_dim
self.codebook_dim = codebook_dim
self.codebook_vocab = codebook_vocab
# Initialize projection matrix with Xavier initialization
self.Prj_A_init = nn.init.xavier_uniform_(torch.empty((input_dim, codebook_dim)))
# Normalize a randomly initialized codebook
self.codebook = F.normalize(torch.randn(codebook_vocab, codebook_dim))
def forward(self, x):
"""
Forward the input through the projection and return the indices of the closest codebook entries.
"""
# Normalize the projected input
x = F.normalize(torch.matmul(x, self.Prj_A_init))
# Calculate distances between codebook entries and input, and find the closest code
distances = vector_norm(self.codebook.unsqueeze(1) - x.unsqueeze(1), dim=-1)
# Return the indices of the closest code for each input
return distances.argmin(dim=1)
2. Masked Language Modeling (MLM)
BEST-RQ applies Masked Language Modeling (MLM), much like BERT does for text, but in this case for speech. During training, certain portions of the speech signal are masked and replaced with noise.
Masking Strategy: Each frame of speech is masked with a fixed probability, and the masked portions are replaced with noise sampled from a normal distribution (mean = 0, standard deviation = 0.1).
The model, typically based on a Transformer architecture, is then tasked with predicting the labels (codebook indices) of the masked speech based on the surrounding context. This allows the model to focus on learning robust speech representations.
** A unique point of BEST-RQ is that the RPQ's projection matrix and codebook are frozen and independent of the ASR encoder. This ensures that the model focuses solely on learning meaningful speech representations without needing to adapt to the intricacies of the quantization process.
Pre-train. The pre-training uses mask length 400ms with masking probability of 0.01. The learning rate schedule uses a transformer learning rate schedule (Vaswani et al., 2017). Adam optimizer with 0.004 peak learning rate and 25000 warmup steps. The batch size is 2048. Since the encoder has 4 times temporal-dimension reduction, the quantization with random projections stacks every 4 frames for projections. The vocab size of the codebook is 8192 and the dimension is 16.
The pre-training quality is not very sensitive to the codebook vocab size and the codebook dimension, and is more sensitive to the masking probability and the mask length. The role of the projection layer in the random-projection quantizer is to allow using different codebook dimensions, and one can achieve similar results without the projection and set the codebook dimension to be the same as the input dimension. Due to the variance coming from the random initialization, the impact of a hyperparameter usually requires multiple runs of experiments to verify the result.
Codebook utilization. One of the most critical factors for pre-training quality is the percentage of the codebook that is used during training. In particular, at each training step a higher percentage of the codebook being used in each batch correlates strongly with a good pre-training quality. When the distribution of the codebook utilization is skewed toward a smaller subset of codes, this usually makes the pre-training task easier and provides less effective pre-training. The l2 normalizations on the projected vector and the codebook are critical for providing more uniform codebook utilization. On the other hand, using randomly initialized codebook and projection matrix can introduce different codebook utilizations with different random seeds, which impact the pretraining quality across different runs with same experiment configurations. This variance impacts quality more when training with smaller pre-training and fine-tuning datasets. How to reduce this reproducibility issue caused by random initialization is an important next step for improving random-projection quantizations.
Initialization. The quantizer uses random initialization and does not update the parameters, and therefore the initialization algorithm can play an important role on the results. In this paper we showed results with Xavier initialization for the projection matrix and the standard normal distribution for the codebook, and further comparisons on different initialization algorithms can be conduct in the future work.
There are generally three ways to perform text-only adaptation:
Injecting synthesizing speech data to the model
generate audio for training texts via TTS and inject it to the model
LM fusion
Fusion and biasing (shallow fusion):
during decoding interpolate posterior word probabilities with text priors from external LMs
another recent approach is to extract internal LM probabilities and discount with the ratio of external and internal LM probabilities
Rescoring and reranking
after decoding, use a powerful external LM to update scores and rerank n-best results or recognition lattice
These techniques incur a significant overhead at inference time due to the external LM and also require careful tuning of the interpolation weight used for the external LM.
Explicit separation of internal LMs
force the E2E decoder/predictor to behave more like a language model (e.g. Hybrid autoregressive transducer (HAT), Modular hybrid autoregressive transducer, and Factorized transducer)
Reference
[1] External Language Model Integration for Factorized Neural Transducers
[2] in-situ test-only adaptation of speech models with low-overhead speech imputations