인공지능/논문 리뷰 or 진행

SHEARED LLAMA: ACCELERATING LANGUAGEMODEL PRE-TRAINING VIA STRUCTURED PRUNING

이게될까 2025. 8. 4. 22:17
728x90
728x90

https://arxiv.org/abs/2310.06694

 

Sheared LLaMA: Accelerating Language Model Pre-training via Structured Pruning

The popularity of LLaMA (Touvron et al., 2023a;b) and other recently emerged moderate-sized large language models (LLMs) highlights the potential of building smaller yet powerful LLMs. Regardless, the cost of training such models from scratch on trillions

arxiv.org

LLaMA와 같은 모델들은 작지만 강력한 LLM을 보여준다. 그래도 훈련 비용은 여전히 엄청나니 Pruning, Dynamic batch Loading을 통해 모델의 파라미터를 줄이고, 동일한 파라미터 수의 모델보다 높은 성능을 보여준다.

학습에 필요한 토큰 수는 엄청 줄고, 속도는 빨라지며, 정확도는 올라가는 것을 볼 수 있다.

모델의 파라미터를 직접 학습하는게 아니라 어떤 구조를 남기고, 없앨지를 결정하는 바이너리 마스크(z)를 학습하는 방식이다.

적은 수의 파라미터와, 학습 토큰을 가지고, 다른 모델들을 이기는 모습을 보여준다.

 점수 자체는 떨어지지만 원래 모델을 줄이고, 학습 량도 얼마 되지 않으니 엄청난 성과를 보여줬다고 생각한다.

선호도에서도 타 모델들을 이기는 모습을 보여준다.

Dynamic Batch Loading이 효과적임을 보여준다.

 

 

🔍 문제 상황 - LLaMA 등 고성능 LLM의 training cost가 매우 높음 (수조 토큰 필요)
- 작은 LLM을 scratch로 학습하려 해도 엄청난 연산 자원 소모
- 기존 structured pruning 기법은 성능 저하 or 비정형 구조 생성으로 실용성 부족
🎯 연구 목표 → 기존 대형 LLM을 기반으로
미리 정의된 구조로 효율적 pruning + 도메인 최적화된 재학습을 통해
작지만 강력한 LLM 생성
🧠 방법론 요약 LLM-Shearing = Targeted Structured Pruning + Dynamic Batch Loading
1. Targeted Pruning: 지정한 구조(예: Pythia-1.4B)를 목표로 pruning mask 학습
2. Dynamic Batch Loading: 도메인별 손실 감소율 기반 데이터 비율 동적 조정
→ pruning과 pretraining 모두에 적용
🧩 Pruning 상세 기법 - 레이어, 헤드 수, hidden dim, intermediate dim까지 모두 pruning 가능
- Hard Concrete 분포 기반 mask 학습
- Lagrangian constraint로 구조 제약 만족
- pruning 후, top-k만 유지하여 최종 구조 구성
🔄 Dynamic Batch Loading - 도메인별 reference loss 대비 손실 차이(∆) 측정
- w_i ∝ w_i^{t−m}⋅exp⁡(Δ_i)방식으로 배치 도메인 비율 업데이트
- Scaling Law로 예측된 domain reference loss 사용
🧪 실험 대상 - Source: LLaMA2-7B
- Target: Sheared-LLaMA-1.3B, 2.7B
- 비교 대상: OPT, Pythia, INCITE, TinyLLaMA, OpenLLaMA 등
📚 학습 데이터 - RedPajama (LLaMA1 재현 데이터)
도메인 구성: CC, C4, GitHub, Books, Wiki, ArXiv, StackExchange
- Pruning: 0.4B tokens
- Continued pretraining: 50B tokens
- 총 학습량: 50.4B
⚙️ 학습 세팅 - 모델 학습: Fully Sharded Data Parallel (FSDP)
- Optimizer: AdamW
- LR: 1.0 (mask), 1e-4 (θ), cosine decay
- Batch: 131K (pruning), 1M tokens (pretrain)
- Eval interval: 50 (pruning), 400 (CT)
- FlashAttention v1 사용
📊 결과 요약 - 1.3B 모델이 300B~3T token 사용 모델들보다 더 높은 성능
- 2.7B 모델이 OpenLLaMA 3B, INCITE-3B보다 성능 우위
- Instruction tuning 성능도 우수 (GPT-4 기준 win-rate ↑)
- 전체 학습량은 scratch 대비 3% 수준
🏅 기여 - 최초로 LLM에 대해 target 구조 명시적 pruning 가능하게 함
- 도메인 불균형을 고려한 dynamic batch loading 제안
- 소형 모델 생성의 비용-효율 최적 사례 제시
- 모델 및 코드 공개 (GitHub)
⚠️ 한계 - 고차원 도메인(C4 등) 성능 회복 어려움
- LLaMA2 기반 한정 → 다른 구조 적용은 미확인
- scaling reference를 추정하는 dev set에 의존
- 13B 이상 대형 소스 모델 실험 미수행
더보기

