앞서 메모한 #RNN-T Beam search [1] 글에 이어, 최근 facebook AI 팀에서 ICASSP 2020에 제출한 것처럼 보이는 아카이브에 올린 [2] 논문에 대해서 메모하고자 한다. [2]에서는 Latency Controlled RNN과 improved RNNT Beam Search를 제안했지만, 이 글은 후자인 RNNT improved Beam Search 부분만을 위한 글이다.

 

먼저 간단하게 기본 beam search for RNNT를 리마인드 해보면 이렇다. RNN-T에서는 다음 joiner call을 할 때, 다음 time frame t+1로 넘어갈 것인지, 같은 time frame t에서의 output units을 더 emit할 것인지 결정하기 위해서, output units은 특별한 symbol인 blank symbol을 포함한다. 매번 joiner call 마다, 1) 우리는 time axis (t) 에서 다음 audio frame t+1로 넘어가거나, 2) hypothesis를 업데이트하고 같은 time frame t 에서 계속해서 symbols를 emit 할 수 있다. 전자는 Joiner가 blank symbol을 가장 큰 확률로 emit할 때이고, 후자는 blank가 아닌 chracter들 중 하나를 가장 큰 확률로 outputting 할 때이다. 기존 RNNT Beam Search는 다음과 같다.

사진 삭제

RNN-T Beam Search [3]

일반적으로 beam search를 computationally efficient하게 개선하기 위해 hypothesis pruning을 많이 시도한다. 이 논문의 beam search의 목표도 그러하다. RNNT beam search에서 hypothesis set A의 hypothesis들의 확장을 제한한다는 것이 이 논문 beam search 알고리즘의 핵심이다. pesudo code는 다음과 같다.

대표사진 삭제

Improved RNN-T Beam Search [2]

 

 

간단하다. 두 가지만 확인하면 된다.

 

1) state_beam을 도입했다. log space에서 hypothesis setB의 best hypothesis보다 (state_beam + hypothesis setA의 best hypothesis)이 더 확률이 낮은 경우, "hypothesis set B의 이미 존재하는 hypos"들이, "hypothesis setA 로부터 확장될 hypothesises들"보다 이미 더욱 좋은 hypothesises이라고 가정해서 while loop 를 끝내 버려서 그 때까지의 hypothesis로 결과를 낸다.

 

