본문으로 건너뛰기
SuanLab

[논문 리뷰] Better & Faster Large Language Models via Multi-token Prediction

Large language models such as GPT and Llama are trained with a next-token prediction loss. In this work, we suggest that training language models to predict multiple future tokens at once results in h...

공유하기
[논문 리뷰] Better & Faster Large Language Models via Multi-token Prediction

[논문 리뷰] Better & Faster Large Language Models via Multi-token Prediction

TL;DR

대규모 언어 모델(LLM)은 보통 다음 한 토큰만 예측하도록 훈련됩니다. 하지만 이 논문은 한 번에 여러 개의 미래 토큰을 예측하도록 훈련하면, 동일한 훈련 비용으로 더 높은 성능(샘플 효율성)과 더 빠른 추론 속도를 동시에 달성할 수 있음을 보여줍니다. 특히 대형 모델과 코드 생성 같은 구조적인 작업에서 효과가 두드러지며, **자기 추측 디코딩(Self-Speculative Decoding)**을 통해 추론 속도를 최대 3배까지 향상시켜 비용 절감과 사용자 경험 개선에 기여할 수 있습니다.

연구 배경 및 동기

대규모 언어 모델(LLM)은 텍스트 생성 시 한 번에 한 토큰씩, 즉 자기회귀적(Autoregressive)으로 결과를 만들어냅니다. 이 방식은 효과적이지만, 다음과 같은 근본적인 한계를 가집니다.

  1. 근시안적 예측(Myopic Prediction): "다음" 토큰만 예측하는 것은 단기적으로는 최적일 수 있지만, 전체 문장의 구조나 논리적 흐름 같은 장기적인 의존성을 놓칠 수 있습니다. 예를 들어, "To boldly go where no..." 라는 문장을 생성할 때, 모델은 "one has gone before"라는 전체 구문을 미리 내다보는 것이 더 자연스럽습니다.
  2. 추론 병목 현상: 한 토큰을 생성해야 다음 토큰을 생성할 수 있는 순차적 과정은 GPU의 병렬 처리 능력을 제대로 활용하지 못해 추론 속도에 병목을 일으킵니다.

이 연구는 이러한 문제를 해결하기 위해 **다중 토큰 예측(Multi-token Prediction)**이라는 새로운 훈련 패러다임을 제안합니다. 한 번의 예측 단계에서 여러 미래 토큰을 동시에 내다보게 함으로써, 모델이 더 넓은 문맥을 학습하고 추론 효율성을 극대화하도록 유도합니다.

관련 연구

기존에도 LLM의 한계를 극복하려는 다양한 연구가 있었습니다.

접근법 설명 본 연구와의 차별점
어텐션 메커니즘 개선 Transformer-XL, Longformer 등 더 긴 문맥을 효율적으로 처리하려는 연구 여전히 단일 토큰 예측의 틀 안에서 작동
비-자기회귀(Non-Autoregressive) 모델 한 번에 전체 시퀀스를 생성하여 속도를 높이지만, 보통 생성 품질이 떨어짐 자기회귀 모델의 높은 품질을 유지하면서 속도를 개선
추측 디코딩(Speculative Decoding) 작은 모델로 초안을 생성하고, 큰 모델로 검증하여 속도를 높이는 기법 별도의 작은 모델 없이, 하나의 모델 내에서 초안 생성과 검증을 모두 수행 (Self-Speculative Decoding)

본 연구는 특히 추측 디코딩의 아이디어를 훈련 단계부터 통합했다는 점에서 독창적입니다. 다중 토큰 예측을 위한 헤드를 훈련하고, 이를 추론 시 추측(초안 생성)에 활용하여 추가 모델 없이 속도 향상을 이뤄냅니다.

핵심 기여

  1. 다중 토큰 예측 프레임워크 제안: 기존 단일 토큰 예측을 넘어, 여러 미래 토큰을 동시에 예측하는 새로운 훈련 방법론을 제안합니다.
  2. 샘플 효율성 향상: 동일한 데이터와 훈련 시간으로 학습했을 때, 다중 토큰 예측 모델이 기존 모델보다 더 높은 성능을 달성합니다.
  3. 추론 속도 개선: 훈련된 다중 예측 헤드를 활용한 **자기 추측 디코딩(Self-Speculative Decoding)**을 통해 추론 속도를 최대 3배까지 가속화합니다.
  4. 범용성 입증: 코드 생성(HumanEval, MBPP)뿐만 아니라 요약 등 일반 자연어 처리 작업에서도 성능 향상을 입증했습니다.

