본문 바로가기

AI/NLP Paper

[NLP] Transformer-XL : Attentive Language Models Beyond a Fixed-Length Context 리뷰

 

Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context

Transformers have a potential of learning longer-term dependency, but are limited by a fixed-length context in the setting of language modeling. We propose a novel neural architecture Transformer-XL that enables learning dependency beyond a fixed length wi

arxiv.org

 

Introduction

sequential data에서의 장기 의존성에 대한 수용력을 가진 신경망 모델을 만드는 것은 어렵습니다. LSTM을 이용하여 장기 의존성 문제를 해결하기 위한 시도가 많았지만, 기울기 소실과 기울기 폭발에 대한 문제 때문에 최적화에 어려움이 있었습니다. 그리고 경험적으로 볼 때 LSTM을 이용한 언어 모델은 평균 200개의 context words를 사용하며 이는 추가 개선의 여지가 있음을 말합니다.

 

이 외에 character-level 언어 모델링을 위한 deep Transformer network 학습에 사용하기 위한 auxiliary loss함수를 설계한 연구가 있었습니다. 이는 LSTM의 성능을 뛰어넘었지만, 해당 연구에서의 training이 고정된 길이의 segments들 각각을 이용하여 진행됐습니다. 이때의 각 segments는 자기들끼리 어떠한 정보의 흐름이 없이 단절되어있었습니다. 

 

이러한 고정된 context의 길이 때문에, 모델은 사전에 정의된 길이 이상의 longer-term dependency를 파악하지 못합니다. 게다가, 고정된 길이의 segments는 어떠한 semantic boundary 없이 단지 연속적인 symbols의 덩어리로 만들어집니다. 이 때문에, 모델은 처음 몇 개의 symbol을 예측하는 데 필요한 contextual information이 부족하여 비효율적인 최적화와 낮은 성능을 보였습니다. 논문에서는 이러한 문제를 context fragmentation이라고 부릅니다.

 

그리고 고정된 길이의 context로 인해 생겨나는 한계를 극복하기 위해 Transformer-XL을 제안합니다. 새로운 모델에서는 deep self-attention network 내에서의 순환(recurrence) 개념에 대해 소개합니다. 자세히 말하자면, 새로운 segment에 대해 처음부터(from scratch) hidden state를 계산하는 것이 아니라 이 전의 segment에서 얻은 hidden state를 재사용한다는 것입니다. 

 

결과적으로 정보가 recurrent connection을 통해 전파되기 때문에 장기 의존성에 대한 모델링이 가능해집니다. 그리고 전파되는 정보는 context fragmentation문제로 해결할 수 있습니다. 

 

뿐만 아니라 논문에서는 relative positional encoding을 이용합니다. 

 

recurrence in self-attention model과 relative positional encoding을 이용한 Transformer-XL은 character-level, word-level 언어 모델링에서 RNN을 큰 차이로 뛰어넘은 첫 번째 self-attention모델입니다.

 

Model

$\mathbf{x} = (x_1, \cdots, x_T)$가 주어질 때, 언어 모델링 task는 결합 확률 $P(x)$를 추정하는 것입니다.

$$\prod_t P(x_t | \mathbf{x}_{<t})$$

논문에서는 이러한 조건부 확률 모델링 기법을 사용했습니다. 구체적으로, context $\mathbf{x}_{<t}$를 고정된 길이의 hidden state로 인코딩하기 위해 학습 가능한 신경망이 사용되었습니다. 구해진 hidden state는 logits를 얻기 위해 워드 임베딩과 곱해지게 되고, 소프트맥스 함수로 들어가서 확률 값을 추정합니다.

 

Vanilla Transformer Language Model

트랜스포머의 self-attention을 언어 모델링에 적용하기 위해 중요한 점은 길이가 제멋대로인 context를 고정된 길이의 표현으로 어떻게 효과적으로 학습할 수 있는지에 대한 것입니다. 메모리나 사용 가능한 계산량이 무한하다면 간단하지만, 현실은 유한하기 때문에 어렵습니다. 

 