2) 또한 A에 추가되는 hypothesises의 개수를 제한하는 expand_beam을 도입했다. hypothesis set B가, hypothesis set A에서 가장 가능성있는 hypothesis 보다 probability가 높은 W(beam width) 개의 hypothesis를 갖자마자 beam search criterion은 시간 t에서 충족되며, 프레임 t + 1 처리를 시작할 수 있다.y_t, h를 생성하는 (t, h)에서의 Joiner 호출의 경우, 먼저 Hypothesis set A 와 B 에서의 최고의 확률 값 (non-blank output unit (yt, h) 중에서 최고의 확률과, 최고의 확률)을 계산한다. 그 후, log (best prob)보다 log (Pr (k | y ∗, t)가 더 높은 output, k만 추출하여 hypothesis set A에 추가한다.

 

** hypothesis set A는 여전히 시간 t에 대해 고려되고 있는 hypothesis를 포함하고 있으며, hypothesis set B는 시간 t에서 이미 blank symbol을 emit했으며, 현재 시간 프레임인 t+1에 있는 hypothesis를 포함한다.

 


코드를 작성해보면 다음과 같다.

 

 

def recognize_beam_facebook(self, h, recog_args, rnnlm=None, state_beam, expand_beam):
    """facebook Beam search implementation for rnn-transducer.
    Args:
        h (torch.Tensor): encoder hidden state sequences (maxlen_in, Henc)
        recog_args (Namespace): argument Namespace containing options
        rnnlm (torch.nn.Module): language model module
        state_beam: ...
        expand_beam: ...
    Returns:
        nbest_hyps (list of dicts): n-best decoding results
    """
    beam = recog_args.beam_size
    k_range = min(beam, self.odim)
    nbest = recog_args.nbest
    normscore = recog_args.score_norm_transducer

    B_hyps = [{'score': 0.0, 'yseq': [self.blank], 'cache': None}]

    for i, hi in enumerate(h):
        A_hyps = B_hyps
        B_hyps = []

        while True:
            new_hyp = max(A_hyps, key=lambda x: x['score'])
            a_best_hyp = max(A_hyps, key=lambda x: x['score'])
            b_best_hyp = max(B_hyps, key=lambda x: x['score'])
            
            if log(b_best_hyp) >= state_beam + log(a_best_hyp):
                break
            
            A_hyps.remove(new_hyp)

            ys = to_device(self, torch.tensor(new_hyp['yseq']).unsqueeze(0))
            ys_mask = to_device(self, subsequent_mask(len(new_hyp['yseq'])).unsqueeze(0))
            y, c = self.forward_one_step(ys, ys_mask, new_hyp['cache'])

            ytu = torch.log_softmax(self.joint(hi, y[0]), dim=0)

            best_prob = max(ytu[1:])
            
            for k in six.moves.range(self.odim):

                if k == self.blank:
                    beam_hyp = {'score': new_hyp['score'] + float(ytu[k]),
                            'yseq': new_hyp['yseq'][:],
                            'cache': new_hyp['cache']}

                    B_hyps.append(beam_hyp)
    
                else:
                    if float(ytu[k]) >= log(best_prob) - expand_beam :
                        beam_hyp = {'score': new_hyp['score'] + float(ytu[k]),
                            'yseq': new_hyp['yseq'][:],
                            'cache': new_hyp['cache']}

                        beam_hyp['yseq'].append(int(k))
                        beam_hyp['cache'] = c

                        A_hyps.append(beam_hyp)

            if len(B_hyps) >= k_range: // beam_size (W)
                break

    if normscore:
        nbest_hyps = sorted(
            B_hyps, key=lambda x: x['score'] / len(x['yseq']), reverse=True)[:nbest]
    else:
        nbest_hyps = sorted(
            B_hyps, key=lambda x: x['score'], reverse=True)[:nbest]

    return nbest_hyps

 


[1] https://sequencedata.tistory.com/3?category=1129285

[2] M. Jain, al "RNN-T FOR LATENCY CONTROLLED ASR WITH IMPROVED BEAM SEARCH", 2020

[3] Alex Graves, "Sequence Transduction with Recurrent Neural Networks", 2012

 

 

몇 년 전, Alex Graves가 길이가 다른 input/output sequences 를 잘 mapping 할 수 있는 RNN-T 모델을 Sequence Transduction with RNN [1] 논문에서 제안했다. 여러 가지 sequence transduction 문제 중 ASR 문제에 집중을 했는데, 구글이 작년에 이 모델로 모바일 상에서 steamable 음성인식 성과를 소개해서인지, 최근들어 이 모델에 대한 관심이 급격하게 높아졌다. ASR 문제에서 input sequence는 audio signal, output sequence는 텍스트이다. RNN-T에 대한 설명은 생략하고, 이 글의 주제인 beam search for RNN-T 알고리즘에 대해서 메모를 남겨보고자 한다. 이 알고리즘은 [1] 논문에서 처음 제안되었다. pseudo code는 다음과 같다.

 

대표사진 삭제

사진 설명을 입력하세요.

 

 

RNN-T 모델을 위한 beam search를 위해 우선 time t에서 hypothesis h 는 frame embedding a_t 와 text embedding t_h 를 가진다고 가정한다. 이 두 embedding들은 결합되고 Joiner를 태워서 y_t,h output units(eg, 256) 에 대한 확률을 생성한다. while 문 안에 있는 for문에서의 y는 한 시점의 output units 이다.

 

RNN-T의 output units 은 우리가 모델링하고자 하는 output character들 외에 특별한 symbol 인 blank (∅) 가 존재한다. 이 blank symbol 은 디코딩 진행시, lattice 상에서 t 시점에 있을 때, 다음 time frame인 t+1 으로 진행할 것인지 혹은 t 시점에서의 output unit 과 같은 문자를 emit할 것인지 결정하기 위해 존재한다.

 

Beam search는 두 가지 hypothesis set A, B 를 사용하여 수행이 된다. 이것은 t + 1로 이동되는 최소 W (beam size) hypothesis가 t에서 여전히 생성 될 수있는 것보다 높은 확률을 갖도록하기 위함이다.

 

알고리즘을 보면 크게 세 부분으로 나눌 수 있다.

 

1. 첫번째는 특정시점까지 찾아진 hypothesis set A에 포함된 모든 hypothesis에 대한 probability를 그 hypothesis 의 적절한 prefix의 set을 사용하여 계산해 내는, 첫번째 for-loop에 대한 부분이다. 모든 time-step마다 그 시점까지 찾아진 hypothesis set B(=A) 에 포함된 모든 hypothesis에 대한 확률을 계산해낸다.

 

2. 두번째는 Beam search 를 진행하는 중에, hypothesis set A에서 최고 확률이 높은 hypothesis를 선택하고 blank 또는 non-blank symbol 로 확장하는 부분이다. while 를 돌며 Joiner(softmax)를 태워서 나온 y_t,h output units(eg, 256) 에서 모든 dimension output(probability)에 대해서, 기존 score(prob)에 곱한 score(prob)를 업데이트 하고, 해당 chracter(정확히는 index)를 append 시킨 새로운 hypothesis를 hypothesis set A 혹은 B에 append 한다. 이 때, blank에 해당하는 output dimension에 대해서는 score 계산 후, hypothesis set B에 append 하고, 다른 dimension에 해당하는 값(다른 캐릭터) 들에 대해서는 socre 계산 후, hypothesis set A에 append한다. While 문은 beam width 만큼 hypothesis가 hypothesis set B에 존재할 때까지 진행된다. 다시 말하면, blank가 마지막에 사용되는 expansion hypothesis는 를 hypo set B에 삽입된다. 한편, blank가 사용되지 않는 다른 모든 symbols을 사용하는 expansion hypothesis들은 hypothesis set A에 다시 삽입되어 t에서 확장 된 hypothesis set A가 계속해서 재 정의 된다.

 

3. return 할 때, hypothesis set B에 속한 W 개 미만의 hypothesis 중에 "각각의 hypotesis의 확률 값을 chracter들의 숫자로 나눈 최종 score"을 비교하여 가장 높은 score 값이 되는 hypothesis를 return 한다.

 

** 참고로 hypothesis set A, B에 저장 할 때는 {score(probability), character sequence(hypothesis)} 가 사전 형태로 함께 저장된다는 것을 함께 생각해야, 헷갈리지 않는다.


espnet에 구현되어 있는 beam search 코드[2]이다. 참고하면 더 이해하기 편하다. pseudo code와는 다르게 코드 상에선 blank output 에 대해서 해당 hypothesis를 hypothesis set B 에 append 하는 부분을 for 문 안으로 넣어놓았다.

 

 

def recognize_beam(self, h, recog_args, rnnlm=None):
    """Beam search implementation for transformer-transducer.
    Args:
        h (torch.Tensor): encoder hidden state sequences (maxlen_in, Henc)
        recog_args (Namespace): argument Namespace containing options
        rnnlm (torch.nn.Module): language model module
    Returns:
        nbest_hyps (list of dicts): n-best decoding results
    """
    beam = recog_args.beam_size
    k_range = min(beam, self.odim)
    nbest = recog_args.nbest
    normscore = recog_args.score_norm_transducer

    B_hyps = [{'score': 0.0, 'yseq': [self.blank], 'cache': None}]

    for i, hi in enumerate(h):
        A_hyps = B_hyps
        B_hyps = []

        while True:
            new_hyp = max(A_hyps, key=lambda x: x['score'])
            A_hyps.remove(new_hyp)

            ys = to_device(self, torch.tensor(new_hyp['yseq']).unsqueeze(0))
            ys_mask = to_device(self, subsequent_mask(len(new_hyp['yseq'])).unsqueeze(0))
            y, c = self.forward_one_step(ys, ys_mask, new_hyp['cache'])

            ytu = torch.log_softmax(self.joint(hi, y[0]), dim=0)

            for k in six.moves.range(self.odim):
                beam_hyp = {'score': new_hyp['score'] + float(ytu[k]),
                            'yseq': new_hyp['yseq'][:],
                            'cache': new_hyp['cache']}

                if k == self.blank:
                    B_hyps.append(beam_hyp)
                else:
                    beam_hyp['yseq'].append(int(k))
                    beam_hyp['cache'] = c

                    A_hyps.append(beam_hyp)

            if len(B_hyps) >= k_range: // beam_size (W)
                break

    if normscore:
        nbest_hyps = sorted(
            B_hyps, key=lambda x: x['score'] / len(x['yseq']), reverse=True)[:nbest]
    else:
        nbest_hyps = sorted(
            B_hyps, key=lambda x: x['score'], reverse=True)[:nbest]

    return nbest_hyps

 


[1] Alex Graves, "Sequence Transduction with Recurrent Neural Networks", 2012

[2] ESPnet; https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transducer/transformer_decoder.py

 

 

+ Recent posts