📘 Sheared LLaMA 요약표

문제 상황 - LLaMA와 같은 7B급 모델도 수조 토큰 단위의 학습에 막대한 비용 발생
- 기존 structured pruning은 일반적으로 inference-friendly한 아키텍처를 생성하지 않거나 성능이 떨어짐
- 작은 LLM을 scratch로 학습하는 것은 비효율적
연구 질문 "Pre-trained 대형 모델을 기반으로, 훨씬 적은 연산량으로 강력한 소형 LLM을 만들 수 있는가?"
주요 기법 🔧 LLM-Shearing:
1. Targeted Structured Pruning: 미리 지정된 목표 구조로 레이어, attention head, hidden dim 등을 제거
2. Dynamic Batch Loading: pruning 및 pre-training 중 도메인별 손실 감소 속도를 기반으로 동적으로 batch 구성
모델 구성 - Source: LLaMA2-7B
- Target: Sheared-LLaMA-1.3B, 2.7B (Pythia-1.4B, INCITE-3B 구조를 목표로 설정)
- 구조 예시: 2.7B 모델은 32 layer, hidden 2560, intermediate 6912, head 20
학습 데이터 - 사용 데이터: RedPajama (7개 도메인 포함: C4, GitHub, Wiki 등)
- Pruning: 0.4B tokens
- Continued Pretraining: 50B tokens
학습법 - Pruning: ℓ₀ regularization 기반 mask 학습 + Lagrangian constraint 사용
- Dynamic Batch Loading: 도메인별 손실 감소율 기반으로 batch 구성 비율을 동적으로 조정 (Scaling reference 활용)
비교 대상 OPT, Pythia, TinyLLaMA, OpenLLaMA, INCITE 등 다양한 1.3B~3B 모델
실험 결과 Downstream (SciQ, ARC, MMLU 등 11개 벤치마크)
→ Sheared-LLaMA-1.3B/2.7B는 같은 크기의 open-source 모델보다 우수

Instruction Tuning
→ GPT-4 평가 기준에서 Pythia/INCITE/OpenLLaMA 대비 높은 win-rate

Efficiency
→ 전체 compute는 scratch 학습 대비 3% (1/32) 수준
추가 실험 - 기존 pruning 기법(CoFiPruning, LLM-Pruner)과 비교: 추론 속도 + 구조 효율성 우위
- Easy 도메인 제외 vs Dynamic batch loading: 후자가 훨씬 효과적
기여 - 대형 모델에서의 inference-friendly 구조를 가진 pruning 방법 제시
- 도메인 손실 편향 보정을 위한 동적 배치 로딩 기법 제안
- 단 50B 토큰으로 scratch training 대비 높은 효율성과 성능 증명
한계 - 소스 모델 및 pre-training 데이터가 공개된 오픈 모델에 한정됨
- 아직 13B 이상 대형 모델에는 실험 적용되지 않음
- C4 등 entropy 높은 도메인의 손실 회복은 어렵다는 한계 존재

