앞서 메모한 #RNN-T Beam search [1] 글에 이어, 최근 facebook AI 팀에서 ICASSP 2020에 제출한 것처럼 보이는 아카이브에 올린 [2] 논문에 대해서 메모하고자 한다. [2]에서는 Latency Controlled RNN과 improved RNNT Beam Search를 제안했지만, 이 글은 후자인 RNNT improved Beam Search 부분만을 위한 글이다.

 

먼저 간단하게 기본 beam search for RNNT를 리마인드 해보면 이렇다. RNN-T에서는 다음 joiner call을 할 때, 다음 time frame t+1로 넘어갈 것인지, 같은 time frame t에서의 output units을 더 emit할 것인지 결정하기 위해서, output units은 특별한 symbol인 blank symbol을 포함한다. 매번 joiner call 마다, 1) 우리는 time axis (t) 에서 다음 audio frame t+1로 넘어가거나, 2) hypothesis를 업데이트하고 같은 time frame t 에서 계속해서 symbols를 emit 할 수 있다. 전자는 Joiner가 blank symbol을 가장 큰 확률로 emit할 때이고, 후자는 blank가 아닌 chracter들 중 하나를 가장 큰 확률로 outputting 할 때이다. 기존 RNNT Beam Search는 다음과 같다.

사진 삭제

RNN-T Beam Search [3]

일반적으로 beam search를 computationally efficient하게 개선하기 위해 hypothesis pruning을 많이 시도한다. 이 논문의 beam search의 목표도 그러하다. RNNT beam search에서 hypothesis set A의 hypothesis들의 확장을 제한한다는 것이 이 논문 beam search 알고리즘의 핵심이다. pesudo code는 다음과 같다.

대표사진 삭제

Improved RNN-T Beam Search [2]

 

 

간단하다. 두 가지만 확인하면 된다.

 

1) state_beam을 도입했다. log space에서 hypothesis setB의 best hypothesis보다 (state_beam + hypothesis setA의 best hypothesis)이 더 확률이 낮은 경우, "hypothesis set B의 이미 존재하는 hypos"들이, "hypothesis setA 로부터 확장될 hypothesises들"보다 이미 더욱 좋은 hypothesises이라고 가정해서 while loop 를 끝내 버려서 그 때까지의 hypothesis로 결과를 낸다.

 

2) 또한 A에 추가되는 hypothesises의 개수를 제한하는 expand_beam을 도입했다. hypothesis set B가, hypothesis set A에서 가장 가능성있는 hypothesis 보다 probability가 높은 W(beam width) 개의 hypothesis를 갖자마자 beam search criterion은 시간 t에서 충족되며, 프레임 t + 1 처리를 시작할 수 있다.y_t, h를 생성하는 (t, h)에서의 Joiner 호출의 경우, 먼저 Hypothesis set A 와 B 에서의 최고의 확률 값 (non-blank output unit (yt, h) 중에서 최고의 확률과, 최고의 확률)을 계산한다. 그 후, log (best prob)보다 log (Pr (k | y ∗, t)가 더 높은 output, k만 추출하여 hypothesis set A에 추가한다.

 

** hypothesis set A는 여전히 시간 t에 대해 고려되고 있는 hypothesis를 포함하고 있으며, hypothesis set B는 시간 t에서 이미 blank symbol을 emit했으며, 현재 시간 프레임인 t+1에 있는 hypothesis를 포함한다.

 


코드를 작성해보면 다음과 같다.

 

 

def recognize_beam_facebook(self, h, recog_args, rnnlm=None, state_beam, expand_beam):
    """facebook Beam search implementation for rnn-transducer.
    Args:
        h (torch.Tensor): encoder hidden state sequences (maxlen_in, Henc)
        recog_args (Namespace): argument Namespace containing options
        rnnlm (torch.nn.Module): language model module
        state_beam: ...
        expand_beam: ...
    Returns:
        nbest_hyps (list of dicts): n-best decoding results
    """
    beam = recog_args.beam_size
    k_range = min(beam, self.odim)
    nbest = recog_args.nbest
    normscore = recog_args.score_norm_transducer

    B_hyps = [{'score': 0.0, 'yseq': [self.blank], 'cache': None}]

    for i, hi in enumerate(h):
        A_hyps = B_hyps
        B_hyps = []

        while True:
            new_hyp = max(A_hyps, key=lambda x: x['score'])
            a_best_hyp = max(A_hyps, key=lambda x: x['score'])
            b_best_hyp = max(B_hyps, key=lambda x: x['score'])
            
            if log(b_best_hyp) >= state_beam + log(a_best_hyp):
                break
            
            A_hyps.remove(new_hyp)

            ys = to_device(self, torch.tensor(new_hyp['yseq']).unsqueeze(0))
            ys_mask = to_device(self, subsequent_mask(len(new_hyp['yseq'])).unsqueeze(0))
            y, c = self.forward_one_step(ys, ys_mask, new_hyp['cache'])

            ytu = torch.log_softmax(self.joint(hi, y[0]), dim=0)

            best_prob = max(ytu[1:])
            
            for k in six.moves.range(self.odim):

                if k == self.blank:
                    beam_hyp = {'score': new_hyp['score'] + float(ytu[k]),
                            'yseq': new_hyp['yseq'][:],
                            'cache': new_hyp['cache']}

                    B_hyps.append(beam_hyp)
    
                else:
                    if float(ytu[k]) >= log(best_prob) - expand_beam :
                        beam_hyp = {'score': new_hyp['score'] + float(ytu[k]),
                            'yseq': new_hyp['yseq'][:],
                            'cache': new_hyp['cache']}

                        beam_hyp['yseq'].append(int(k))
                        beam_hyp['cache'] = c

                        A_hyps.append(beam_hyp)

            if len(B_hyps) >= k_range: // beam_size (W)
                break

    if normscore:
        nbest_hyps = sorted(
            B_hyps, key=lambda x: x['score'] / len(x['yseq']), reverse=True)[:nbest]
    else:
        nbest_hyps = sorted(
            B_hyps, key=lambda x: x['score'], reverse=True)[:nbest]

    return nbest_hyps

 


[1] https://sequencedata.tistory.com/3?category=1129285

[2] M. Jain, al "RNN-T FOR LATENCY CONTROLLED ASR WITH IMPROVED BEAM SEARCH", 2020

[3] Alex Graves, "Sequence Transduction with Recurrent Neural Networks", 2012

 

+ Recent posts