[논문 리뷰] On-Policy Context Distillation for Language Models

Context distillation enables language models to internalize in-context knowledge into their parameters. In our work, we propose On-Policy Context Distillation (OPCD), a framework that bridges on-polic...

[논문 리뷰] On-Policy Context Distillation for Language Models

[논문 리뷰] On-Policy Context Distillation for Language Models

TL;DR

최근 언어 모델 연구에서, 모델이 프롬프트로 주어진 일시적인 정보를 영구적으로 내재화하는 능력이 중요해지고 있습니다. 이 논문은 **On-Policy Context Distillation (OPCD)**라는 새로운 프레임워크를 제안합니다. OPCD는 학생 모델이 스스로 생성한 데이터로 학습하면서, 풍부한 컨텍스트(예: 상세한 풀이 과정)를 가진 교사 모델의 행동을 모방하도록 합니다. 핵심은 역 Kullback-Leibler(KL) 발산을 최소화하여, 학생 모델이 교사 모델의 확률 분포에서 가장 가능성이 높은 '정답' 경로에 집중하게 만드는 것입니다. 이를 통해 OPCD는 수학 문제 풀이, 텍스트 기반 게임, 시스템 프롬프트 증류 등 다양한 분야에서 기존 방법보다 높은 정확도와 일반화 성능을 달성했습니다. 또한, OPCD는 작은 모델이 더 큰 모델의 지식을 효과적으로 학습할 수 있게 하여, 모델 경량화와 효율성 향상의 가능성을 열었습니다.

연구 배경 및 동기

언어 모델은 프롬프트에 담긴 정보를 활용해 놀라운 성능을 보여주지만, 이 정보는 일회성에 그칩니다. 모델의 파라미터 자체에 지식이 내재화되지 않기 때문이죠. 예를 들어, 복잡한 시스템 프롬프트나 문제 풀이 예시(few-shot examples)를 매번 입력해야 하는 것은 비효율적입니다.

이 문제를 해결하기 위해 **컨텍스트 증류(Context Distillation)**가 제안되었습니다. 교사 모델이 좋은 예시(컨텍스트)를 보고 생성한 답변을, 학생 모델이 컨텍스트 없이도 생성하도록 학습시키는 방식입니다. 하지만 기존 방법들은 두 가지 주요 한계를 가집니다.

  1. 오프-정책(Off-Policy) 학습의 한계: 기존 방식은 교사 모델이 미리 생성해 둔 고정된 데이터셋으로 학생을 학습시킵니다. 이는 학생이 정해진 '정답지'만 보고 공부하는 것과 같습니다. 막상 시험(추론)에서는 학생 스스로 문장을 생성해야 하는데, 학습 과정과 달라 실수를 연발하기 쉽습니다. 이를 노출 편향(Exposure Bias) 문제라고 합니다.

  2. 순방향 KL 발산(Forward KL Divergence)의 한계: 순방향 KL은 학생이 교사의 모든 가능한 답변 스타일을 전부 따라 하도록 만듭니다. 이는 교사가 "정답은 3이야" 또는 "답은 3입니다"라고 말할 수 있을 때, 학생이 두 가지 모두를 어설프게 흉내 내려고 하는 모드 커버링(Mode-covering) 문제를 야기합니다. 결과적으로 학생의 답변은 모호하고 불확실해질 수 있습니다.

이 논문은 이러한 한계를 극복하기 위해 온-정책(On-Policy) 증류와 **역 KL 발산(Reverse KL Divergence)**을 결합한 OPCD를 제안합니다. OPCD는 학생이 직접 문제를 풀어보고(on-policy), 교사가 학생의 풀이 과정에서 더 좋은 방향을 알려주는(distillation) 방식입니다. 또한, 역 KL 발산을 통해 교사의 여러 답변 스타일 중 가장 확률 높은 하나에 집중(mode-seeking)하도록 유도하여 더 명확하고 정확한 답변을 생성하게 합니다.

관련 연구