🔍 구조적 Pruning 최적화 수식 (요약)

  • 각 구조 단위 (layer, head, hidden, intermediate)에 대해 mask 변수 z 를 학습
  • 목표 구조와 일치하도록 다음과 같은 constraint loss를 함께 학습:
  • 전체 loss는 다음과 같은 min-max 최적화 문제:

🧠 Dynamic Batch Loading 핵심 아이디어

  • 도메인별 validation loss를 주기적으로 측정하여, scaling law 기반 reference loss와 비교
  • Δ_i=max⁡(ℓ_i − ℓ_{ref,i}, 0) 로 도메인별 과적합/부족 보정
  • 다음 batch에서 도메인 i의 sampling 확률을 ∝exp⁡(Δ_i) 로 조정

 

 


🔗 관련 연구 정리: Sheared LLaMA와 연결되는 기술 흐름

  대표 연구 핵심 내용 및 관계
1. Structured Pruning(모델 구조 단위 삭제) CoFiPruning - Task-specific pruning에서 layer/head 단위 structured pruning 제안
-그러나 layer 구성이 비정형 (non-uniform) → inference 비효율적
➡️ Sheared LLaMA는 target shape을 명시하여 uniform 구조 생성
  LLM-Pruner - LLM 대상 structured pruning 기법
- hidden dim pruning은 못함, 구조 비효율 + 속도 느림
➡️ Sheared LLaMA는 모든 차원 pruning 지원 + 더 빠름
  SparseGPT - One-shot unstructured pruning 기반 LLM 압축
- 가중치 sparsity 기반 성능 유지 가능
➡️ Sheared LLaMA는 dense 구조로 유지되며, inference-friendly
2. Semi-Structured / Unstructured Pruning Movement Pruning - fine-tuning 중 pruning mask를 이동시켜 적응
- task-specific 세팅에서 유용
➡️ 일반 LLM pretraining과는 방향이 다름
  Wanda - Semi-structured pruning (2:4, 4:8 sparsity)으로 모델 압축
- 50% sparsity 기준 hardware acceleration 가능
➡️ Sheared LLaMA와 비교 시, speed는 비슷하나 구조 제한 있음
3. Distillation 기반 소형화 DistilBERT
TinyBERT
- Teacher → Student 모델로 soft target을 학습
➡️ Pretraining보다는 task-specific tuning에 특화됨
  Gopher Distill - 대규모 LLM distillation 실험
- 학습 효율이 좋지만 Teacher inference 비용이 큼
➡️ Sheared LLaMA는 teacher 없이 pruning + self-training
4. Efficient Pretraining Doremi - 데이터 도메인 별로 proxy model을 학습하여 최적 mixture weight를 찾아 학습 효율 향상
➡️ Sheared LLaMA는 proxy model 없이 실시간 loss 기반 domain 비중 조정
  Scaling Laws - 모델 크기와 데이터 양 간의 loss scaling 관계 분석
➡️ Sheared LLaMA는 이를 기반으로 reference loss 추정하여 dynamic loading 기준 설정
  SlimPajama - 효율적인 pretraining 데이터 구성법 제안
➡️ TinyLLaMA에서 사용된 dataset으로 비교 baseline 역할 수행
5. Layer Dropping / Dynamic Depth Once-for-All
Progressive Layer Dropping
- 다양한 depth/width로 학습하여 subnetwork 선택 가능
➡️ pruning과 유사하지만 target structure 지정과는 다름
6. Instruction Tuning 비교 대상 Alpaca
OpenLLaMA
TinyLLaMA
- Open-source 모델 기반 instruction tuning 벤치마크
➡️ Sheared LLaMA는 이들과 크기/학습량 동등 조건에서 성능 우위 증명

📈 연관 기술 흐름도 (개념적 정리)