한 가지 간단한 방법으로는 전체 corpus를 짧은 segment로 나누고, 이 전의 segments의 모든 contextual information을 무시한 채로 각각의 segments에 대해서만 훈련을 진행하는 것입니다. 이 방법을 논문에서는 vanilla model이라 부르고 Fig 1(a). 에 해당합니다. 이 방식으로 훈련을 진행한다면 segments끼리는 순방향이든 역방향이든 어떠한 정보의 교류도 없을 것입니다.

 

위 방법에는 두 가지의 문제점이 있습니다. 첫째로 possible dependency length가 미리 정해둔 segment길이 이상으로는 갈 수가 없다는 것입니다. 이 때문에 self-attention 메커니즘이 RNN에 비해 기울기 소실에 영향을 적게 받지만 vanilla model은 이러한 장점을 살리지 못합니다. 두 번째로 sequence를 고정된 길이의 segments로 나누는 것은 context fragmentation문제를 야기할 수 있습니다.

 

Evaluation phase에서 vanilla model은 훈련에 사용된 길이와 동일한 길이의 segment를 사용하지만 제일 마지막 위치에서 단 하나의 예측만을 만들어냅니다. 그리고 다음 step에서 segment는 딱 하나의 position만큼 오른쪽으로 이동하고, 새로운 segment를 이용해 처음부터 다시 이 전 단계의 과정을 진행합니다. 이는 Fig 1(b)에 해당합니다. 이러한 방식은 각 예측에서 훈련 중에 봤던 가장 긴 context를 활용할 수 있게 하고, context fragmentation issue도 완화해줍니다. 하지만 비용이 너무 비싸다는 단점이 있습니다.

 

Segment-Level Recurrence with State Reuse

논문에서는 고정된 길이의 context 사용에 대한 한계를 다루기 위해 트랜스포머 구조에 recurrence 메커니즘을 도입한 것을 제안했습니다.

 

훈련 중, 이 전 segment를 계산한 hidden state sequence는 고정(fixed)되고 저장(cached)됩니다. 이는 추후에 모델이 새로운 segment를 처리할 때 확장된 context로 재사용됩니다. Fig 2(a)에 해당하는 부분입니다. 이러한 추가적인 입력은 네트워크가 과거의 정보를 이용할 수 있게 하고, long-term dependency 모델링 능력을 키워주는 것뿐 아니라 context fragmentation을 피할 수 있게 해 줍니다.

 

식으로 살펴보면 길이가 $L$인 연속적인 두 개의 segment는 다음과 같이 표현할 수 있습니다. $\mathbf{s}_{\tau} = [x_{\tau, 1}, \cdots, x_{\tau, L}], \mathbf{s}_{\tau+1} = [x_{\tau+1,1}, \cdots, x_{\tau+1, L}]$이 때 $\tau$번 째 segment의 $n$번 째 layer의 hidden state는 다음과 같습니다. $\mathbf{h}_{\tau}^n \in \mathbb{R}^{L \times d}$ 여기서 $d$는 hidden dimension입니다.

다음으로 $\tau+1$번 째 segment의 $n$번 째 layer의 hidden state는 아래와 같이 계산됩니다.

\begin{align} &\tilde{\mathbf{h}}_{\tau+1}^{n-1} = [\text{SG}(\mathbf{h}_{\tau}^{n-1}) \circ \mathbf{h}_{\tau+1}^{n-1}]\\ &\mathbf{q}_{\tau+1}^n, \mathbf{k}_{\tau+1}^n, \mathbf{v}_{\tau+1}^n = \mathbf{h}_{\tau+1}^{n-1}W_q^T, \tilde{\mathbf{h}}_{\tau+1}^{n-1}W_k^T, \tilde{\mathbf{h}}_{\tau+1}^{n-1}W_v^T\\ &\mathbf{h}_{\tau+1}^n = \text{Transformer-Layer}(\mathbf{q}_{\tau+1}^n, \mathbf{k}_{\tau+1}^n, \mathbf{v}_{\tau+1}^n)\end{align}

식에서 $\text{SG}(\cdot)$함수는 Stop-gradient를 의미하고, $W$는 모델 파라미터, $[\mathbf{h}_u \circ \mathbf{h}_v]$는 concatenate 연산을 의미합니다.

 

