앞서 메모한 #RNN-T Beam search [1] 글에 이어, 최근 facebook AI 팀에서 ICASSP 2020에 제출한 것처럼 보이는 아카이브에 올린 [2] 논문에 대해서 메모하고자 한다. [2]에서는 Latency Controlled RNN과 improved RNNT Beam Search를 제안했지만, 이 글은 후자인 RNNT improved Beam Search 부분만을 위한 글이다.
먼저 간단하게 기본 beam search for RNNT를 리마인드 해보면 이렇다. RNN-T에서는 다음 joiner call을 할 때, 다음 time frame t+1로 넘어갈 것인지, 같은 time frame t에서의 output units을 더 emit할 것인지 결정하기 위해서, output units은 특별한 symbol인 blank symbol을 포함한다. 매번 joiner call 마다, 1) 우리는 time axis (t) 에서 다음 audio frame t+1로 넘어가거나, 2) hypothesis를 업데이트하고 같은 time frame t 에서 계속해서 symbols를 emit 할 수 있다. 전자는 Joiner가 blank symbol을 가장 큰 확률로 emit할 때이고, 후자는 blank가 아닌 chracter들 중 하나를 가장 큰 확률로 outputting 할 때이다. 기존 RNNT Beam Search는 다음과 같다.
사진 삭제
RNN-T Beam Search [3]
일반적으로 beam search를 computationally efficient하게 개선하기 위해 hypothesis pruning을 많이 시도한다. 이 논문의 beam search의 목표도 그러하다. RNNT beam search에서 hypothesis set A의 hypothesis들의 확장을 제한한다는 것이 이 논문 beam search 알고리즘의 핵심이다. pesudo code는 다음과 같다.
대표사진 삭제
Improved RNN-T Beam Search [2]
간단하다. 두 가지만 확인하면 된다.
1) state_beam을 도입했다. log space에서 hypothesis setB의 best hypothesis보다 (state_beam + hypothesis setA의 best hypothesis)이 더 확률이 낮은 경우, "hypothesis set B의 이미 존재하는 hypos"들이, "hypothesis setA 로부터 확장될 hypothesises들"보다 이미 더욱 좋은 hypothesises이라고 가정해서 while loop 를 끝내 버려서 그 때까지의 hypothesis로 결과를 낸다.
2) 또한 A에 추가되는 hypothesises의 개수를 제한하는 expand_beam을 도입했다. hypothesis set B가, hypothesis set A에서 가장 가능성있는 hypothesis 보다 probability가 높은 W(beam width) 개의 hypothesis를 갖자마자 beam search criterion은 시간 t에서 충족되며, 프레임 t + 1 처리를 시작할 수 있다.y_t, h를 생성하는 (t, h)에서의 Joiner 호출의 경우, 먼저 Hypothesis set A 와 B 에서의 최고의 확률 값 (non-blank output unit (yt, h) 중에서 최고의 확률과, 최고의 확률)을 계산한다. 그 후, log (best prob)보다 log (Pr (k | y ∗, t)가 더 높은 output, k만 추출하여 hypothesis set A에 추가한다.
** hypothesis set A는 여전히 시간 t에 대해 고려되고 있는 hypothesis를 포함하고 있으며, hypothesis set B는 시간 t에서 이미 blank symbol을 emit했으며, 현재 시간 프레임인 t+1에 있는 hypothesis를 포함한다.
코드를 작성해보면 다음과 같다.
def recognize_beam_facebook(self, h, recog_args, rnnlm=None, state_beam, expand_beam):
"""facebook Beam search implementation for rnn-transducer.
Args:
h (torch.Tensor): encoder hidden state sequences (maxlen_in, Henc)
recog_args (Namespace): argument Namespace containing options
rnnlm (torch.nn.Module): language model module
state_beam: ...
expand_beam: ...
Returns:
nbest_hyps (list of dicts): n-best decoding results
"""
beam = recog_args.beam_size
k_range = min(beam, self.odim)
nbest = recog_args.nbest
normscore = recog_args.score_norm_transducer
B_hyps = [{'score': 0.0, 'yseq': [self.blank], 'cache': None}]
for i, hi in enumerate(h):
A_hyps = B_hyps
B_hyps = []
while True:
new_hyp = max(A_hyps, key=lambda x: x['score'])
a_best_hyp = max(A_hyps, key=lambda x: x['score'])
b_best_hyp = max(B_hyps, key=lambda x: x['score'])
if log(b_best_hyp) >= state_beam + log(a_best_hyp):
break
A_hyps.remove(new_hyp)
ys = to_device(self, torch.tensor(new_hyp['yseq']).unsqueeze(0))
ys_mask = to_device(self, subsequent_mask(len(new_hyp['yseq'])).unsqueeze(0))
y, c = self.forward_one_step(ys, ys_mask, new_hyp['cache'])
ytu = torch.log_softmax(self.joint(hi, y[0]), dim=0)
best_prob = max(ytu[1:])
for k in six.moves.range(self.odim):
if k == self.blank:
beam_hyp = {'score': new_hyp['score'] + float(ytu[k]),
'yseq': new_hyp['yseq'][:],
'cache': new_hyp['cache']}
B_hyps.append(beam_hyp)
else:
if float(ytu[k]) >= log(best_prob) - expand_beam :
beam_hyp = {'score': new_hyp['score'] + float(ytu[k]),
'yseq': new_hyp['yseq'][:],
'cache': new_hyp['cache']}
beam_hyp['yseq'].append(int(k))
beam_hyp['cache'] = c
A_hyps.append(beam_hyp)
if len(B_hyps) >= k_range: // beam_size (W)
break
if normscore:
nbest_hyps = sorted(
B_hyps, key=lambda x: x['score'] / len(x['yseq']), reverse=True)[:nbest]
else:
nbest_hyps = sorted(
B_hyps, key=lambda x: x['score'], reverse=True)[:nbest]
return nbest_hyps
[1] https://sequencedata.tistory.com/3?category=1129285
[2] M. Jain, al "RNN-T FOR LATENCY CONTROLLED ASR WITH IMPROVED BEAM SEARCH", 2020
[3] Alex Graves, "Sequence Transduction with Recurrent Neural Networks", 2012
'Speech Signal Processing > Speech Recognition' 카테고리의 다른 글
[Acoustic Model] Feedforward Sequential Memory Networks (FSMN) (0) | 2020.06.15 |
---|---|
[speech recognition] Audio augmentation (0) | 2020.06.13 |
[E2E ASR] RNN-Transducer for ASR (0) | 2020.06.13 |
[E2E ASR] RNN-T Beam search decoding (0) | 2020.06.13 |
음성인식기(ASR) 구현하기 위한 모듈 정리 (0) | 2020.06.13 |