graph TD
  A1[GPT-like LLMs Pretraining] --> B1[Scaling Laws]
  A1 --> B2[Task-specific Pruning (Han et al., 2016)]
  B2 --> C1[Structured Pruning (CoFiPruning)]
  B2 --> C2[Unstructured (SparseGPT, Wanda)]

  A1 --> B3[Distillation (DistilBERT, Gopher Distill)]
  A1 --> B4[Efficient Pretraining (Doremi, SlimPajama)]
  A1 --> B5[Dynamic Training (Layer Dropping)]

  C1 --> D1[Sheared LLaMA (Targeted Structured Pruning)]
  B4 --> D1
  B1 --> D1

  D1 --> E1[Dynamic Batch Loading]
  D1 --> E2[Small LLMs w/ minimal compute]

📚 확장 가능 관련 연구 주제

Pruning + Quantization Quantization-aware pruning으로 추가 효율성 확보 가능 (ex. QLoRA + Sheared-LLaMA)
Multi-Objective Structured Pruning 속도 + 파라미터 수 + 다운스트림 성능 등 다목적 최적화 고려
Instruction 기반 데이터 선택 Instruction tuning 특화 도메인만을 대상으로 pruning & pretraining 조정
Continual Pretraining vs Pruning Initialization 기존 소형 모델을 단순히 더 학습시키는 것과 pruning 기반 전이 초기화 비교 (논문에서도 실험됨)

 


🔧 Sheared LLaMA의 방법론 개요

Sheared LLaMA는 기존 대형 LLM을 기반으로, 소형 모델을 효율적으로 생성하는 방법론입니다. 전체 과정은 크게 두 단계로 구성됩니다:

  1. Targeted Structured Pruning: 기존 모델(LLaMA2-7B)을 정해진 구조로 “잘라내기”
  2. Continued Pretraining with Dynamic Batch Loading: pruned 모델을 효율적으로 다시 학습시키기

① 🎯 Targeted Structured Pruning (구조 지정 기반 압축)

❓ 목적

  • 기존 structured pruning은 비정형 구조를 만들어 inference 비효율 발생
  • Sheared LLaMA는 "미리 지정한 목표 구조(target architecture)"로 정확하게 pruning함
    • 예: LLaMA2-7B → Pythia-1.4B 구조로 pruning

🧩 세부 구조 단위 (Substructures)

  pruning 대상 예시 mask 변수
Global Layer 수 z_{layer} ∈ R^{L_S}
Global Hidden dim z_{hidden} ∈ R^{d_S}
Local Head 수 z_{head}∈R^{L_S × H_S}
Local FFN 중간 dim z_{int}∈R^{L_S×m_S}

각 mask 값은 z = 1이면 유지, z = 0이면 제거


🧮 수식 기반 학습: 목표 구조를 향한 제약 최적화

  • mask 변수는 hard concrete distribution으로 연속적으로 학습됨
    (최종적으로는 binary mask로 수렴)
  • 각 구조에 대해 Lagrange 제약조건을 적용:

예: attention head 수가 H_T가 되도록 하기 위한 loss

  • 전체 pruning 목적함수:
  • pruning 후에는 최종 binary mask를 기준으로 상위 점수 구조만 유지

🔧 예시 구조 변화

Source: 32-layer, 4096 hidden, 11008 intermediate → Target: 24-layer, 2048 hidden, 5504 intermediate
→ pruning 단계에서 각 레이어의 submodule들을 선택적으로 제거하여 이 구조로 재조립

② 🔄 Dynamic Batch Loading (동적 도메인별 학습 비율 조정)

❓ 문제 인식

  • Pruned 모델은 도메인마다 잔존 지식 정도가 다름
    • GitHub: 손실 낮음 (low-entropy), C4: 손실 큼 (high-entropy)
    • → 동일한 데이터 비율로 학습하면 비효율적

📈 해결 아이디어: 도메인 손실 감소율 기반으로 batch 구성 비율 동적 조정

  • 도메인 D_i에 대해 손실 차이 Δi = max⁡(ℓ_i − ℓ_i^{ref} ,0)
  • 비율 갱신:
  • m step마다 validation loss 기준으로 도메인 비율 갱신

📋 알고리즘 요약 (Algorithm 1)