기존의 트랜스포머와 가장 다른 점은, key와 value가 확장된 context $\tilde{\mathbf{h}}_{\tau+1}^{n-1}$에 의해 조절된다는 것입니다. 그리고 이를 위해서 이 전 segment에서 $\mathbf{h}_{\tau}^{n-1}$이 저장된 것입니다. 이 부분은 Fig 2(a)에서 초록색 선으로 표시된 부분입니다. 

 

이처럼 연속적인 두 개의 segment에 적용되는 reucurrence 메커니즘을 이용해서 segment-level recurrence를 만들게 됩니다. 하지만 보통의 RNN-LM과는 다른 점이 있는데, 바로 $\mathbf{h}_{\tau+1}^n$이 $\mathbf{h}_{\tau}^{n-1}$처럼 한 레이어 밑에 있는 hidden state를 사용한다는 것입니다. 그 결과 largest possible dependency length는 레이어의 수와 segment 길이에 대해 선형적으로 증가하게 됩니다. $\text{O}(N \times L)$ (Fig 2(b)의 색칠된 부분)

 

이는 truncated BPTT와 유사하지만, 논문에서는 hidden state의 sequence를 저장해놓고(caches) relative positional encoding을 이용한다는 차이가 있습니다. 또한, 이러한 방식을 이용하여 아주 긴 context를 다룰 수 있고, fragmentation 문제를 해결한 것 이외에도 매우 빠른 evaluation speed를 가진다는 장점이 있습니다.

 

Relative Positional Encodings

이번에는 state를 재사용할 때 위치 정보도 유지할 수 있는 방법에 대한 고민을 하게 됩니다. 먼저, 기본적인 트랜스포머에서 sequence 순서에 대한 정보는 positional encoding($U \in \mathbb{R}^{L_{\text{max}} \times d}$)을 통해 정의됩니다. 여기서 $i$번 째 행 $U_i$는 segment 내에서 $i$번 째 absolute position을 나타냅니다. 그리고 $L_{\text{max}}$는 입력으로 들어올 수 있는 sequence의 최대 길이입니다. 그다음 실제 입력은 워드 임베딩과 positional encoding의 element-wise 합으로 구성됩니다. 

 

만약 이 방식을 recurrence 메커니즘에 사용한다고 하면, hidden state sequence는 아래와 같이 계산될 것입니다.

\begin{align} \mathbf{h}_{\tau+1} &= f(\mathbf{h}_{\tau}, E_{\mathbf{s}_{\tau+1}} + U_{1:L})\\ \mathbf{h}_{\tau} &= f(\mathbf{h}_{\tau-1}, E_{\mathbf{s}_{\tau}} + U_{1:L})\end{align}

 

$E_{\mathbf{s}_{\tau}}$는 $s_{\tau}$의 워드 임베딩 시퀀스입니다. 여기서 눈여겨볼 것은 $E_{\mathbf{s}_{\tau}}$와 $E_{\mathbf{s}_{\tau+1}}$이 모두 같은 positional encoding $U_{1:L}$을 사용한다는 점입니다. 그 결과 모델은 $x_{\tau, j}$와 $x_{\tau+1, j}$의 위치적인 차이를 구분할 수 없습니다.

 

이에 논문에서는 hidden state에 relative positional information을 인코딩하자는 아이디어를 사용했습니다. 또한 이러한 relative positional encoding은 임베딩과 결합하는 것이 아니라 attention score에 주입됩니다. 그리고 더 중요한 것은 이러한 방법이 좀 더 직관적이고 일반화가 가능하다는 점입니다.

 

예를 들어, 쿼리 벡터 $q_{\tau, i}$가 키 벡터 $\mathbf{k}_{\tau, \leq i}$를 본다고 할 때 segment내에서의 키 벡터에 대한 절대적인 위치(absolute position)은 필요가 없습니다. 정확히 어느 위치에 있는지가 중요한 게 아니라 정확한 위치는 몰라도 키 벡터 $\mathbf{k}_{\tau, j}$와 쿼리 벡터 $q_{\tau, i}$사이의 상대적인 거리만 알아도 충분합니다. 

 