언어 모델의 성능 향상을 위한 연구는 크게 지식 증류, 컨텍스트 증류, 경험적 학습으로 나눌 수 있습니다.

  • 지식 증류 (Knowledge Distillation): Hinton et al. (2015)이 제안한 방식으로, 큰 교사 모델의 '소프트 레이블'(출력 확률 분포)을 작은 학생 모델이 학습하게 하여 지식을 전수합니다. 하지만 대부분 오프-정책 방식에 의존하여 노출 편향 문제가 발생합니다.

  • 컨텍스트 증류 (Context Distillation): 모델이 프롬프트의 정보를 내재화하는 데 초점을 맞춥니다. GPT-3 (Brown et al., 2020)와 같은 모델은 인-컨텍스트 학습(in-context learning)의 강력함을 보여주었지만, 매번 컨텍스트를 제공해야 하는 비효율성을 해결하지는 못했습니다.

  • 경험적 학습 (Experiential Learning): AlphaGo (Silver et al., 2016)처럼 모델이 환경과 상호작용하며 얻은 경험으로 학습하는 강화학습 기반 접근법입니다. 언어 모델에 직접 적용하기는 어렵지만, '스스로 생성한 데이터로 학습한다'는 온-정책 학습의 아이디어와 맞닿아 있습니다.

OPCD는 이러한 연구들의 장점을 결합합니다. 온-정책 학습으로 노출 편향을 해결하고, 역 KL 발산을 통해 지식 증류의 효율을 높여 기존 컨텍스트 증류의 한계를 극복합니다.

연구 접근법 한계점
Hinton et al. (2015) 지식 증류 오프-정책 학습으로 인한 노출 편향
Brown et al. (2020) 컨텍스트 증류 프롬프트에 대한 높은 의존성, 비효율성
Silver et al. (2016) 경험적 학습 강화학습 기반으로, 일반 언어 모델에 적용하기 어려움

핵심 기여

  1. 온-정책 컨텍스트 증류 프레임워크 제안: 학생 모델이 스스로 생성한 데이터로 학습하여 노출 편향을 해결하고, 추론 시의 성능을 극대화했습니다.
  2. 역 KL 발산을 통한 모드 추구(Mode-Seeking) 특성 강화: 학생 모델이 교사 모델의 가장 확률 높은 출력에 집중하도록 유도하여, 불필요한 출력을 억제하고 더 정확하고 일관된 결과를 생성합니다.
  3. 다양한 응용 분야에서의 SOTA 성능 달성: 수학 문제 풀이, 텍스트 기반 게임, 시스템 프롬프트 증류 등에서 기존 방법론을 크게 뛰어넘는 성능을 입증했습니다.
  4. 교차 크기 증류(Cross-Size Distillation) 가능성 제시: 작은 학생 모델이 더 큰 교사 모델의 지식을 효과적으로 학습할 수 있음을 보여주어, 모델 경량화의 새로운 방향을 제시했습니다.

제안 방법론

OPCD의 목표는 컨텍스트 cc가 있을 때 교사 모델이 보여주는 뛰어난 성능을, 컨텍스트가 없는 학생 모델에 내재화하는 것입니다.

학습 과정

OPCD의 학습 과정은 다음과 같은 단계로 이루어집니다.

  1. 학생의 응답 생성 (On-Policy Sampling): 학생 모델 πθπ_θ는 입력 xx에 대해 컨텍스트 cc 없이 스스로 응답(토큰 시퀀스) yy를 생성합니다. yπθ(x)y \sim π_θ(\cdot|x)
  2. 교사의 피드백 제공: 학생이 생성한 응답 시퀀스 yy의 각 타임스텝 tt마다, 교사 모델 πteacherπ_{\text{teacher}}는 컨텍스트 cc를 참고하여 해당 시점(y<ty_{<t})에서 다음 토큰의 이상적인 확률 분포를 계산합니다. πteacher(x,c,y<t)π_{\text{teacher}}(\cdot|x, c, y_{<t})
  3. 역 KL 발산을 통한 학습: 학생 모델은 자신이 생성한 경로를 따라, 자신의 다음 토큰 확률 분포 πθ(x,y<t)π_θ(\cdot|x, y_{<t})가 교사의 분포와 일치하도록 역 KL 발산을 최소화하는 방향으로 파라미터 θθ를 업데이트합니다.

핵심 수식

OPCD의 손실 함수는 학생이 생성한 궤적에 대한 기댓값으로 표현되며, 각 타임스텝에서 역 KL 발산을 최소화합니다.

L(θ)=Eyπθ(x)[t=0y1DKL(πθ(x,y<t)  πteacher(x,c,y<t))]L(\theta) = \mathbb{E}_{y \sim \pi_\theta(\cdot|x)} \left[ \sum_{t=0}^{|y|-1} D_{KL}(\pi_\theta(\cdot|x, y_{<t}) \ || \ \pi_{\text{teacher}}(\cdot|x, c, y_{<t})) \right]

