[논문 리뷰] Better & Faster Large Language Models via Multi-token Prediction
TL;DR
대규모 언어 모델(LLM)은 보통 다음 한 토큰만 예측하도록 훈련됩니다. 하지만 이 논문은 한 번에 여러 개의 미래 토큰을 예측하도록 훈련하면, 동일한 훈련 비용으로 더 높은 성능(샘플 효율성)과 더 빠른 추론 속도를 동시에 달성할 수 있음을 보여줍니다. 특히 대형 모델과 코드 생성 같은 구조적인 작업에서 효과가 두드러지며, **자기 추측 디코딩(Self-Speculative Decoding)**을 통해 추론 속도를 최대 3배까지 향상시켜 비용 절감과 사용자 경험 개선에 기여할 수 있습니다.
연구 배경 및 동기
대규모 언어 모델(LLM)은 텍스트 생성 시 한 번에 한 토큰씩, 즉 자기회귀적(Autoregressive)으로 결과를 만들어냅니다. 이 방식은 효과적이지만, 다음과 같은 근본적인 한계를 가집니다.
- 근시안적 예측(Myopic Prediction): "다음" 토큰만 예측하는 것은 단기적으로는 최적일 수 있지만, 전체 문장의 구조나 논리적 흐름 같은 장기적인 의존성을 놓칠 수 있습니다. 예를 들어, "To boldly go where no..." 라는 문장을 생성할 때, 모델은 "one has gone before"라는 전체 구문을 미리 내다보는 것이 더 자연스럽습니다.
- 추론 병목 현상: 한 토큰을 생성해야 다음 토큰을 생성할 수 있는 순차적 과정은 GPU의 병렬 처리 능력을 제대로 활용하지 못해 추론 속도에 병목을 일으킵니다.
이 연구는 이러한 문제를 해결하기 위해 **다중 토큰 예측(Multi-token Prediction)**이라는 새로운 훈련 패러다임을 제안합니다. 한 번의 예측 단계에서 여러 미래 토큰을 동시에 내다보게 함으로써, 모델이 더 넓은 문맥을 학습하고 추론 효율성을 극대화하도록 유도합니다.
관련 연구
기존에도 LLM의 한계를 극복하려는 다양한 연구가 있었습니다.
| 접근법 | 설명 | 본 연구와의 차별점 |
|---|---|---|
| 어텐션 메커니즘 개선 | Transformer-XL, Longformer 등 더 긴 문맥을 효율적으로 처리하려는 연구 | 여전히 단일 토큰 예측의 틀 안에서 작동 |
| 비-자기회귀(Non-Autoregressive) 모델 | 한 번에 전체 시퀀스를 생성하여 속도를 높이지만, 보통 생성 품질이 떨어짐 | 자기회귀 모델의 높은 품질을 유지하면서 속도를 개선 |
| 추측 디코딩(Speculative Decoding) | 작은 모델로 초안을 생성하고, 큰 모델로 검증하여 속도를 높이는 기법 | 별도의 작은 모델 없이, 하나의 모델 내에서 초안 생성과 검증을 모두 수행 (Self-Speculative Decoding) |
본 연구는 특히 추측 디코딩의 아이디어를 훈련 단계부터 통합했다는 점에서 독창적입니다. 다중 토큰 예측을 위한 헤드를 훈련하고, 이를 추론 시 추측(초안 생성)에 활용하여 추가 모델 없이 속도 향상을 이뤄냅니다.
핵심 기여
- 다중 토큰 예측 프레임워크 제안: 기존 단일 토큰 예측을 넘어, 여러 미래 토큰을 동시에 예측하는 새로운 훈련 방법론을 제안합니다.
- 샘플 효율성 향상: 동일한 데이터와 훈련 시간으로 학습했을 때, 다중 토큰 예측 모델이 기존 모델보다 더 높은 성능을 달성합니다.
- 추론 속도 개선: 훈련된 다중 예측 헤드를 활용한 **자기 추측 디코딩(Self-Speculative Decoding)**을 통해 추론 속도를 최대 3배까지 가속화합니다.
- 범용성 입증: 코드 생성(HumanEval, MBPP)뿐만 아니라 요약 등 일반 자연어 처리 작업에서도 성능 향상을 입증했습니다.
제안 방법론
1. 다중 헤드 아키텍처 (Multi-Head Architecture)
기존 LLM의 구조는 거의 그대로 유지하면서, 마지막 레이어의 은닉 상태(hidden state) 위에 **여러 개의 독립적인 예측 헤드(prediction head)**를 추가합니다.
- Head 1: 다음 토큰()을 예측
- Head 2: 다다음 토큰()을 예측
- ...
- Head n: 번째 미래 토큰()을 예측
개념도: [Input] -> [LLM Body] -> [Hidden State] -> [Head 1 (t+1)], [Head 2 (t+2)], ..., [Head n (t+n)]
2. 다중 토큰 예측 손실 함수 (Loss Function)
훈련 시, 각 예측 헤드가 올바른 미래 토큰을 예측하도록 손실 함수를 구성합니다. 기존의 다음 토큰 예측 손실이 단일 토큰에 대한 것이었다면, 다중 토큰 예측 손실은 개의 미래 토큰에 대한 손실을 모두 합산합니다.
-
기존 손실 함수 (Next-token Prediction):
-
제안된 손실 함수 (Multi-token Prediction):
여기서 는 시점 까지의 입력 시퀀스, 는 예측해야 할 번째 미래 토큰입니다. 는 번째 예측 헤드가 계산한 확률 분포를 의미하며, 는 각 예측의 중요도를 조절하는 가중치입니다 (논문에서는 로 동일하게 설정).
3. 자기 추측 디코딩 (Self-Speculative Decoding)
추론 속도 향상의 핵심입니다. 이 과정은 두 단계로 이루어집니다.
- 초안 생성 (Drafting): 현재 시점 에서, 훈련된 개의 예측 헤드를 사용해 한 번의 정방향 패스(forward pass)로 미래 토큰 초안 을 동시에 생성합니다.
- 검증 (Verification): 생성된 초안을 원래 입력 에 이어 붙인 새로운 시퀀스 를 모델에 입력하여, 각 위치의 예측 확률을 다시 계산합니다.
- 수락 및 수정 (Accept & Correct): 초안 토큰 가 검증 단계의 예측과 일치하면 해당 토큰을 '수락'합니다. 만약 개의 토큰이 일치했다면, 한 번에 개의 토큰을 생성한 셈이 됩니다. 불일치가 발생한 첫 번째 위치에서는 검증된 확률 분포에 따라 토큰을 샘플링하고, 그 이후의 초안은 버립니다.
이 방식을 통해 약 2번의 정방향 패스 비용으로 여러 개의 토큰을 한 번에 생성할 수 있어, 토큰당 평균 추론 시간이 크게 단축됩니다.
실험 설정
- 모델: 3억(300M)부터 130억(13B) 파라미터에 이르는 다양한 크기의 Transformer 모델을 처음부터 훈련.
- 데이터셋:
- 코드: The Stack (1.1TB)
- 자연어: C4 (800GB)
- 평가 벤치마크:
- 코드 생성: HumanEval, MBPP
- 요약: CNN/Dailymail, XSum
- 베이스라인: 동일한 데이터와 훈련 스텝으로 학습된 표준 다음 토큰 예측(Next-Token Prediction) 모델.
실험 결과 분석
다중 토큰 예측 모델은 모든 모델 크기와 데이터셋에서 베이스라인을 일관되게 능가했습니다.
| 데이터셋 | 베이스라인 (1-토큰) | 다중 토큰 예측 (4-토큰) | 성능 향상 |
|---|---|---|---|
| HumanEval (Code) | 33.9% | 38.5% | +4.6%p |
| MBPP (Code) | 49.6% | 53.0% | +3.4%p |
| CNN/Dailymail (Summary) | ROUGE-L 29.5 | ROUGE-L 30.1 | +0.6 |
- 성능 향상: 특히 코드 생성과 같이 구조가 명확하고 예측 가능한 패턴이 많은 작업에서 성능 향상이 두드러졌습니다. 이는 모델이 여러 토큰을 내다보며 코드의 구문 구조를 더 잘 학습했기 때문으로 분석됩니다.
- 추론 속도: 4-토큰 예측 모델의 경우, 자기 추측 디코딩을 통해 베이스라인 대비 최대 3배 빠른 추론 속도를 달성했습니다. 이는 생성해야 할 토큰이 많을수록(long-form generation) 더 큰 효과를 보였습니다.
비판적 평가 (한계점)
- 훈련 오버헤드: 다중 예측 헤드로 인해 훈련 시 약간의 계산 및 메모리 오버헤드가 발생합니다. (논문에 따르면 약 8% 내외)
- 사전 훈련된 모델 적용의 어려움: 이미 다음 토큰 예측으로 사전 훈련된 모델을 다중 토큰 예측 방식으로 파인튜닝하는 것은 효과가 미미했습니다. 이 방법론의 최대 효과를 보려면 처음부터(from scratch) 훈련해야 합니다.
- 최적의
n값: 예측할 토큰의 수(n)는 태스크와 모델 크기에 따라 달라지는 하이퍼파라미터로, 적절한 값을 찾기 위한 튜닝이 필요합니다.
향후 연구 방향
- 적응형 예측 수(Adaptive
n): 문맥에 따라 예측할 토큰의 수를 동적으로 조절하는 방법을 연구하여 효율성을 더욱 높일 수 있습니다. - 강화학습과의 결합: 어떤 초안을 생성하고 수락할지 결정하는 과정을 강화학습을 통해 최적화하여 생성 품질을 높이는 연구가 유망합니다.
- 다양한 모달리티로의 확장: 텍스트뿐만 아니라 이미지, 오디오 등 다른 데이터 형식의 생성 모델에도 다중 예측 방법론을 적용해볼 수 있습니다.
실무 적용 가이드
다중 토큰 예측을 실무에 도입하려는 개발자는 다음 사항을 고려해야 합니다.
- 처음부터 훈련할 때 고려: 새로운 LLM을 직접 훈련하는 경우, 이 방법론은 동일 비용으로 더 좋은 모델을 얻을 수 있는 강력한 대안입니다.
n값의 선택: 예측할 토큰 수n은 트레이드오프 관계에 있습니다.- 작은
n(예: 2): 훈련 안정성이 높고, 미미한 성능 향상과 약간의 속도 개선. - 큰
n(예: 4-8): 잠재적 성능 및 속도 향상이 크지만, 훈련이 불안정해질 수 있고 정확도가 떨어질 수 있음. 일반적으로n=4가 좋은 출발점입니다.
- 작은
- 적합한 태스크: 코드 생성, API 호출 시퀀스 생성, 구조화된 데이터(JSON, YAML) 생성 등 지역적 예측 가능성이 높은 태스크에서 특히 효과적입니다.
결론
본 연구는 LLM 훈련의 기본 패러다임에 대한 중요한 질문을 던집니다. 단순히 "다음 한 단어"를 맞추는 것을 넘어, "미래의 여러 단어"를 예측하도록 훈련함으로써, 모델은 더 깊은 문맥 이해와 뛰어난 생성 능력을 갖추게 됩니다. 다중 토큰 예측은 동일한 훈련 예산으로 더 똑똑한 모델을 만들고, 추가 비용 없이 더 빠른 추론을 가능하게 하는, 그야말로 '더 좋고 더 빠른' LLM을 향한 중요한 진일보입니다.
참고 자료
- 논문 원문: Better and Faster Large Language Models via Multi-token Prediction (arXiv:2404.19737)
- 관련 기술: Speculative Decoding, Non-Autoregressive Transformers

![[논문 리뷰] Better & Faster Large Language Models via Multi-token Prediction](/assets/images/blog/20260515-paper-2404-19737-better-amp-faster-large-langua.jpg)