먼저, 기본적인 트랜스포머의 attention 연산을 살펴보겠습니다. 

\begin{align} &\text{Attention} = Q K^T\\ &Q = (E + U)W_q\\ &K = (E + U)W_k\end{align}

위 식은 다음과 같이 풀어쓸 수 있습니다.

\begin{align} A_{i,j}^{\text{abs}} &= \underbrace{E_{x_i}^TW_q^TW_kE_{x_j}}_{(a)} + \underbrace{E_{x_i}^TW_q^TW_kU_j}_{(b)}\\ &+ \underbrace{U_i^TW_q^TW_kE_{x_j}}_{(c)} + \underbrace{U_i^TW_q^TW_kU_j}_{(d)}\end{align}

논문에서는 이 식을 상대적인 위치 정보를 사용하기 위해 아래와 같이 변형합니다.

\begin{align} A_{i, j}^{\text{rel}} &= \underbrace{E_{x_i}^TW_q^TW_{k, E}E_{x_j}}_{(a)} + \underbrace{E_{x_i}^TW_q^TW_{k, R}\mathbf{R}_{i-j}}_{(b)}\\ &+ \underbrace{u^TW_{k, E}E_{x_j}}_{(c)} + \underbrace{v^TW_{k, R}\mathbf{R}_{i-j}}_{(d)}\end{align}

처음 식과 바뀐 부분은 세 가지가 있습니다.

  1. (b)와 (d)에서 키 벡터를 계산하기 위해 사용된 절대적 위치 임베딩 $U_j$를 상대적 위치 임베딩 $R_{i-j}$로 바꿨습니다. 여기서 $R$은 sinusoid encoding matrix로 학습되는 파라미터는 없습니다.
  2. 새로운 학습 가능한 파라미터 $u \in \mathbb{R}^{d}$로 쿼리 $U_i^TW_q^T$를 대체했습니다. 상대 위치를 사용할 경우, 쿼리 벡터를 모든 쿼리 포지션에 대해서 동일합니다.  따라서 포지션에 상관없이 같은 값인 $u$로 대체했고, 이와 같은 이유로 (d)에 있는 $U_i^TW_q^T$도 $v \in \mathbb{R}^d$로 바꿨습니다.
  3. 마지막으로 content-based key vector와 location-based key vector를 독립적으로 생성하기 위해 $W_{k,E}$와 $W_{k, R}$을 분리했습니다.

이제 recurrence 메커니즘과 relative positional embedding을 합치면 Transformer-XL구조가 완성됩니다. 모델의 계산 과정 아래와 같습니다.

\begin{align} \tilde{\mathbf{h}}_{\tau}^{n-1} &= [\text{SG}(\mathbf{m}_{\tau}^{n-1}) \circ \mathbf{h}_{\tau}^{n-1}]\\ \mathbf{q}_{\tau}^n, \mathbf{k}_{\tau}^n, \mathbf{v}_{\tau}^n &= \mathbf{h}_{\tau}^{n-1}{W_q^n}^T, \tilde{\mathbf{h}}_{\tau}^{n-1}{W_{k, E}^n}^T, \tilde{\mathbf{h}}_{\tau}^{n-1}{W_v^n}^T\\ A_{\tau, i, j}^n &= {\mathbf{q}_{\tau, i}^n}^T \mathbf{k}_{\tau, j}^n + {\mathbf{q}_{\tau, i}^n}^T W_{k, R}^n R_{i-j} + u^T \mathbf{k}_{\tau, j} + v^TW_{k,R}^nR_{i-j}\\ \mathbf{a}_{\tau}^n &= \text{Masked-Softmax}(A_{\tau}^n)\mathbf{v}_{\tau}^n\\ \mathbf{o}_{\tau}^n &= \text{LayerNorm}(\text{Linear}(\mathbf{a}_{\tau}^n) + \mathbf{h}_{\tau}^{n-1})\\ \mathbf{h}_{\tau}^n &= \text{Positionwise-Feed-Forward}(\mathbf{o}_{\tau}^n)\end{align}