여기서 DKL(QP)D_{KL}(Q || P)는 역 KL 발산을 의미합니다. 이 손실 함수는 학생 모델의 분포(QQ)가 교사 모델의 분포(PP)에서 확률이 낮은 영역에 대해서는 더 낮은 확률을 갖도록 강제합니다. 결과적으로 학생은 교사가 가장 선호하는 '하나의 정답 모드'를 찾아 집중하게 됩니다.

의사 코드 (Pseudo-code)

다음은 OPCD의 학습 과정을 나타낸 의사 코드입니다.

# θ: 학생 모델 파라미터
# π_θ: 학생 모델
# π_teacher: 교사 모델

for x, c in dataset:
    # 1. 학생 모델이 컨텍스트 없이 스스로 응답 생성 (On-Policy)
    y = π_θ.sample(prompt=x)  # y = (y_0, y_1, ..., y_T)

    total_loss = 0
    for t in range(len(y)):
        # 현재까지 생성된 시퀀스
        y_prefix = y[:t]

        # 2. 교사 모델이 컨텍스트 c를 참고하여 이상적인 다음 토큰 분포 계산
        teacher_dist = π_teacher.get_distribution(prompt=x, context=c, completion_prefix=y_prefix)

        # 3. 학생 모델의 현재 분포 계산
        student_dist = π_θ.get_distribution(prompt=x, completion_prefix=y_prefix)

        # 4. 역 KL 발산 손실 계산 및 누적
        # D_KL(student || teacher)
        loss_t = kl_divergence(student_dist, teacher_dist)
        total_loss += loss_t

    # 5. 누적된 손실로 학생 모델의 파라미터 θ 업데이트
    update_parameters(θ, total_loss)

실험 설정

OPCD의 효과를 검증하기 위해 수학 추론, 텍스트 게임, 시스템 프롬프트 증류 세 가지 태스크에서 실험을 진행했습니다.

  • 데이터셋:

    1. 수학 문제: GSM8K 벤치마크를 기반으로 한 DAPO-Math-17K 데이터셋을 사용. 모델이 단계별 풀이 과정을 내재화하는 능력을 평가.
    2. 텍스트 게임: Frozen LakeSokoban 게임을 통해, 모델이 게임 규칙과 전략을 학습하는 능력을 검증.
    3. 시스템 프롬프트 증류: 모델이 특정 역할(예: "너는 시인이야")을 수행하도록 지시하는 긴 시스템 프롬프트를 내재화하는 능력을 평가.
  • 평가 지표:

    • 정확도(Accuracy): 태스크의 정답률.
    • 일반화 성능(Generalization): 학습 데이터와 다른 분포의 테스트 데이터에 대한 성능.
    • 분포 외 성능(Out-of-Distribution Performance): 특정 지식을 내재화한 후, 관련 없는 다른 태스크에서의 성능 저하 여부.
  • 베이스라인:

    • Off-Policy CD (Forward KL): 기존의 일반적인 컨텍스트 증류 방식.
    • Fine-tuning on Teacher's outputs: 교사의 정답 데이터셋으로 지도 학습.
  • 모델: Llama2-7B, Llama2-70B, PaLM-2-L 등 다양한 크기의 모델을 학생과 교사로 사용.

실험 결과 분석

OPCD는 모든 실험에서 기존 베이스라인을 압도하는 성능을 보였습니다.

주요 결과

분야 베이스라인 정확도 (Off-Policy CD) OPCD 정확도 성능 향상률 (%)
수학 문제 (GSM8K) 75.0% 85.0% 13.3%
텍스트 게임 (ALFWorld) 60.0% 72.0% 20.0%
시스템 프롬프트 증류 70.0% 82.0% 17.1%
  • 수학 문제 풀이: OPCD로 학습한 모델은 더 일관되고 논리적인 풀이 과정을 생성했습니다. 오프-정책 모델이 중간에 실수를 하거나 풀이를 중단하는 경우가 잦았던 반면, OPCD 모델은 끝까지 정확한 추론을 이어갔습니다.
  • 교차 크기 증류: Llama2-7B(학생)가 PaLM-2-L(교사)의 지식을 증류받았을 때, OPCD는 오프-정책 방식보다 훨씬 효과적으로 성능을 이전시켰습니다. 이는 작은 모델을 효율적으로 고도화할 수 있음을 시사합니다.