제안 방법론

1. 다중 헤드 아키텍처 (Multi-Head Architecture)

기존 LLM의 구조는 거의 그대로 유지하면서, 마지막 레이어의 은닉 상태(hidden state) 위에 **여러 개의 독립적인 예측 헤드(prediction head)**를 추가합니다.

  • Head 1: 다음 토큰(t+1t+1)을 예측
  • Head 2: 다다음 토큰(t+2t+2)을 예측
  • ...
  • Head n: nn번째 미래 토큰(t+nt+n)을 예측

개념도: [Input] -> [LLM Body] -> [Hidden State] -> [Head 1 (t+1)], [Head 2 (t+2)], ..., [Head n (t+n)]

2. 다중 토큰 예측 손실 함수 (Loss Function)

훈련 시, 각 예측 헤드가 올바른 미래 토큰을 예측하도록 손실 함수를 구성합니다. 기존의 다음 토큰 예측 손실이 단일 토큰에 대한 것이었다면, 다중 토큰 예측 손실은 nn개의 미래 토큰에 대한 손실을 모두 합산합니다.

  • 기존 손실 함수 (Next-token Prediction): Lnext=tlogPθ(xt+1x1:t)L_{\text{next}} = \sum_{t} -\log P_\theta(x_{t+1} | x_{1:t})

  • 제안된 손실 함수 (Multi-token Prediction): Lmulti=ti=1nλi(logPθ(xt+ix1:t))L_{\text{multi}} = \sum_{t} \sum_{i=1}^{n} \lambda_i \cdot (-\log P_\theta(x_{t+i} | x_{1:t}))

여기서 x1:tx_{1:t}는 시점 tt까지의 입력 시퀀스, xt+ix_{t+i}는 예측해야 할 ii번째 미래 토큰입니다. Pθ(xt+ix1:t)P_\theta(x_{t+i} | x_{1:t})ii번째 예측 헤드가 계산한 확률 분포를 의미하며, λi\lambda_i는 각 예측의 중요도를 조절하는 가중치입니다 (논문에서는 λi=1\lambda_i=1로 동일하게 설정).

3. 자기 추측 디코딩 (Self-Speculative Decoding)

추론 속도 향상의 핵심입니다. 이 과정은 두 단계로 이루어집니다.

  1. 초안 생성 (Drafting): 현재 시점 tt에서, 훈련된 nn개의 예측 헤드를 사용해 한 번의 정방향 패스(forward pass)로 미래 토큰 초안 (x^t+1,,x^t+n)(\hat{x}_{t+1}, \dots, \hat{x}_{t+n})을 동시에 생성합니다.
  2. 검증 (Verification): 생성된 초안을 원래 입력 x1:tx_{1:t}에 이어 붙인 새로운 시퀀스 [x1:t,x^t+1,,x^t+n1][x_{1:t}, \hat{x}_{t+1}, \dots, \hat{x}_{t+n-1}]를 모델에 입력하여, 각 위치의 예측 확률을 다시 계산합니다.
  3. 수락 및 수정 (Accept & Correct): 초안 토큰 x^t+i\hat{x}_{t+i}가 검증 단계의 예측과 일치하면 해당 토큰을 '수락'합니다. 만약 kk개의 토큰이 일치했다면, 한 번에 kk개의 토큰을 생성한 셈이 됩니다. 불일치가 발생한 첫 번째 위치에서는 검증된 확률 분포에 따라 토큰을 샘플링하고, 그 이후의 초안은 버립니다.

이 방식을 통해 약 2번의 정방향 패스 비용으로 여러 개의 토큰을 한 번에 생성할 수 있어, 토큰당 평균 추론 시간이 크게 단축됩니다.

실험 설정

  • 모델: 3억(300M)부터 130억(13B) 파라미터에 이르는 다양한 크기의 Transformer 모델을 처음부터 훈련.
  • 데이터셋:
    • 코드: The Stack (1.1TB)
    • 자연어: C4 (800GB)
  • 평가 벤치마크:
    • 코드 생성: HumanEval, MBPP
    • 요약: CNN/Dailymail, XSum
  • 베이스라인: 동일한 데이터와 훈련 스텝으로 학습된 표준 다음 토큰 예측(Next-Token Prediction) 모델.