for each training step t:
    if t mod m == 0:
        for each domain i:
            compute validation loss ℓ_t[i]
            compute delta ∆_t[i] = max(ℓ_t[i] - ℓ_ref[i], 0)
        update weights w_t ∝ w_{t-m} * exp(∆_t)

    sample a batch using w_t
    update model weights (either prune loss or LM loss)
  • pruning 중에도 사용되고, continued pretraining에도 계속 사용됨
  • reference loss는 scaling law 기반 추정값 or source model 기준값

⚖️ 두 방법의 상호작용

단계  내용
1단계 Targeted Pruning→ 지정된 구조로 pruning (비용은 많지만 0.4B token 내 소화 가능)
2단계 Continued Pretraining→ Dynamic Batch Loading으로 token 50B 사용해 균형 있게 성능 회복

🧠 핵심 설계 철학

  • 기존 LLM이 어떤 도메인에 더 많이 학습됐는지 반영하는 방식으로 pruning & 학습 조정
  • pruning을 단순한 “압축”이 아닌 “전이 초기화”로 해석함
  • 기존 distillation이나 scratch-training보다 훨씬 적은 비용으로 효율적 소형 모델 생성 가능

 

 


✅ 1. 결과 (Results)

🔬 실험 모델

  • Source: LLaMA2-7B (pretrained, 2T tokens)
  • Target: Sheared-LLaMA-1.3B / 2.7B
  • Pruning + continued pretraining: 총 50.4B tokens 사용

📊 성능 결과 요약

Downstream Tasks (11개) Sheared-LLaMA-1.3B 50B 51.0%
  OPT-1.3B 300B 48.2%
  Pythia-1.4B 300B 48.9%
  TinyLLaMA-1.1B 3T 50.0%
  Sheared-LLaMA-2.7B 50B 56.7%
  INCITE-3B 800B 54.7%
  OpenLLaMA-3B-v2 1T 55.7%

요약:

  • Sheared-LLaMA는 학습량이 훨씬 적음에도 SOTA 오픈모델들과 동등~우위 성능 달성
  • 특히 1.3B 모델이 3T tokens 사용한 TinyLLaMA-1.1B보다 성능 우수
  • Instruction Tuning에서도 GPT-4 평가 기준 높은 승률 기록

⚡ 효율성 결과

  • LLaMA2-7B 대비 계산량 3%만으로 훈련
  • 동일 파라미터 수 기준 추론 속도 향상 (예: 1.3B 기준 58 tokens/sec vs CoFi 51)

🧠 2. 결론 (Conclusion)

"Strong small language models can be built from large pre-trained ones via structured pruning and minimal additional compute."

핵심 결론 요약

  • Pruning + 재학습 방식은 scratch training 대비 훨씬 비용 효율적
  • 기존 pretraining 구조 정보를 적극 활용하여 추론 친화적 구조 생성 가능
  • 동적 배치 로딩은 pruning 이후 도메인 불균형 문제 해결에 매우 효과적

⚠️ 3. 한계 (Limitations)

도메인 편중 특정 도메인(C4 등 high-entropy)의 성능 회복이 어렵거나 느림
확장성 미검증 LLaMA2-7B 이상 사이즈(13B, 70B 등)에 대해서는 실험하지 못함
데이터 의존성 RedPajama 기반 학습에 한정됨 → 도메인 다양성이 부족할 수 있음
Dynamic Loading 평가셋 고정 Reference loss 예측에 사용하는 dev set이 고정되어 있음 (진짜 성능 대표성 부족 우려)

🏅 4. 기여 (Contributions)

Pruning 기법 학습 가능한 pruning mask를 이용해 임의 구조(target config)로 transformer pruning 가능하도록 한 최초의 시도
효율적 pretraining 단 50B tokens만으로 SOTA 수준의 1.3B / 2.7B 모델 생성
Dynamic Batch Loading 도메인별 손실 감소율 기반 동적 데이터 구성 기법 제안 → 성능 및 효율 향상
공개 모델 Sheared-LLaMA-1.3B / 2.7B 모델과 코드 공개 (GitHub)
후속 연구 기반 제공 압축 + 전이 초기화 관점의 새로운 LLM scaling 전략 제안

 

 

 

 

728x90