Ablation Study (구성 요소 분석)

OPCD의 핵심 요소인 온-정책 학습역 KL 발산의 중요성을 확인하기 위해 각각을 제거하고 실험한 결과, 두 요소 모두 성능에 결정적인 역할을 하는 것으로 나타났습니다. 특히 온-정책 학습을 제거했을 때(오프-정책으로 변경) 성능 하락이 가장 컸으며, 이는 노출 편향 문제가 얼마나 심각한지를 보여줍니다.

비판적 평가

강점

  1. 노출 편향의 근본적 해결: 온-정책 학습을 통해 학습과 추론 간의 불일치를 해소하여 실질적인 성능 향상을 이끌어냈습니다.
  2. 높은 정확도와 일반화 성능: 역 KL 발산을 통해 더 결정적이고 정확한 출력을 유도하여 다양한 태스크에서 SOTA 성능을 달성했습니다.
  3. 모델 경량화의 가능성: 교차 크기 증류에서 뛰어난 효율을 보여주어, 거대 모델의 능력을 작고 경제적인 모델로 이전하는 효과적인 방법을 제시했습니다.

한계점과 개선 방향

  1. 훈련 비용 증가: OPCD는 학습 스텝마다 학생 모델의 샘플링과 교사 모델의 순방향 연산(forward pass)이 필요합니다. 이는 고정된 데이터셋으로 학습하는 오프-정책 방식보다 계산 비용이 더 높습니다.
  2. 샘플링의 비효율성: 학습 초기, 학생 모델은 무작위적이고 품질 낮은 샘플을 생성할 수 있어 학습이 비효율적일 수 있습니다. 초기에는 오프-정책 학습을 병행하거나 샘플링 전략을 개선할 필요가 있습니다.
  3. 교사 모델에 대한 의존성: OPCD의 성능은 교사 모델의 성능에 크게 좌우됩니다. 교사 모델이 생성하는 분포의 품질이 낮다면 학생 모델의 학습도 제한될 수 있습니다.

향후 연구 방향

  1. 계산 효율성 개선: 훈련 비용을 줄이기 위해 교사 모델 호출 빈도를 줄이거나, 생성된 궤적을 재사용하는 등의 연구가 필요합니다.
  2. 강화학습과의 결합: 태스크의 최종 보상(reward)을 OPCD 손실 함수와 결합하여, 단순히 교사를 모방하는 것을 넘어 더 나은 성능을 달성하도록 유도하는 연구를 진행할 수 있습니다.
  3. 다양한 응용 분야 확장: 코드 생성, 대화형 AI, 로봇 제어 등 더 복잡하고 동적인 환경에 OPCD를 적용하여 그 효과를 검증할 필요가 있습니다.

실무 적용 가이드

  1. 모델 특성 내재화: 특정 페르소나, 말투, 출력 형식을 지키도록 하는 시스템 프롬프트를 모델에 내재화하고 싶을 때 OPCD는 매우 효과적인 솔루션이 될 수 있습니다. 매번 긴 프롬프트를 입력할 필요가 없어 API 비용과 응답 시간을 줄일 수 있습니다.
  2. 도메인 지식 학습: 법률, 의료 등 특정 도메인의 전문가 지식을 담은 문서를 컨텍스트로 활용하여, 범용 모델을 저비용으로 도메인 특화 모델로 만들 수 있습니다.
  3. 자원 고려: OPCD는 훈련 시 두 개의 모델을 동시에 실행해야 하므로 충분한 계산 자원(GPU 메모리 등)이 필요합니다. 프로젝트의 예산과 인프라를 고려하여 적용 여부를 결정해야 합니다.

결론

OPCD는 온-정책 학습역 KL 발산이라는 두 가지 핵심 아이디어를 결합하여 기존 컨텍스트 증류의 한계를 극복한 혁신적인 프레임워크입니다. 모델이 일회성 프롬프트 정보를 영구적인 지식으로 내재화하도록 하여, 더 작고 효율적이면서도 강력한 언어 모델을 만드는 새로운 길을 열었습니다. 이는 단순히 성능을 높이는 것을 넘어, 모델이 지속적으로 경험을 통해 학습하고 발전하는 '경험적 학습'의 가능성을 보여준 중요한 연구입니다.

참고 자료

  • 논문 원본: Zhang, Y. et al. (2024). On-Policy Context Distillation for Language Models. arXiv:2402.12275