실험 결과 분석

다중 토큰 예측 모델은 모든 모델 크기와 데이터셋에서 베이스라인을 일관되게 능가했습니다.

데이터셋 베이스라인 (1-토큰) 다중 토큰 예측 (4-토큰) 성능 향상
HumanEval (Code) 33.9% 38.5% +4.6%p
MBPP (Code) 49.6% 53.0% +3.4%p
CNN/Dailymail (Summary) ROUGE-L 29.5 ROUGE-L 30.1 +0.6
  • 성능 향상: 특히 코드 생성과 같이 구조가 명확하고 예측 가능한 패턴이 많은 작업에서 성능 향상이 두드러졌습니다. 이는 모델이 여러 토큰을 내다보며 코드의 구문 구조를 더 잘 학습했기 때문으로 분석됩니다.
  • 추론 속도: 4-토큰 예측 모델의 경우, 자기 추측 디코딩을 통해 베이스라인 대비 최대 3배 빠른 추론 속도를 달성했습니다. 이는 생성해야 할 토큰이 많을수록(long-form generation) 더 큰 효과를 보였습니다.

비판적 평가 (한계점)

  • 훈련 오버헤드: 다중 예측 헤드로 인해 훈련 시 약간의 계산 및 메모리 오버헤드가 발생합니다. (논문에 따르면 약 8% 내외)
  • 사전 훈련된 모델 적용의 어려움: 이미 다음 토큰 예측으로 사전 훈련된 모델을 다중 토큰 예측 방식으로 파인튜닝하는 것은 효과가 미미했습니다. 이 방법론의 최대 효과를 보려면 처음부터(from scratch) 훈련해야 합니다.
  • 최적의 n: 예측할 토큰의 수(n)는 태스크와 모델 크기에 따라 달라지는 하이퍼파라미터로, 적절한 값을 찾기 위한 튜닝이 필요합니다.

향후 연구 방향

  • 적응형 예측 수(Adaptive n): 문맥에 따라 예측할 토큰의 수를 동적으로 조절하는 방법을 연구하여 효율성을 더욱 높일 수 있습니다.
  • 강화학습과의 결합: 어떤 초안을 생성하고 수락할지 결정하는 과정을 강화학습을 통해 최적화하여 생성 품질을 높이는 연구가 유망합니다.
  • 다양한 모달리티로의 확장: 텍스트뿐만 아니라 이미지, 오디오 등 다른 데이터 형식의 생성 모델에도 다중 예측 방법론을 적용해볼 수 있습니다.

실무 적용 가이드

다중 토큰 예측을 실무에 도입하려는 개발자는 다음 사항을 고려해야 합니다.

  1. 처음부터 훈련할 때 고려: 새로운 LLM을 직접 훈련하는 경우, 이 방법론은 동일 비용으로 더 좋은 모델을 얻을 수 있는 강력한 대안입니다.
  2. n 값의 선택: 예측할 토큰 수 n은 트레이드오프 관계에 있습니다.
    • 작은 n (예: 2): 훈련 안정성이 높고, 미미한 성능 향상과 약간의 속도 개선.
    • n (예: 4-8): 잠재적 성능 및 속도 향상이 크지만, 훈련이 불안정해질 수 있고 정확도가 떨어질 수 있음. 일반적으로 n=4가 좋은 출발점입니다.
  3. 적합한 태스크: 코드 생성, API 호출 시퀀스 생성, 구조화된 데이터(JSON, YAML) 생성 등 지역적 예측 가능성이 높은 태스크에서 특히 효과적입니다.

결론

본 연구는 LLM 훈련의 기본 패러다임에 대한 중요한 질문을 던집니다. 단순히 "다음 한 단어"를 맞추는 것을 넘어, "미래의 여러 단어"를 예측하도록 훈련함으로써, 모델은 더 깊은 문맥 이해와 뛰어난 생성 능력을 갖추게 됩니다. 다중 토큰 예측은 동일한 훈련 예산으로 더 똑똑한 모델을 만들고, 추가 비용 없이 더 빠른 추론을 가능하게 하는, 그야말로 '더 좋고 더 빠른' LLM을 향한 중요한 진일보입니다.

참고 자료

댓글