Training LLMs to be Better Text Embedders through Bidirectional Reconstruction
https://arxiv.org/abs/2509.03020 Training LLMs to be Better Text Embedders through Bidirectional ReconstructionLarge language models (LLMs) have increasingly been explored as powerful text embedders. Existing LLM-based text embedding approaches often lever
yoonschallenge.tistory.com

LM casual Model을 Bidirectional 하게 바꿔야 하는데 엄청 단순합니다 ㅎㅎ
다양한 곳에서 사용할 수 있을 것 같아요
"""
Qwen2 양방향(Bidirectional) 변형 모듈
- 목적: 원래는 causal(미래 토큰을 볼 수 없는) self-attention인 Qwen2를, 임베딩 추출에 적합하도록
양방향 self-attention으로 바꿉니다.
- 핵심 변경점:
1) Attention 모듈의 `is_causal` 플래그를 False로 설정한 파생 클래스(ModifiedQwen2Attention 계열)를 사용
2) 디코더 레이어를 해당 Modified Attention을 쓰는 `ModifiedQwen2DecoderLayer`로 교체
3) 이를 쌓아서 `Qwen2BiModel`을 구성하고, LM 헤드를 붙인 `Qwen2BiForMNTP`를 제공합니다.
즉, 큰 구조 변경 없이, 각 레이어의 어텐션에서 causal 마스크를 적용하지 않도록 하여 전체 시퀀스를
양방향으로 볼 수 있게 만드는 방식입니다. MLP/Norm 등은 동일하며, 가중치 초기화/후처리는 기존 로직을 따릅니다.
"""
from typing import List, Optional, Tuple, Union
import torch
from transformers import Qwen2Model, Qwen2ForCausalLM, Qwen2PreTrainedModel, Qwen2Config
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.cache_utils import Cache, DynamicCache
from transformers.models.qwen2.modeling_qwen2 import (
Qwen2DecoderLayer,
Qwen2RMSNorm,
Qwen2Attention,
Qwen2FlashAttention2,
Qwen2SdpaAttention,
Qwen2MLP,
)
from torch import nn
from transformers.utils import logging
from .attn_mask_utils import (
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from peft import PeftModel
logger = logging.get_logger(__name__)
class ModifiedQwen2Attention(Qwen2Attention):
"""
기본 Qwen2Attention에서 `is_causal=False`로만 바꾼 경량 파생 클래스.
- 이 설정으로 causal(자동 회귀) 마스크 대신 비-인과(양방향) 어텐션이 사용됩니다.
- 나머지 로직은 동일하며, 제공되는 attention_mask(패딩 등)는 그대로 존중됩니다.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.is_causal = False
class ModifiedQwen2FlashAttention2(Qwen2FlashAttention2):
"""FlashAttention2 구현체의 비-인과(양방향) 버전"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.is_causal = False
class ModifiedQwen2SdpaAttention(Qwen2SdpaAttention):
"""SDPA 구현체의 비-인과(양방향) 버전"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.is_causal = False
# Qwen2가 내부적으로 선택하는 어텐션 구현(eager/flash_attention_2/sdpa)에 맞춰
# causal이 해제된(양방향) Attention 클래스를 매핑합니다.
QWEN2_ATTENTION_CLASSES = {
"eager": ModifiedQwen2Attention,
"flash_attention_2": ModifiedQwen2FlashAttention2,
"sdpa": ModifiedQwen2SdpaAttention,
}
class ModifiedQwen2DecoderLayer(Qwen2DecoderLayer):
"""
기본 Qwen2DecoderLayer를 복제하되, self_attn만 비-인과 Attention으로 교체합니다.
- 레이어 정규화/MLP는 동일
- config._attn_implementation에 따라 위 매핑된 Modified Attention을 사용
"""
def __init__(self, config: Qwen2Config, layer_idx: int):
nn.Module.__init__(self)
self.hidden_size = config.hidden_size
self.self_attn = QWEN2_ATTENTION_CLASSES[config._attn_implementation](
config=config, layer_idx=layer_idx
)
self.mlp = Qwen2MLP(config)
self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Qwen2RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
class Qwen2BiModel(Qwen2Model):
_no_split_modules = ["ModifiedQwen2DecoderLayer"]
def __init__(self, config: Qwen2Config):
"""
Qwen2의 기본 구조를 따르면서, 모든 디코더 레이어를 양방향 Modified 레이어로 대체한 모델.
- 임베딩/정규화/파이프라인 초기화는 Qwen2 표준을 유지
- 결과적으로 모든 레이어에서 causal 마스크 없이 전체 토큰 상호 주의가 가능
"""
Qwen2PreTrainedModel.__init__(self, config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(
config.vocab_size, config.hidden_size, self.padding_idx
)
self.layers = nn.ModuleList(
[
ModifiedQwen2DecoderLayer(config, layer_idx)
for layer_idx in range(config.num_hidden_layers)
]
)
self._attn_implementation = config._attn_implementation
self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
class Qwen2BiForMNTP(Qwen2ForCausalLM):
"""
양방향 Qwen2BiModel 위에 LM Head를 얹은 래퍼.
- MNTP(문장 임베딩 학습 등) 같은 목적의 손실 계산/헤드 결합을 위해 사용
- PEFT(LoRA 등) 연동을 위한 getter/setter/helper 제공
"""
def __init__(self, config):
Qwen2PreTrainedModel.__init__(self, config)
self.model = Qwen2BiModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
# getter for PEFT model
def get_model_for_peft(self):
"""PEFT 적용 시 내부 본체 모델을 반환"""
return self.model
# setter for PEFT model
def set_model_for_peft(self, model: PeftModel):
"""PEFT 적용 후 본체 모델을 교체"""
self.model = model
# save the PEFT model
def save_peft_model(self, path):
"""PEFT 가중치만 저장"""
self.model.save_pretrained(path)
무엇이 어떻게 양방향으로 바뀌었나요?
원래 Qwen2의 Self-Attention은 causal(자동회귀) 모드라서 현재 토큰이 미래 토큰을 볼 수 없습니다. 이 파일에서는 그 “인과성(causality)”만 해제해서, 임베딩을 만들 때 문장 전체를 동시에 참조할 수 있게 했습니다. 바뀐 점은 크게 2가지입니다.
Attention 모듈의 인과성 끄기
ModifiedQwen2Attention, ModifiedQwen2FlashAttention2, ModifiedQwen2SdpaAttention에서 공통적으로 self.is_causal = False로 설정했습니다.
즉, 사용하는 어텐션 구현(eager/flash_attention_2/sdpa)에 맞는 “비-인과” 버전 클래스를 매핑해 사용합니다.
디코더 레이어 교체
ModifiedQwen2DecoderLayer에서 원본 Qwen2DecoderLayer 대신 위의 “비-인과” 어텐션을 쓰도록 self.self_attn만 바꿉니다. 나머지 MLP, LayerNorm 등은 동일합니다.
이 레이어를 여러 층 쌓아 Qwen2BiModel을 구성합니다. 따라서 모든 층에서 causal mask가 아닌 양방향 self-attention이 작동합니다.
즉, 구조를 크게 바꾸지 않고 각 Attention 레이어의 “causal 여부”만 False로 바꾼 것이 핵심입니다. 그래서 외형상 코드 변경이 많지 않아 보이지만, 어텐션의 동작 방식이 달라지므로 결과적으로 임베딩 추출 시 문맥을 양방향으로 활용할 수 있게 됩니다.
주석 추가 내용
파일 상단 모듈 주석: 목적과 핵심 변경점(양방향화)을 요약
각 Modified Attention 클래스: “is_causal=False” 의미 설명
QWEN2_ATTENTION_CLASSES 매핑: 어떤 구현에서 어떤 Modified 클래스를 쓰는지 설명
ModifiedQwen2DecoderLayer: self_attn만 교체한다는 점과 나머지는 동일함을 명시
Qwen2BiModel: 기존 Qwen2 구조를 유지하면서 레이어만 교체했다는 점 설명
Qwen2BiForMNTP: LM Head 연결과 PEFT 헬퍼 목적 설명
참고
현재 편집은 문서화(주석/Docstring)만 수행했고 기능 변화는 없습니다.
편집 후 정적 임포트 경고가 보일 수 있지만, 이는 로컬 환경에 torch/peft가 설치되어 있지 않을 때 발생하는 편집기 경고일 가능성이 큽니다. 실행 환경에 해당 패키지가 설치되어 있으면 문제 없이 동작합니다.
이제 모델에 관련된 코드입니다
패딩을 맞추고, instruction에 Pooling이 들어가지 않도록 합니다
여기선 Instruction과 test 분리를 !@#$%^&*()를 통해 합니다
ex) 너는 요약해 !@#$%^&*() 나는 AI다
Instruction 너는 요약해
text 나는 AI다
결국 마지막에는 저 특수기호가 빠져서 들어가게 됩니다.
"""
LLM2Vec_q2d_d2q 모듈
- 목적: Causal LM(예: Llama, Mistral, Gemma, Qwen2)의 은닉 상태를 풀링하여 문장/문서 임베딩을 생성합니다.
- 구성 개요:
* from_pretrained: HF 허브/로컬 경로에서 기반 모델 및 토크나이저 로딩, (선택) PEFT 어댑터 병합
* tokenize + embed_mask: 입력 중 임베딩에 반영할 “본문” 영역만 마스킹으로 선택
* get_pooling: mean/weighted_mean/eos 등 다양한 풀링 방식 지원
* encode/_encode: 배치 인코딩 및 다중 GPU 병렬 처리(왼쪽 패딩 가정)
- 참고: 일부 Instruct 계열 모델은 프롬프트 포맷팅이 필요하며, 해당 로직을 prepare_for_tokenization 계열 함수에서 처리합니다.
"""
import json
import logging
import os
from functools import partial
from typing import Dict, List, Optional, Union
import numpy as np
import torch
import torch.multiprocessing as mp
from peft import PeftModel
from torch import Tensor, device, nn
from tqdm.autonotebook import tqdm, trange
from transformers import (
AutoModelForCausalLM,
AutoModel,
AutoConfig,
PretrainedConfig,
AutoTokenizer,
LlamaConfig,
MistralConfig,
GemmaConfig,
Qwen2Config,
)
from .models import (
MistralBiModel,
LlamaBiModel,
GemmaBiModel,
Qwen2BiModel,
)
logger = logging.getLogger(__name__)
def batch_to_device(batch, target_device: device):
"""
PyTorch 배치를 지정한 장치(CPU/GPU)로 이동합니다.
- 텐서(Tensor) 타입의 값들에 대해 .to(target_device)를 호출합니다.
- 비-텐서 값은 그대로 둡니다.
"""
for key in batch:
if isinstance(batch[key], Tensor):
batch[key] = batch[key].to(target_device)
return batch
class LLM2Vec_q2d_d2q(nn.Module):
"""
Causal LM의 은닉 상태를 문장 임베딩으로 변환하는 래퍼 모듈.
- model: Causal LM의 backbone(언어모델 본체)로, .model (decoder)만 사용합니다.
- lm_head: 원래의 언어모델 헤드(로스 계산 등 추후 확장을 위해 보관)
- tokenizer: 토크나이저. 왼쪽 패딩(left padding)과 EOS를 pad 토큰으로 사용하는 설정을 기본으로 합니다.
- pooling_mode: 임베딩 풀링 방식(mean, weighted_mean, eos_token/last_token, bos_token)
- skip_instruction: 임베딩 계산 시 안내/프롬프트 영역은 무시하고 본문만 사용하도록 하는 플래그
- max_length / doc_max_length: 토크나이즈 및 본문 길이 제한
핵심 아이디어:
- 입력 문자열을 "지시문 !@#$%^&*() 본문" 형태로 합쳐서 토크나이즈한 뒤, embed_mask를 통해 본문 토큰만
attention_mask로 취급하여 풀링에 반영합니다.
"""
def __init__(
self,
lm_head,
model: AutoModel,
tokenizer: AutoTokenizer,
pooling_mode: str = "mean",
max_length: int = 512,
doc_max_length: int = 400,
skip_instruction: bool = True,
):
super().__init__()
self.model = model
self.tokenizer = tokenizer
self.pooling_mode = pooling_mode
self.skip_instruction = skip_instruction
self.max_length = max_length
self.doc_max_length = doc_max_length
self.config = model.config
self.lm_head = lm_head
@classmethod
def _get_model_class(cls, config_class_name, enable_bidirectional):
"""
양방향(bidirectional) 임베딩 모델을 사용할지 여부에 따라 적절한 모델 클래스를 반환합니다.
- enable_bidirectional=False: 일반 AutoModel 사용
- True인 경우, 모델 config 타입에 맞춘 BiModel 클래스를 선택
"""
if not enable_bidirectional:
return AutoModel
if config_class_name == "MistralConfig":
return MistralBiModel
elif config_class_name == "LlamaConfig":
return LlamaBiModel
elif config_class_name == "GemmaConfig":
return GemmaBiModel
elif config_class_name == "Qwen2Config":
return Qwen2BiModel
else:
raise ValueError(
f"{config_class_name} is not supported yet with bidirectional models."
)
@classmethod
def from_pretrained(
cls,
base_model_name_or_path,
peft_model_name_or_path=None,
merge_peft=False,
enable_bidirectional=True,
**kwargs,
):
"""
사전학습 모델과 토크나이저를 로드하고, (선택적으로) PEFT 어댑터를 병합하여 인코더를 구성합니다.
- base_model_name_or_path: HF 모델 이름 또는 로컬 경로
- peft_model_name_or_path: LoRA 등 PEFT 어댑터 경로 (선택)
- merge_peft=True인 경우 어댑터를 가중치에 병합 후 언로드
- enable_bidirectional=True: 양방향 BiModel로 래핑(현재는 아래 import된 BiModel들을 통해 지원)
- 추가 인자(kwargs): AutoModelForCausalLM.from_pretrained에 그대로 전달
반환: LLM2Vec_q2d_d2q 인스턴스 (lm_head는 보존, model은 .model(decoder)로 교체)
"""
# pop out encoder args
keys = ["pooling_mode", "max_length", "doc_max_length", "skip_instruction"]
encoder_args = {
key: kwargs.pop(key, None) for key in keys if kwargs.get(key) is not None
}
tokenizer = AutoTokenizer.from_pretrained(base_model_name_or_path)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
config = AutoConfig.from_pretrained(base_model_name_or_path)
config_class_name = config.__class__.__name__
model = AutoModelForCausalLM.from_pretrained(base_model_name_or_path, **kwargs)
if os.path.isdir(base_model_name_or_path) and os.path.exists(
f"{base_model_name_or_path}/config.json"
):
with open(f"{base_model_name_or_path}/config.json", "r") as fIn:
config_dict = json.load(fIn)
config = PretrainedConfig.from_dict(config_dict)
model.config._name_or_path = config._name_or_path
# For special case where config.json and adapter weights are in the same directory
if hasattr(model, "peft_config"):
model = PeftModel.from_pretrained(
model,
base_model_name_or_path,
)
model = model.merge_and_unload()
if peft_model_name_or_path is not None:
model = PeftModel.from_pretrained(
model,
peft_model_name_or_path,
)
if merge_peft:
model = model.merge_and_unload()
config = {}
config_addr = (
peft_model_name_or_path
if peft_model_name_or_path is not None
else base_model_name_or_path
)
if os.path.exists(f"{config_addr}/llm2vec_config.json"):
with open(f"{config_addr}/llm2vec_config.json", "r") as fIn:
llm2vec_config = json.load(fIn)
config.update(llm2vec_config)
for key, value in encoder_args.items():
config[key] = value
# 주의: 후속 q2d/d2q 손실 계산을 위해 causal-llm의 lm_head를 보존하고,
# 본체 모델은 .model(decoder)만 사용합니다.
lm_head = model.lm_head
model = model.model
return cls(lm_head=lm_head, model=model, tokenizer=tokenizer, **config)
def prepare_for_tokenization_llama_3_instruct(model, text):
"""
Llama 3 Instruct 포맷팅.
- 주의: 메서드 시그니처가 (self가 아닌) model, text 형태로 정의되어 있으나,
바운드 메서드로 호출되므로 첫 번째 인자로 self가 전달됩니다.
여기서는 self를 사용하지 않으므로 이름만 model로 되어 있습니다.
"""
text = (
"<|start_header_id|>user<|end_header_id|>\n\n" + text.strip() + "<|eot_id|>"
)
return text
def prepare_for_tokenization_qwen_instruct(model, text, pooling_mode="mean"):
"""
Qwen Instruct 포맷팅. 사용자 메시지 블록을 생성하고, 필요 시 EOS 토큰을 붙입니다.
- 주의: 위와 동일하게 첫 번째 인자는 실제로 self가 바인딩됩니다.
"""
text = "<|im_start|>user\n" + text.strip() + "<|im_end|>"
text = text.strip() + "<|endoftext|>"
return text
def prepare_for_tokenization(self, text):
"""
모델 종류에 따라 Instruct 프롬프트 템플릿을 적용하거나, eos 기반 풀링 시 EOS 토큰을 강제로 붙입니다.
- 모델 이름은 config._name_or_path를 기반으로 분기합니다.
- pooling_mode == "eos_token"인 경우, 프레임워크별 EOS 표기를 붙입니다.
"""
if self.model.config._name_or_path == "meta-llama/Meta-Llama-3-8B-Instruct":
text = (
"<|start_header_id|>user<|end_header_id|>\n\n"
+ text.strip()
+ "<|eot_id|>"
)
return text
if self.model.config._name_or_path in [
"mistralai/Mistral-7B-Instruct-v0.2",
"meta-llama/Llama-2-7b-chat-hf",
]:
text = "[INST] " + text.strip() + " [/INST]"
if self.model.config._name_or_path in [
"google/gemma-2-9b-it",
]:
text = "<bos><start_of_turn>user\n" + text.strip() + "<end_of_turn>"
if self.model.config._name_or_path in [
"Qwen/Qwen2-1.5B-Instruct",
"Qwen/Qwen2-7B-Instruct",
]:
text = "<|im_start|>user\n" + text.strip() + "<|im_end|>"
if self.pooling_mode == "eos_token":
if self.model.config._name_or_path == "meta-llama/Meta-Llama-3-8B":
text = text.strip() + "<|end_of_text|>"
elif isinstance(self.model.config, LlamaConfig) or isinstance(
self.model.config, MistralConfig
):
text = text.strip() + " </s>"
elif isinstance(self.model.config, GemmaConfig):
text = text.strip() + "<eos>"
elif isinstance(self.model.config, Qwen2Config):
text = text.strip() + "<|endoftext|>"
return text
def tokenize(self, texts):
"""
입력 텍스트들을 토크나이즈하고, 본문 영역만 임베딩에 반영하기 위한 embed_mask를 생성합니다.
- 입력은 "지시문 !@#$%^&*() 본문"처럼 특수 구분자("!@#$%^&*()")를 포함할 수 있습니다.
구분자 이후의 토큰들만 임베딩 풀링 시 사용합니다.
- original_texts: 전체 텍스트(지시문+본문)를 토크나이즈하여 input_ids/attention_mask 생성
- embed_mask: 각 배치별로 "본문"의 토큰 길이만 1로 표기하여 attention_mask로 교체할 때 사용
"""
texts_2 = []
original_texts = []
for text in texts:
t = text.split("!@#$%^&*()")
texts_2.append(t[1] if len(t) > 1 else "")
original_texts.append("".join(t))
original = self.tokenizer(
original_texts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=self.max_length,
)
embed_mask = None
for t_i, t in enumerate(texts_2):
ids = self.tokenizer(
[t],
return_tensors="pt",
padding=True,
truncation=True,
max_length=self.max_length,
add_special_tokens=False,
)
if embed_mask is None:
e_m = torch.zeros_like(original["attention_mask"][t_i])
if len(ids["input_ids"][0]) > 0:
e_m[-len(ids["input_ids"][0]) :] = torch.ones(
len(ids["input_ids"][0])
)
embed_mask = e_m.unsqueeze(0)
else:
e_m = torch.zeros_like(original["attention_mask"][t_i])
if len(ids["input_ids"][0]) > 0:
e_m[-len(ids["input_ids"][0]) :] = torch.ones(
len(ids["input_ids"][0])
)
embed_mask = torch.cat((embed_mask, e_m.unsqueeze(0)), dim=0)
original["embed_mask"] = embed_mask
return original
def _skip_instruction(self, sentence_feature):
"""
embed_mask를 attention_mask로 교체하여, 임베딩 계산 시 본문 토큰만 고려하도록 합니다.
- 왼쪽 패딩을 가정하므로, 마지막 토큰들이 실제 시퀀스 토큰입니다.
"""
assert (
sentence_feature["attention_mask"].shape
== sentence_feature["embed_mask"].shape
)
sentence_feature["attention_mask"] = sentence_feature["embed_mask"]
def forward(self, sentence_feature: Dict[str, Tensor]):
"""
Transformer 모델을 실행하고, 마지막 은닉 상태(last_hidden_state)를 풀링하여 임베딩을 반환합니다.
- tokenize()에서 생성한 embed_mask는 forward 이후 다시 복구합니다.
"""
embed_mask = None
if "embed_mask" in sentence_feature:
embed_mask = sentence_feature.pop("embed_mask")
reps = self.model(**sentence_feature)
sentence_feature["embed_mask"] = embed_mask
return self.get_pooling(sentence_feature, reps.last_hidden_state)
def get_pooling(self, features, last_hidden_states): # All models padded from left
"""
다양한 방식으로 은닉 상태를 풀링하여 문장 임베딩을 생성합니다. (왼쪽 패딩 가정)
- mean: 시퀀스 길이 구간의 평균
- weighted_mean: 뒤쪽 토큰에 선형 가중치(1..L)를 부여한 가중 평균
- eos_token/last_token: 마지막 토큰의 은닉 상태
- bos_token: BOS 토큰의 은닉 상태 위치를 찾아 반환
- skip_instruction=True이면 embed_mask를 적용하여 본문 토큰만 대상으로 풀링
"""
assert (
self.tokenizer.padding_side == "left"
), "Pooling modes are implemented for padding from left."
if self.skip_instruction:
self._skip_instruction(features)
seq_lengths = features["attention_mask"].sum(dim=-1)
if self.pooling_mode == "mean":
return torch.stack(
[
last_hidden_states[i, -length:, :].mean(dim=0)
for i, length in enumerate(seq_lengths)
],
dim=0,
)
elif self.pooling_mode == "weighted_mean":
bs, l, _ = last_hidden_states.shape
complete_weights = torch.zeros(bs, l, device=last_hidden_states.device)
for i, seq_l in enumerate(seq_lengths):
if seq_l > 0:
complete_weights[i, -seq_l:] = torch.arange(seq_l) + 1
complete_weights[i] /= torch.clamp(
complete_weights[i].sum(), min=1e-9
)
return torch.sum(last_hidden_states * complete_weights.unsqueeze(-1), dim=1)
elif self.pooling_mode == "eos_token" or self.pooling_mode == "last_token":
return last_hidden_states[:, -1]
elif self.pooling_mode == "bos_token":
return last_hidden_states[
features["input_ids"] == self.tokenizer.bos_token_id
]
else:
raise ValueError(f"{self.pooling_mode} is not implemented yet.")
def _convert_to_str(self, instruction, text):
"""
지시문(instruction)과 본문(text)을 하나의 문자열로 결합하고, 본문 길이가 너무 긴 경우 토큰 길이(doc_max_length)까지 축소합니다.
- 결합 형식: "{instruction} !@#$%^&*(){text}" (instruction이 비어있으면 구분자부터 시작)
- 축소 로직: 토크나 길이가 doc_max_length를 초과할 때마다 비율에 따라 단어 단위로 잘라 재시도
"""
tokenized_q = self.tokenizer(
text,
return_tensors="pt",
padding=True,
truncation=True,
max_length=self.max_length,
add_special_tokens=False,
)
tokenized_q_length = len(tokenized_q["input_ids"][0])
while tokenized_q_length > self.doc_max_length:
reduction_ratio = self.doc_max_length / tokenized_q_length
reduced_length = int(len(text.split()) * reduction_ratio)
text = " ".join(text.split()[:reduced_length])
tokenized_q = self.tokenizer(
text,
return_tensors="pt",
padding=True,
truncation=True,
max_length=self.max_length,
add_special_tokens=False,
)
tokenized_q_length = len(tokenized_q["input_ids"][0])
return (
f"{instruction.strip()} !@#$%^&*(){text}"
if instruction
else f"!@#$%^&*(){text}"
)
def encode(
self,
sentences: Union[str, List[str]],
batch_size: int = 32,
show_progress_bar: bool = True,
convert_to_numpy: bool = False,
convert_to_tensor: bool = False,
device: Optional[str] = None,
):
# 한글 요약:
# - 입력 문장(들)을 배치로 토크나이즈하고 모델에 통과시켜 임베딩을 생성합니다.
# - 왼쪽 패딩을 가정하고, 길이순 정렬로 효율적으로 배치를 구성합니다.
# - 단일/다중 GPU 모두 지원하며, 다중 GPU에서는 프로세스 풀을 통해 병렬로 _encode를 호출합니다.
"""
Encode a list of sentences to their respective embeddings. The sentences can be a list of strings or a string.
Args:
sentences: sentence or sentences to encode.
batch_size: batch size for turning sentence tokens into embeddings.
show_progress_bar: whether to show progress bars during encoding steps.
convert_to_numpy: If true, return numpy arrays instead of torch tensors.
convert_to_tensor: If true, return torch tensors (default).
device: torch backend device identifier (e.g., 'cuda', 'cpu','mps' etc.). If not specified,
the default is to use cuda when available, otherwise cpu. Note that only the choice of 'cuda' supports
multiprocessing as currently implemented.
Returns: embeddings of the sentences. Embeddings are detached and always on the CPU (see _encode implementation).
"""
if isinstance(sentences[0], str) and isinstance(sentences[-1], int):
sentences = [sentences]
# required for MEDI version of MTEB
if isinstance(sentences[0], str):
sentences = [[""] + [sentence] for sentence in sentences]
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
concatenated_input_texts = []
for sentence in sentences:
assert isinstance(sentence[0], str)
assert isinstance(sentence[1], str)
concatenated_input_texts.append(
self._convert_to_str(sentence[0], sentence[1])
)
sentences = concatenated_input_texts
self.eval()
if convert_to_tensor:
convert_to_numpy = False
length_sorted_idx = np.argsort([-self._text_length(sen) for sen in sentences])
sentences_sorted = [sentences[idx] for idx in length_sorted_idx]
all_embeddings = []
if torch.cuda.device_count() <= 1:
# This branch also support mps devices
self.to(device)
for start_index in trange(
0,
len(sentences),
batch_size,
desc="Batches",
disable=not show_progress_bar,
):
sentences_batch = sentences_sorted[
start_index : start_index + batch_size
]
embeddings = self._encode(
sentences_batch, device=device, convert_to_numpy=convert_to_numpy
)
all_embeddings.append(embeddings)
else:
num_proc = torch.cuda.device_count()
cuda_compatible_multiprocess = mp.get_context("spawn")
with cuda_compatible_multiprocess.Pool(num_proc) as p:
sentences_batches = [
sentences_sorted[start_index : start_index + batch_size]
for start_index in range(0, len(sentences), batch_size)
]
progress_bar = tqdm(
total=len(sentences_batches),
desc="Batches",
disable=not show_progress_bar,
)
results = []
def update(*args):
progress_bar.update()
for batch in sentences_batches:
results.append(
p.apply_async(
self._encode,
args=(batch, None, convert_to_numpy, True),
callback=update,
)
)
all_embeddings = [result.get() for result in results]
progress_bar.close()
all_embeddings = torch.cat(all_embeddings, dim=0)
all_embeddings = all_embeddings[np.argsort(length_sorted_idx)]
all_embeddings = all_embeddings.to(torch.float32)
if convert_to_numpy:
all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
return all_embeddings
def save(self, output_path, merge_before_save=False, save_config=True):
"""
현재 모델(.model)과 토크나이저를 저장하고, LLM2Vec 전용 설정(llm2vec_config.json)을 함께 기록합니다.
- merge_before_save=True이고 PEFT 모델인 경우, 가중치에 병합 후 저장합니다.
"""
if merge_before_save and isinstance(self.model, PeftModel):
self.model = self.model.merge_and_unload()
# Fixes the issue of saving - https://huggingface.co/McGill-NLP/LLM2Vec-Mistral-7B-Instruct-v2-mntp-unsup-simcse/discussions/1
if hasattr(self.model, "_hf_peft_config_loaded"):
self.model._hf_peft_config_loaded = False
self.model.save_pretrained(output_path)
self.tokenizer.save_pretrained(output_path)
llm2vec_config = {
"pooling_mode": self.pooling_mode,
"max_length": self.max_length,
"doc_max_length": self.doc_max_length,
"skip_instruction": self.skip_instruction,
}
if save_config:
os.makedirs(output_path, exist_ok=True)
with open(f"{output_path}/llm2vec_config.json", "w") as fOut:
json.dump(llm2vec_config, fOut, indent=4)
def _encode(
self,
sentences_batch,
device: Optional[str] = None,
convert_to_numpy: bool = False,
multiprocessing=False,
):
"""
단일 배치에 대해 토크나이즈 -> 디바이스 이동 -> 모델 추론 -> 풀링 임베딩을 수행합니다.
- multiprocessing=True이면 CUDA 장치만 지원하며, 프로세스 id 기반으로 장치를 라운드로빈 할당합니다.
- prepare_for_tokenization_llama_3_instruct를 통해 Instruct 포맷을 적용 후 tokenize를 호출합니다.
"""
if multiprocessing:
# multiprocessing only supports CUDA devices at this time, so we ignore the value of device
# and use cuda:rank for the device
rank = mp.current_process()._identity[0]
if device is None and torch.cuda.is_available():
device = f"cuda:{rank % torch.cuda.device_count()}"
self.to(device)
# Instruct 포맷 적용 후 토크나이즈
features = self.tokenize(
[self.prepare_for_tokenization_llama_3_instruct(sentence) for sentence in sentences_batch]
)
features = batch_to_device(features, device)
with torch.no_grad():
embeddings = self.forward(features)
embeddings = embeddings.detach()
embeddings = embeddings.cpu()
return embeddings
def _text_length(self, text: Union[List[int], List[List[int]]]):
"""
입력 텍스트의 길이를 계산하는 보조 함수.
- 문자열: 문자 길이 반환
- 정수 리스트: 토큰 길이 반환
- 리스트들의 리스트/튜플: 내부 모든 길이의 합
- 딕셔너리: 첫 value의 길이
"""
if (
isinstance(text, str)
or (isinstance(text, list) and isinstance(text[0], int))
or len(text) == 0
): # Single text, list of ints, or empty
return len(text)
if isinstance(text, dict): # {key: value} case
return len(next(iter(text.values())))
elif not hasattr(text, "__len__"): # Object has no len() method
return 1
else:
return sum([len(t) for t in text])
def resize_token_embeddings(
self,
new_num_tokens: Optional[int] = None,
pad_to_multiple_of: Optional[int] = None,
) -> nn.Embedding:
"""토큰 임베딩 테이블 크기를 조정합니다."""
return self.model.resize_token_embeddings(
new_num_tokens=new_num_tokens, pad_to_multiple_of=pad_to_multiple_of
)
def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
"""그라디언트 체크포인팅을 활성화합니다(메모리 절약)."""
self.model.gradient_checkpointing_enable(
gradient_checkpointing_kwargs=gradient_checkpointing_kwargs
)
좋은 질문이에요.
결론부터 말하면, 네—문자열 사이에 실제로 "!@#$%^&*()"를 끼워 넣지만, 그 특수기호는 모델로 보내기 전에 제거됩니다.
이 구분자는 “어디까지가 지시문(instruction)이고 어디부터가 본문(text)인지”를 표시하기 위한 내부 마커로만 쓰여요.
어떻게 동작하나요?
_convert_to_str
문자열을 “지시문 + 공백 + !@#$%^&*() + 본문” 형태로 합칩니다.
예: "instr ... !@#$%^&*()text ..."
tokenize
위에서 만든 문자열을 split("!@#$%^&*()") 해서
t[0] → 지시문 쪽
t[1] → 본문 쪽으로 분리합니다.
original_texts에는 "".join(t)를 넣어 구분자를 제거한 전체 문자열을 만듭니다.
즉, 모델이 실제로 보는 입력에는 구분자가 들어가지 않습니다. 결과는 "instr ... text ..." 형태가 됩니다. 공백은 구분자 앞에 둔 한 칸이 유지되어 자연스럽게 이어집니다.
texts_2에는 본문만(t[1])을 저장해, 본문만 따로 토크나이즈하여 “본문 토큰 길이”를 구합니다.
왼쪽 패딩을 가정하고, embed_mask에서 마지막 N(=본문 토큰 수)개 위치만 1로 만들어, forward에서 attention_mask를 이 마스크로 바꿔 끼웁니다. 이렇게 하면 풀링은 본문 토큰만 대상으로 수행됩니다.
정리하면, 구분자는 “경계 표시용”으로만 쓰이고, 모델 입력에는 포함되지 않습니다.
굳이 저렇게 하는 이유는?
목적: 임베딩 풀링에서 “지시문 토큰은 제외하고 본문 토큰만 평균/가중평균 등으로 풀링”하기 위함이에요.
문제: 토크나이저가 모델/설정마다 특수토큰을 다르게 붙이고, 왼쪽 패딩·truncation도 있어서 “지시문과 본문의 경계 인덱스”를 토큰 단위로 깔끔하게 추적하기 어렵습니다.
해결: 본문만 따로 토크나이즈한 길이 N을 구한 뒤, “왼쪽 패딩”을 가정해 마지막 N개 토큰이 본문이라고 간주하여 마스크(=embed_mask)를 만듭니다. 이때 경계를 구분하기 위해 문자열에 임시 구분자를 넣고, 나중에 제거하는 방식이 단순하면서도 모델 불가지론적으로 동작합니다.
주의사항과 개선 아이디어
희귀 구분자 선택
"!@#$%^&*()"는 자연 언어에 거의 등장하지 않는 패턴이라 충돌 가능성이 낮지만, 0%는 아닙니다.
안전하게 하려면 더 희귀한 토큰(예: UUID 기반 "<<<LLM2VEC_SEP_7f0c9d4c>>>")을 쓰거나 설정에서 바꿀 수 있게 하는 게 좋아요.
입력에 구분자가 들어오는 경우
현재 코드는 text.split("!@#$%^&*()")로 나눕니다. 사용자가 지시문/본문에 우연히 같은 문자열을 포함하면 경계 판단이 틀어질 수 있어요.
완화책:
split("... ", 1)처럼 maxsplit=1을 써서 “첫 번째 등장만” 분리하거나,
아예 매우 특이한 구분자를 쓰고, instruction/text에 그 구분자가 포함되면 예외 처리/치환을 하는 방법.
더 근본적인 대안
구분자 없이도 가능하게, 구조적으로 (instruction, text) 튜플을 끝까지 들고 가며 토큰 단계에서 경계를 계산하는 방식도 있습니다. 다만 현재 구현은 “문자열 리스트” 인터페이스를 유지하려고 구분자 전략을 쓴 것으로 보입니다.
필요하시면 “구분자 안전성”을 위해 split을 maxsplit=1로 바꾸거나, 더 희귀한 구분자/검증 로직을 추가하는 패치를 제안드릴 수 있어요.
이젠 학습 코드!!
Argument가 진짜 많습니다 ㅎㅎ
Collator = Trainer에 넘겨줄 batch 사이즈의 토큰화된 text 전달
LLM2VecSupervisedTrainer
LM2Vec 전용 Trainer 확장 클래스
주요 목적:
- 기본 HF Trainer의 contrastive / seq2seq 목적이 아닌 "임베딩→조건 토큰 생성" 양방향(q→d, d→q) 생성 손실을
하나의 LM 헤드로 계산하는 커스텀 학습 루프를 제공.
- (향후) ranking / contrastive loss 를 함께 결합할 수 있는 확장 포인트(self.loss_function) 유지.
입력 구조(데이터 콜레이터가 제공):
features: List[Dict[str, Tensor]] 길이 = 텍스트 슬롯 개수 (예: 2 -> [query, doc])
features[0]: { 'input_ids': (B, Lq), 'attention_mask': (B, Lq), ... }
features[1]: { 'input_ids': (B, Ld), 'attention_mask': (B, Ld), ... }
(선택) features[2..]: negatives (랭킹 손실 추가 시 사용 가능)
labels: Tensor(shape=[B]) 현재 compute_loss에서는 미사용 (placeholder)
손실 구성:
1. q 임베딩(q_reps), d 임베딩(d_reps) 생성 (LLM2Vec wrapper -> pooling 포함)
2. 두 임베딩 각각을 첫 토큰(가상 prefix embedding)으로 삽입하여
- d_reps + q_toks => d2q (문서 임베딩이 query를 복원하도록)
- q_reps + d_toks => q2d (쿼리 임베딩이 문서를 복원하도록)
3. LM CrossEntropy(next token 예측) 두 개를 계산
4. 경험적 비율 0.8(d2q) : 0.2(q2d) 로 가중 합 → 최종 loss
설계상 특징:
- 임베딩을 첫 토큰으로 prepend 할 때 labels 첫 위치는 -100 (무시) 처리 → 임베딩 자체를 예측 대상으로 삼지 않음
- bidirectional 어텐션이 활성화된 백본이면 (옵션) 더 풍부한 문맥 학습 가능
- model.module.* 형태 접근은 Accelerate / DDP 래핑 환경(DistributedDataParallel) 고려
확장 아이디어:
- ranking loss 추가: features 2개 초과 시 negatives 로 간주하여 in-batch + hard negative 결합
- 가중치 스케줄링: 학습 초반 q2d 비중을 키우고 후반 줄이거나 반대로 변형
- 임베딩 삽입 위치를 prefix 외 mid / suffix로 바꿔 비교 실험
Comput_loss_q2d_d2q
(토큰 예측 LM Loss)
Teacher forcing 시 t 시점 로짓으로 t+1 시점 토큰을 맞추는 전형적 next-token CrossEntropy.
입력 텐서 형상:
- logits : [B, S, V]
B = batch size, S = (prefix 포함) 시퀀스 길이, V = vocab size
- labels : [B, S]
prefix 자리(맨 앞)는 -100 (ignore_index) 로 채워져 있어 loss 제외
1) shift (오프셋 이동) 이유
- 언어모델은 position i 의 출력으로 position i+1 의 정답을 예측
- 따라서 마지막 토큰의 로짓은 다음 토큰이 없어 사용 불가 → logits[..., :-1, :]
- labels[..., 1:] 로 한 칸 앞으로 맞춰 정렬
2) contiguous() 호출 이유
- 슬라이싱/transpose 후 텐서는 Memory 상 비연속(non-contiguous)이 될 수 있음
- view() 는 연속(contiguous) 메모리 레이아웃을 가정 → 비연속이면 RuntimeError 발생 가능
- contiguous() 는 (필요 시) 새로운 연속 버퍼를 만들어 view 안전성 확보
- 대안: reshape() 는 내부적으로 contiguous 보장해 주기도 하지만, 의도를 명확히 하기 위해 contiguous+view 패턴 사용
3) view(-1, V) / view(-1)
- CrossEntropyLoss 입력 형태 요구: input [N, C], target [N]
- 여기서 N = B * (S-1) (prefix 제외 + 마지막 토큰 제외)
- shift_logits: [B, S-1, V] → view(-1, V) ⇒ [B*(S-1), V]
- shift_labels: [B, S-1] → view(-1) ⇒ [B*(S-1)]
- 이렇게 평탄화(flatten)하면 배치/시퀀스 차원을 하나로 합쳐 모든 위치 토큰을 동일한 샘플로 간주하여 CE 계산
4) ignore_index=-100 작동
- CrossEntropyLoss 기본 ignore_index=-100 (PyTorch 기본값)
- prefix 위치와 pad/eos 마스킹된 토큰들은 -100 → loss에서 무시
5) 장점
- for 루프 없이 벡터라이즈된 단일 CE 호출 → GPU 효율
반환:
- scalar loss (tensor)
"""
Stage I 학습 스크립트 (q2d / d2q 양방향 생성 + 임베딩 학습)
전체 흐름 개요:
1) 인자 파싱 (모델 / 데이터 / 훈련 / 커스텀 손실 설정)
2) 실험 ID 생성 -> 출력 디렉토리 구성
3) 데이터셋 로딩 및 메모리에 예제 로드 (collator가 지시문/본문 처리)
4) LLM2Vec_q2d_d2q 모델 로딩 (옵션: bidirectional, PEFT 병합)
5) (선택) PEFT LoRA 적용 / trainable 파라미터 확인
6) Trainer 확장(LLM2VecSupervisedTrainer)으로 커스텀 compute_loss 구현:
- q_reps, d_reps 임베딩 추출
- 임베딩을 토큰 시퀀스 첫 위치로 삽입하여 teacher forcing 방식으로
q2d(d→q) / d2q(q→d) 생성 손실 계산 (CrossEntropy)
- 두 손실을 가중 합(0.8/0.2)
7) 학습 진행, (선택) 일정 스텝 후 조기 종료 콜백
핵심 포인트:
- q_reps / d_reps는 LLM2Vec 임베딩 (pooling 포함) 결과
- 생성 손실을 통해 임베딩이 "상대 시퀀스" 복원에 유용하도록 유도
- attention_mask / labels를 재구성하여 첫 위치는 임베딩(학습 입력)만 주고 label은 -100 (무시)
추가 개선 여지:
- rank loss(주석 처리)와 결합하여 검색 품질 향상 가능
- collator에서 instruction과 본문을 구분하는 separator를 인자로 받을 수 있도록 확장 가능
"""
import logging
from dataclasses import dataclass, field
import os
import sys
from typing import Any, Dict, List, Optional, Tuple, Union
from torch.nn import CrossEntropyLoss
import torch
from torch import nn
from torch.utils.data import DataLoader, SequentialSampler
from accelerate import Accelerator, DistributedDataParallelKwargs
from accelerate.logging import get_logger
import transformers
from transformers import (
MODEL_FOR_MASKED_LM_MAPPING,
HfArgumentParser,
TrainingArguments,
Trainer,
TrainerCallback,
LlamaConfig,
MistralConfig,
GemmaConfig,
Qwen2Config,
set_seed,
)
from transformers.trainer_utils import seed_worker
from peft import LoraConfig, get_peft_model
from llm2vec import LLM2Vec_q2d_d2q # 임베딩 + q2d/d2q 생성 손실 계산에 사용되는 사용자 정의 래퍼
from llm2vec.dataset.utils import load_dataset
from llm2vec.loss.utils import load_loss
from llm2vec.experiment_utils import generate_experiment_id
from tqdm import tqdm
transformers.logging.set_verbosity_error()
logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=logging.INFO,
)
logger = get_logger(__name__, log_level="INFO")
MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
def prepare_for_tokenization_llama_3_instruct(model, text, pooling_mode="mean"):
"""Llama3 Instruct 형식으로 사용자 입력을 래핑.
- Stage I collator에서 배치 텍스트 전처리에 사용
- 모델별 프롬프트 템플릿을 맞추어 토크나이저 일관성 확보
"""
text = (
"<|start_header_id|>user<|end_header_id|>\n\n" + text.strip() + "<|eot_id|>"
)
return text
def prepare_for_tokenization(model, text, pooling_mode="mean"):
"""모델 종류/풀링 모드에 따라 Instruct 포맷 또는 EOS 토큰을 추가.
- 다양한 backbones(Llama/Mistral/Gemma/Qwen2) 간 일관된 사용자 프롬프트 구조를 제공
- pooling_mode가 eos_token이면 EOS를 강제 부착하여 마지막 토큰 기반 풀링 편의 확보
"""
if model.config._name_or_path == "meta-llama/Meta-Llama-3-8B-Instruct":
text = (
"<|start_header_id|>user<|end_header_id|>\n\n" + text.strip() + "<|eot_id|>"
)
return text
if model.config._name_or_path in [
"mistralai/Mistral-7B-Instruct-v0.2",
"meta-llama/Llama-2-7b-chat-hf",
]:
text = "[INST] " + text.strip() + " [/INST]"
if model.config._name_or_path in [
"google/gemma-2-9b-it",
]:
text = "<bos><start_of_turn>user\n" + text.strip() + "<end_of_turn>"
if model.config._name_or_path in [
"Qwen/Qwen2-1.5B-Instruct",
"Qwen/Qwen2-7B-Instruct",
]:
text = "<|im_start|>user\n" + text.strip() + "<|im_end|>"
if pooling_mode == "eos_token":
if model.config._name_or_path == "meta-llama/Meta-Llama-3-8B":
text = text.strip() + "<|end_of_text|>"
elif isinstance(model.config, LlamaConfig) or isinstance(
model.config, MistralConfig
):
text = text.strip() + " </s>"
elif isinstance(model.config, GemmaConfig):
text = text.strip() + "<eos>"
elif isinstance(model.config, Qwen2Config):
text = text.strip() + "<|endoftext|>"
return text
def initialize_peft(
model,
lora_r: int = 32,
lora_alpha: int = 64,
lora_dropout: float = 0.05,
lora_modules: Optional[List[str]] = None,
):
if lora_modules is None and model.config.__class__.__name__ in [
"LlamaConfig",
"MistralConfig",
"GemmaConfig",
"Qwen2Config",
]:
lora_modules = [
"q_proj",
"v_proj",
"k_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
]
elif lora_modules is None:
raise ValueError("lora_modules must be specified for this model.")
config = LoraConfig(
r=lora_r,
lora_alpha=lora_alpha,
target_modules=lora_modules,
lora_dropout=lora_dropout,
bias="none",
task_type=None,
)
model = get_peft_model(model, config)
# trainable 파라미터 (LoRA 인젝션된 모듈 위주) 출력
print(f"Model's Lora trainable parameters:")
model.print_trainable_parameters()
return model
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
"""
model_name_or_path: Optional[str] = field(
default=None,
metadata={
"help": (
"The base model checkpoint for weights initialization. Don't set if you want to train a model from scratch."
)
},
)
peft_model_name_or_path: Optional[str] = field(
default=None,
metadata={"help": ("The PEFT model checkpoint to add on top of base model.")},
)
bidirectional: Optional[bool] = field(
default=False,
metadata={
"help": (
"Whether to enable bidirectional attention in the model. If set to False, the model will use unidirectional attention."
)
},
)
max_seq_length: Optional[int] = field(
default=None,
metadata={
"help": (
"The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated."
)
},
)
torch_dtype: Optional[str] = field(
default=None,
metadata={
"help": (
"Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
"dtype will be automatically derived from the model's weights."
),
"choices": ["auto", "bfloat16", "float16", "float32"],
},
)
attn_implementation: Optional[str] = field(
default="sdpa",
metadata={
"help": ("The attention implementation to use in the model."),
"choices": ["eager", "sdpa", "flash_attention_2"],
},
)
pooling_mode: Optional[str] = field(
default="mean",
metadata={
"help": ("The pooling mode to use in the model."),
"choices": ["mean", "weighted_mean", "eos_token"],
},
)
use_peft: Optional[bool] = field(
default=False,
metadata={
"help": ("Whether to use PEFT or not.")
},
)
@dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""
dataset_name: Optional[str] = field(
default=None,
metadata={"help": "The name of the dataset to use. Options: E5"},
)
dataset_file_path: Optional[str] = field(
default=None, metadata={"help": "The input training data file or folder."}
)
# TODO: implement this
max_train_samples: Optional[int] = field(
default=None,
metadata={
"help": (
"For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
)
},
)
@dataclass
class CustomArguments:
"""
Custom arguments for the script
"""
lora_dropout: float = field(
default=0.05, metadata={"help": "The dropout rate for lora"}
)
lora_r: int = field(default=8, metadata={"help": "The r value for lora"})
stop_after_n_steps: int = field(
default=10000, metadata={"help": "Stop training after n steps"}
)
experiment_id: Optional[str] = field(
default=None, metadata={"help": "The experiment id"}
)
loss_class: Optional[str] = field(
default="HardNegativeNLLLoss",
metadata={
"help": "The loss class to use for training. Options: HardNegativeNLLLoss"
},
)
loss_scale: float = field(
default=50.0, metadata={"help": "The loss scale for the loss function"}
)
@dataclass
class DefaultCollator:
model: LLM2Vec_q2d_d2q
def __init__(self, model: LLM2Vec_q2d_d2q) -> None:
self.model = model
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
"""배치 내 샘플들을 q/d (및 추가 negative가 있다면 확장) 토크나이즈.
- example.texts: [query_text, doc_text, (optional negatives...)] 구조 가정
- 여기서는 Llama3 instruct 템플릿 적용 (다른 모델일 경우 확장 가능)
- 반환: sentence_features(list(dict)), labels(tensor)
"""
batch = features
num_texts = len(batch[0].texts) # 각 샘플이 가진 텍스트 개수 (2: q/d, 3+: negatives 포함 가능)
texts = [[] for _ in range(num_texts)]
labels = []
for example in batch:
for idx, text in enumerate(example.texts):
# 모델별 프롬프트 포맷 적용
text = prepare_for_tokenization_llama_3_instruct(
self.model, text, pooling_mode=self.model.pooling_mode
)
texts[idx].append(text)
labels.append(example.label)
labels = torch.tensor(labels)
# 각 위치(q, d, neg1, neg2 ...)별로 따로 토크나이즈 -> 리스트로 묶어 Trainer에 전달
sentence_features = []
for idx in range(num_texts):
tokenized = self.model.tokenize(texts[idx])
sentence_features.append(tokenized)
return sentence_features, labels
class StopTrainingCallback(TrainerCallback):
def __init__(self, stop_after_n_steps: int):
self.stop_after_n_steps = stop_after_n_steps
def on_step_end(self, args, state, control, **kwargs):
if state.global_step >= self.stop_after_n_steps:
control.should_training_stop = True
class LLM2VecSupervisedTrainer(Trainer):
def __init__(
self,
*args,
loss_function=None,
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
self.loss_function = loss_function
def compute_loss(
self,
model: nn.Module,
inputs: Dict[str, Union[torch.Tensor, Any]],
return_outputs: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
features, labels = inputs # labels는 아직 사용하지 않음 (확장: 랭킹 손실 적용 가능)
# 1) 쿼리 / 문서 임베딩 추출 (pooling 적용된 dense 벡터)
q_reps = model(features[0]) # [bs, hidden]
d_reps = model(features[1]) # [bs, hidden]
# 2) teacher forcing에 사용할 원본 토큰 시퀀스 추출
q_input_ids = features[0]['input_ids']
q_attention_mask = features[0]['attention_mask']
d_input_ids = features[1]['input_ids']
d_attention_mask = features[1]['attention_mask']
# 3) 생성 학습용 label 준비 (pad/eos 토큰은 -100으로 마스킹하여 loss 제외)
q_labels = q_input_ids.clone()
q_labels[q_labels == model.module.tokenizer.eos_token_id] = -100
d_labels = d_input_ids.clone()
d_labels[d_labels == model.module.tokenizer.eos_token_id] = -100
# 4) attention mask 확장: 임베딩이 삽입될 첫 위치(길이 1)를 위한 mask(1) 추가
attention_mask_tmp = torch.full(
(d_attention_mask.shape[0], 1), 1, dtype=d_attention_mask.dtype, device=d_attention_mask.device
)
d2q_attention_mask = torch.cat((attention_mask_tmp, q_attention_mask), dim=1) # d 임베딩 + q 토큰
q2d_attention_mask = torch.cat((attention_mask_tmp, d_attention_mask), dim=1) # q 임베딩 + d 토큰
# 5) labels도 동일 길이로 맞추되, 첫 위치(임베딩 자리)는 -100 (예측 제외)
label_tmp = torch.full(
(d_labels.shape[0], 1), -100, dtype=d_labels.dtype, device=d_labels.device
)
d2q_labels = torch.cat((label_tmp, q_labels), dim=1)
q2d_labels = torch.cat((label_tmp, d_labels), dim=1)
def compute_loss_q2d_d2q(logits, labels):
"""표준 LM CrossEntropy: shift하여 next token 예측.
- logits: [bs, seq_len, vocab]
- labels: [bs, seq_len]
- 첫 토큰은 (임베딩 자리) label=-100 이므로 loss 미적용
"""
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, model.module.model.config.vocab_size)
shift_labels = shift_labels.view(-1).to(shift_logits.device)
return loss_fct(shift_logits, shift_labels)
# 6) d→q (d2q): d_reps를 첫 토큰으로 삽입, 나머지 q 시퀀스를 teacher forcing
q_input_embeds = model.module.model.embed_tokens(q_input_ids) # [bs, q_len, hidden]
combined_d2q_input_embeds = torch.cat([d_reps.unsqueeze(1), q_input_embeds], dim=1)
d2q_outputs = model.module.model(
inputs_embeds=combined_d2q_input_embeds,
attention_mask=d2q_attention_mask,
)
d2q_logits = model.module.lm_head(d2q_outputs[0])
d2q_loss = compute_loss_q2d_d2q(d2q_logits, d2q_labels)
# 7) q→d (q2d): q_reps를 첫 토큰으로 삽입, 나머지 d 시퀀스를 teacher forcing
d_input_embeds = model.module.model.embed_tokens(d_input_ids)
combined_q2d_input_embeds = torch.cat([q_reps.unsqueeze(1), d_input_embeds], dim=1)
q2d_outputs = model.module.model(
inputs_embeds=combined_q2d_input_embeds,
attention_mask=q2d_attention_mask,
)
q2d_logits = model.module.lm_head(q2d_outputs[0])
q2d_loss = compute_loss_q2d_d2q(q2d_logits, q2d_labels)
# 8) (선택) Ranking loss 자리 - 현재는 주석 처리
# d_reps_neg = None
# if len(features) > 2:
# d_reps_neg = self.model(features[2])
# rank_loss = self.loss_function(q_reps, d_reps, d_reps_neg)
# total_loss = d2q_loss + q2d_loss + rank_loss
# 9) 두 생성 손실 가중 결합 (경험적 비율: d2q 비중 0.8, q2d 비중 0.2)
loss = 0.8 * d2q_loss + 0.2 * q2d_loss
return loss
def get_train_dataloader(self) -> DataLoader:
# Copying most of the code from the parent class, changing the sampler to SequentialSampler
if self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.")
train_dataset = self.train_dataset
data_collator = self.data_collator
data_collator = self._get_collator_with_removed_columns(
data_collator, description="training"
)
dataloader_params = {
"batch_size": self._train_batch_size,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
"persistent_workers": self.args.dataloader_persistent_workers,
}
if not isinstance(train_dataset, torch.utils.data.IterableDataset):
# Changing from random sampler to sequential sampler
dataloader_params["sampler"] = SequentialSampler(train_dataset)
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["worker_init_fn"] = seed_worker
return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
def _save(self, output_dir: Optional[str] = None, state_dict=None):
# If we are executing this function, we are the process zero, so we don't check for that.
output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
logger.info(f"Saving model checkpoint to {output_dir}")
self.model.save(output_dir)
# Good practice: save your training arguments together with the trained model
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
def main():
parser = HfArgumentParser(
(ModelArguments, DataTrainingArguments, TrainingArguments, CustomArguments)
)
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args, custom_args = parser.parse_json_file(
json_file=os.path.abspath(sys.argv[1])
)
else:
(
model_args,
data_args,
training_args,
custom_args,
) = parser.parse_args_into_dataclasses()
if training_args.ddp_find_unused_parameters:
kwargs = [
DistributedDataParallelKwargs(
dim=0,
broadcast_buffers=True,
bucket_cap_mb=25,
find_unused_parameters=True,
check_reduction=False,
gradient_as_bucket_view=False,
)
]
else:
kwargs = []
# accelerator = Accelerator(kwargs_handlers=kwargs)
# Accelerate 초기화 (DDP, 혼합정밀, 장치 관리)
accelerator = Accelerator()
set_seed(training_args.seed)
if training_args.gradient_checkpointing:
# reentrant=False로 설정하여 일부 모델에서 발생 가능한 체크포인팅 문제 회피
training_args.gradient_checkpointing_kwargs = {"use_reentrant": False}
if custom_args.experiment_id is not None:
experiment_id = custom_args.experiment_id
else:
experiment_id = generate_experiment_id(
name=data_args.dataset_name,
split="train",
model_name=(
model_args.model_name_or_path
if "/" not in model_args.model_name_or_path
else model_args.model_name_or_path.split("/")[-1]
),
pooling_mode=model_args.pooling_mode,
train_batch_size=training_args.per_device_train_batch_size
* accelerator.num_processes
* training_args.gradient_accumulation_steps,
max_seq_length=model_args.max_seq_length,
bidirectional=model_args.bidirectional,
epochs=training_args.num_train_epochs,
seed=training_args.seed,
warmup_steps=training_args.warmup_steps,
lr=training_args.learning_rate,
lora_r=custom_args.lora_r,
use_peft=model_args.use_peft,
)
training_args.output_dir = f"{training_args.output_dir}/{experiment_id}"
# TODO: can also pass separator arg here
# 데이터셋 로드 (효과적 배치 크기: per_device * num_processes)
train_dataset = load_dataset(
data_args.dataset_name,
split="train",
file_path=data_args.dataset_file_path,
effective_batch_size=training_args.per_device_train_batch_size
* accelerator.num_processes,
)
# 전체 예제를 메모리에 로드 (데이터가 크다면 스트리밍/샘플링 전략 고려 가능)
train_examples = [
train_dataset[i]
for i in tqdm(
range(len(train_dataset)),
desc="Loading train examples...",
disable=not accelerator.is_main_process,
)
]
torch_dtype = (
model_args.torch_dtype
if model_args.torch_dtype in ["auto", None]
else getattr(torch, model_args.torch_dtype)
)
# LLM2Vec_q2d_d2q 모델 로딩 (bidirectional 옵션 포함)
model = LLM2Vec_q2d_d2q.from_pretrained(
base_model_name_or_path=model_args.model_name_or_path,
enable_bidirectional=model_args.bidirectional,
peft_model_name_or_path=model_args.peft_model_name_or_path,
merge_peft=True,
pooling_mode=model_args.pooling_mode,
max_length=model_args.max_seq_length,
torch_dtype=torch_dtype,
attn_implementation=model_args.attn_implementation,
)
if model_args.use_peft:
# 내부 구조: wrapper.model 이 HF 원본 모델 → PEFT는 내부에 적용
model.model = initialize_peft(
model.model,
lora_r=custom_args.lora_r,
lora_alpha=2 * custom_args.lora_r,
lora_dropout=custom_args.lora_dropout,
)
else:
print("Not using PEFT")
tokenizer = model.tokenizer
train_loss = load_loss(custom_args.loss_class, scale=custom_args.loss_scale)
data_collator = DefaultCollator(model)
# 커스텀 Trainer (q2d/d2q 생성 손실 활용)
trainer = LLM2VecSupervisedTrainer(
model=model,
args=training_args,
train_dataset=train_examples,
data_collator=data_collator,
tokenizer=tokenizer,
loss_function=train_loss,
)
if custom_args.stop_after_n_steps is not None:
trainer.add_callback(StopTrainingCallback(custom_args.stop_after_n_steps))
# 학습 시작
trainer.train()
if __name__ == "__main__":
main()
변경 내용
최상단에 Stage I 전체 파이프라인 개요 모듈 docstring 추가.
모델별 프롬프트 전처리 함수들에 목적 및 역할 설명.
LoRA 초기화(initialize_peft)에 trainable 파라미터 출력 의미 주석.
DefaultCollator에 q/d(+negatives 확장 가능) 처리 흐름과 반환 구조 설명.
LLM2VecSupervisedTrainer.compute_loss 전 과정을 9단계로 세분화:
1. 임베딩 추출 → teacher forcing용 label/mask 구성 → d2q / q2d 두 생성 손실 → 가중 합산.
2. shift 로직(CrossEntropy)와 첫 위치 임베딩(-100 마스킹) 이유 명시.
랭킹 손실 자리에 향후 확장 가능성 주석 유지.
Accelerate 초기화, gradient checkpointing 설정 이유, 데이터셋 전부 메모리 로드 주석.
모델 로딩/PEFT 적용/Trainer 생성/학습 호출 지점에 목적 및 내부 구조 설명.
현재 상태
파일 편집 성공적으로 적용됨.
표시된 import 관련 문제(Import "torch" could not be resolved 등)는 편집 과정에서 새로 생긴 것이 아니라, 해당 환경(분석기/IDE)에서 PyTorch/Accelerate/PEFT 패키지를 인식하지 못할 때 흔히 나오는 타입/분석 경고입니다. 실제 실행 환경에 패키지가 설치되어 있다면 런타임에서는 문제 없고, 없다면 아래 설치가 필요합니다.
다음에 고려할 수 있는 확장 아이디어
현재 주석 처리된 ranking loss 복구 및 가중치 튜닝.
separator 기반 마스킹을 collator 수준에서 직접 처리해 tokenize 호출 횟수 줄이기.
메모리 전체 로드 대신 큰 코퍼스일 때 streaming + iterable dataset 적용.
d2q : q2d 비율(0.8/0.2) 학습 중 동적 스케줄링 실험.
"""Stage II 학습 스크립트 (Supervised Contrastive / Retrieval Fine-tuning)
이 스크립트는 LLM2Vec 모델을 2단계로 학습하는 파이프라인 중 "Stage II" 단계에 해당합니다.
Stage I (생성 기반 사전 학습) 과의 핵심 차이점:
- Stage I: query->doc, doc->query 양방향 생성(reconstruction) 손실을 통해 임베딩이 텍스트 의미를 보존하도록 만듭니다.
- Stage II: (현재 파일) Hard Negative 가 포함된 감독(contrastive / ranking) 기반 손실을 사용하여
실제 검색(retrieval) 시나리오에서 "올바른 쿼리-문서 쌍" 의 점수가 높아지고
하드/랜덤 negative 의 점수가 낮아지도록 임베딩 공간을 재정렬(refinement) 합니다.
전체 흐름 개요:
1) 인자(Arguments) 파싱 (모델, 데이터, 학습, 커스텀 설정)
2) 가속기(Accelerator) 및 시드 설정
3) 실험 ID 자동 생성 (재현성 및 결과 관리 목적)
4) 데이터셋 로드 (query, positive doc, optional negative docs 포함 구조)
5) LLM2Vec base 모델 로드 (+ 선택적으로 PEFT/LoRA 어댑터 적용)
6) Collator 로 배치 구성 (동일 인덱스끼리 query / positive / negative 텍스트 묶음 토크나이즈)
7) Trainer (커스텀 compute_loss) 구성: 모델 forward -> 임베딩 산출 -> loss_function(q, pos, negs)
8) stop_after_n_steps 콜백 옵션 적용
9) trainer.train() 호출로 학습 루프 시작 (Evaluation / Save 로직은 Stage II 목적에 맞게 단순화)
중요 포인트:
- Stage II 손실(load_loss): HardNegativeNLLLoss (예: in-batch + 명시적 negative 에 대해 NLL 기반)
- Collator 는 labels(정답 인덱스)를 그대로 유지하며 Trainer 내부 compute_loss 에서 사용
- bidirectional 플래그: 임베딩 추출 시 양방향 어텐션 허용 여부 (검색 성능 향상 목적)
- LoRA 사용 시 학습 가능한 적은 수의 파라미터로 빠른 파인튜닝
이 주석은 학습 파이프라인의 구조를 빠르게 이해하고 Stage I 코드와의 차이를 구분하는데 목적이 있습니다.
"""
import logging
from dataclasses import dataclass, field
import os
import sys
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from torch import nn
from torch.utils.data import DataLoader, SequentialSampler
from accelerate import Accelerator, DistributedDataParallelKwargs
from accelerate.logging import get_logger
import transformers
from transformers import (
MODEL_FOR_MASKED_LM_MAPPING,
HfArgumentParser,
TrainingArguments,
Trainer,
TrainerCallback,
LlamaConfig,
MistralConfig,
GemmaConfig,
Qwen2Config,
set_seed,
)
from transformers.trainer_utils import seed_worker
from peft import LoraConfig, get_peft_model
from llm2vec import LLM2Vec
from llm2vec.dataset.utils import load_dataset
from llm2vec.loss.utils import load_loss
from llm2vec.experiment_utils import generate_experiment_id
from tqdm import tqdm
transformers.logging.set_verbosity_error()
logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=logging.INFO,
)
logger = get_logger(__name__, log_level="INFO")
MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
def prepare_for_tokenization_llama_3_instruct(model, text, pooling_mode="mean"):
# Llama 3 Instruct 형식 프롬프트 템플릿 적용
# - Stage II 에서는 임베딩 추출 일관성을 위해 Stage I 과 동일/유사한 프롬프트 래핑 유지
# - user 역할 토큰 헤더/푸터를 붙여 모델이 사용자 질의 문맥으로 인식하도록 함
text = (
"<|start_header_id|>user<|end_header_id|>\n\n" + text.strip() + "<|eot_id|>"
)
return text
def prepare_for_tokenization(model, text, pooling_mode="mean"):
# 다양한 기반 모델별(Chat / Instruct) 토크나이저가 기대하는 특수 토큰 패턴을 부착
# - 모델별 system/user/assistant turn 구분 토큰이 다르므로 조건 분기
# - pooling_mode 가 eos_token 일 때는 뒤에 eos 를 강제로 붙여 문장 경계 명확화
if model.config._name_or_path == "meta-llama/Meta-Llama-3-8B-Instruct":
text = (
"<|start_header_id|>user<|end_header_id|>\n\n" + text.strip() + "<|eot_id|>"
)
return text
if model.config._name_or_path in [
"mistralai/Mistral-7B-Instruct-v0.2",
"meta-llama/Llama-2-7b-chat-hf",
]:
text = "[INST] " + text.strip() + " [/INST]"
if model.config._name_or_path in [
"google/gemma-2-9b-it",
]:
text = "<bos><start_of_turn>user\n" + text.strip() + "<end_of_turn>"
if model.config._name_or_path in [
"Qwen/Qwen2-1.5B-Instruct",
"Qwen/Qwen2-7B-Instruct",
]:
text = "<|im_start|>user\n" + text.strip() + "<|im_end|>"
if pooling_mode == "eos_token":
if model.config._name_or_path == "meta-llama/Meta-Llama-3-8B":
text = text.strip() + "<|end_of_text|>"
elif isinstance(model.config, LlamaConfig) or isinstance(
model.config, MistralConfig
):
text = text.strip() + " </s>"
elif isinstance(model.config, GemmaConfig):
text = text.strip() + "<eos>"
elif isinstance(model.config, Qwen2Config):
text = text.strip() + "<|endoftext|>"
return text
def initialize_peft(
model,
lora_r: int = 8,
lora_alpha: int = 16,
lora_dropout: float = 0.05,
lora_modules: Optional[List[str]] = None,
):
# LoRA(PEFT) 구성 함수
# - 특정 projection 모듈(q,k,v,o,up,down,gate)에 low-rank 업데이트 행렬을 삽입하여
# 원본 모델 파라미터는 동결(freeze) + 적은 수의 학습 파라미터로 파인튜닝 가속
# - 지원 모델(Llama/Mistral/Gemma/Qwen2)은 동일한 모듈 네이밍 컨벤션을 따름
# - 그 외 모델이면 사용자 정의 lora_modules 를 명시적으로 전달해야 함
if lora_modules is None and model.config.__class__.__name__ in [
"LlamaConfig",
"MistralConfig",
"GemmaConfig",
"Qwen2Config",
]:
lora_modules = [
"q_proj",
"v_proj",
"k_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
]
elif lora_modules is None:
raise ValueError("lora_modules must be specified for this model.")
config = LoraConfig(
r=lora_r,
lora_alpha=lora_alpha,
target_modules=lora_modules,
lora_dropout=lora_dropout,
bias="none",
task_type=None,
)
model = get_peft_model(model, config)
print(f"Model's Lora trainable parameters:")
model.print_trainable_parameters()
return model
@dataclass
class ModelArguments:
"""모델 관련 설정 인자
- model_name_or_path: 사전학습(또는 Stage I 완료)된 베이스 LLM 경로/허브 이름
- peft_model_name_or_path: 기존에 학습된 LoRA 어댑터를 이어서 사용할 때 경로
- bidirectional: True 시 causal mask 제거된 양방향 attention (문맥 정보 극대화) 사용
- max_seq_length: 토크나이즈 후 입력을 자를 최대 길이 (길면 truncate)
- torch_dtype: 모델 로드 dtype (auto / bf16 / fp16 / fp32)
- attn_implementation: PyTorch 기본(eager) / sdpa / flash_attention_2 선택
- pooling_mode: 임베딩 풀링 전략 (mean / weighted_mean / eos_token)
- use_peft: LoRA 적용 여부(False 시 전체 파라미터 학습 또는 동결 전략 필요)
"""
model_name_or_path: Optional[str] = field(
default=None,
metadata={
"help": (
"The base model checkpoint for weights initialization. Don't set if you want to train a model from scratch."
)
},
)
peft_model_name_or_path: Optional[str] = field(
default=None,
metadata={"help": ("The PEFT model checkpoint to add on top of base model.")},
)
bidirectional: Optional[bool] = field(
default=False,
metadata={
"help": (
"Whether to enable bidirectional attention in the model. If set to False, the model will use unidirectional attention."
)
},
)
max_seq_length: Optional[int] = field(
default=None,
metadata={
"help": (
"The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated."
)
},
)
torch_dtype: Optional[str] = field(
default=None,
metadata={
"help": (
"Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
"dtype will be automatically derived from the model's weights."
),
"choices": ["auto", "bfloat16", "float16", "float32"],
},
)
attn_implementation: Optional[str] = field(
default="sdpa",
metadata={
"help": ("The attention implementation to use in the model."),
"choices": ["eager", "sdpa", "flash_attention_2"],
},
)
pooling_mode: Optional[str] = field(
default="mean",
metadata={
"help": ("The pooling mode to use in the model."),
"choices": ["mean", "weighted_mean", "eos_token"],
},
)
use_peft: Optional[bool] = field(
default=True,
metadata={
"help": ("Whether to use PEFT or not.")
}
)
@dataclass
class DataTrainingArguments:
"""데이터 로딩 관련 설정
- dataset_name: 커스텀 로더 내부 스위칭용 이름 (예: E5)
- dataset_file_path: 로컬 json / jsonl / 디렉토리 경로
- max_train_samples: (선택) 디버깅/빠른 테스트용 샘플 수 제한
"""
dataset_name: Optional[str] = field(
default=None,
metadata={"help": "The name of the dataset to use. Options: E5"},
)
dataset_file_path: Optional[str] = field(
default=None, metadata={"help": "The input training data file or folder."}
)
# TODO: implement this
max_train_samples: Optional[int] = field(
default=None,
metadata={
"help": (
"For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
)
},
)
@dataclass
class CustomArguments:
"""스크립트 전용 커스텀 인자
- lora_dropout / lora_r: LoRA 구성 하이퍼파라미터
- stop_after_n_steps: 특정 step 도달 시 조기 종료 (학습 길이 제어, 비용 절약)
- experiment_id: 직접 ID 지정 (None 이면 자동 생성 규칙 사용)
- loss_class: Stage II 에서 사용할 랭킹/대조 손실 클래스 이름
- loss_scale: 손실 내부 점수 scaling (온도/스케일 역할)
"""
lora_dropout: float = field(
default=0.05, metadata={"help": "The dropout rate for lora"}
)
lora_r: int = field(default=8, metadata={"help": "The r value for lora"})
stop_after_n_steps: int = field(
default=10000, metadata={"help": "Stop training after n steps"}
)
experiment_id: Optional[str] = field(
default=None, metadata={"help": "The experiment id"}
)
loss_class: Optional[str] = field(
default="HardNegativeNLLLoss",
metadata={
"help": "The loss class to use for training. Options: HardNegativeNLLLoss"
},
)
loss_scale: float = field(
default=50.0, metadata={"help": "The loss scale for the loss function"}
)
@dataclass
class DefaultCollator:
model: LLM2Vec
def __init__(self, model: LLM2Vec) -> None:
self.model = model
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
# features: Dataset __getitem__ 이 반환한 예시들의 리스트
# - 각 example 은 .texts (list[str]) 와 .label 속성을 가진다고 가정
# - texts[0] = query, texts[1] = positive document, texts[2..] = (optional) hard negatives
batch = features
num_texts = len(batch[0].texts) # 한 샘플 내 텍스트 슬롯 개수 (예: 2 또는 3+)
texts = [[] for _ in range(num_texts)] # 슬롯별로 모아서 토크나이즈하기 위한 버퍼
labels = [] # 예: HardNegativeNLLLoss 가 참조하는 정답 인덱스 (대개 0 또는 1) / 또는 class id
for example in batch:
for idx, text in enumerate(example.texts):
# 모델/풀링 모드에 맞게 프롬프트 래핑 → 토크나이즈 일관성 확보
text = prepare_for_tokenization_llama_3_instruct(
self.model, text, pooling_mode=self.model.pooling_mode
)
texts[idx].append(text)
labels.append(example.label)
labels = torch.tensor(labels) # (batch,)
sentence_features = []
for idx in range(num_texts):
# 동일 슬롯(예: 모든 query) 끼리 모아서 한 번에 토크나이즈 → 패딩 효율성, 길이 균질화
tokenized = self.model.tokenize(texts[idx]) # dict(input_ids, attention_mask, ...)
sentence_features.append(tokenized)
# Trainer 의 compute_loss 에서 unpack 하도록 (list[dict], labels) 형태로 반환
return sentence_features, labels
class StopTrainingCallback(TrainerCallback):
def __init__(self, stop_after_n_steps: int):
self.stop_after_n_steps = stop_after_n_steps
def on_step_end(self, args, state, control, **kwargs):
# 매 step 종료 후 global_step 체크하여 조기 종료 플래그 설정
if state.global_step >= self.stop_after_n_steps:
control.should_training_stop = True # Trainer 루프 중단 신호
class LLM2VecSupervisedTrainer(Trainer):
def __init__(
self,
*args,
loss_function=None,
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
self.loss_function = loss_function
def compute_loss(
self,
model: nn.Module,
inputs: Dict[str, Union[torch.Tensor, Any]],
return_outputs: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
# inputs: collator 가 반환한 (sentence_features_list, labels)
# - sentence_features_list: [slot0_batch_dict, slot1_batch_dict, slot2_neg_batch_dict, ...]
# - labels: (batch,) (현재 HardNegative 손실이 참조하거나 로깅용)
features, labels = inputs
# 1) Query 임베딩 추출
q_reps = self.model(features[0]) # shape: (batch, hidden)
# 2) Positive 문서 임베딩 추출
d_reps = self.model(features[1]) # shape: (batch, hidden)
# 3) (선택) Hard Negative 문서 임베딩 (있다면 첫 번째 negative 만 사용하거나
# loss 함수 내부에서 리스트로 확장 가능하도록 추가 설계 가능)
d_reps_neg = None
if len(features) > 2:
d_reps_neg = self.model(features[2]) # shape: (batch, hidden)
# 4) 손실 계산
# loss_function: load_loss 로 로드된 HardNegativeNLLLoss (예시)
# 인터페이스: loss(q_reps, pos_reps, neg_reps(optional))
loss = self.loss_function(q_reps, d_reps, d_reps_neg)
# 5) HF Trainer 와 호환을 위해 필요 시 outputs 반환
if return_outputs:
# 각 슬롯 임베딩을 (batch, num_slots, hidden) 형태로 합쳐 모니터링 가능
output = torch.cat(
[model(row)["sentence_embedding"][:, None] for row in features], dim=1
)
return loss, output
return loss
def get_train_dataloader(self) -> DataLoader:
# Copying most of the code from the parent class, changing the sampler to SequentialSampler
if self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.")
train_dataset = self.train_dataset
data_collator = self.data_collator
data_collator = self._get_collator_with_removed_columns(
data_collator, description="training"
)
dataloader_params = {
"batch_size": self._train_batch_size,
"collate_fn": data_collator, # 우리 collator: (list[slot_dict], labels) 반환
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
"persistent_workers": self.args.dataloader_persistent_workers,
}
if not isinstance(train_dataset, torch.utils.data.IterableDataset):
# 기본 RandomSampler 대신 순차 샘플링(SequentialSampler)
# - 재현성 향상 / 데이터셋 순서 기반 커스텀 처리 가능
dataloader_params["sampler"] = SequentialSampler(train_dataset)
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["worker_init_fn"] = seed_worker # 각 worker 시드 설정
return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
def _save(self, output_dir: Optional[str] = None, state_dict=None):
# process zero 전용 저장 로직 (Trainer 내부에서 호출 시 이미 rank0 보장)
output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
logger.info(f"Saving model checkpoint to {output_dir}")
# LLM2Vec 래퍼의 save 메서드 호출 (모델/토크나이저 config 등 일괄 저장)
self.model.save(output_dir)
# 학습 arguments 도 함께 저장해 재현성 확보
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
def main():
parser = HfArgumentParser(
(ModelArguments, DataTrainingArguments, TrainingArguments, CustomArguments)
)
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# 단일 JSON 설정 파일을 넘겼을 때 해당 파일에서 모든 인자 로드
model_args, data_args, training_args, custom_args = parser.parse_json_file(
json_file=os.path.abspath(sys.argv[1])
)
else:
(
model_args,
data_args,
training_args,
custom_args,
) = parser.parse_args_into_dataclasses()
# DDP 환경에서 사용하지 않는 파라미터 탐색 여부 설정 → accelerate 로 전달
if training_args.ddp_find_unused_parameters:
kwargs = [
DistributedDataParallelKwargs(
dim=0,
broadcast_buffers=True,
bucket_cap_mb=25,
find_unused_parameters=True,
check_reduction=False,
gradient_as_bucket_view=False,
)
]
else:
kwargs = []
accelerator = Accelerator(kwargs_handlers=kwargs)
set_seed(training_args.seed)
if training_args.gradient_checkpointing:
# PyTorch 2.x 와 충돌 피하기 위해 reentrant False 설정 (메모리 절감)
training_args.gradient_checkpointing_kwargs = {"use_reentrant": False}
if custom_args.experiment_id is not None:
experiment_id = custom_args.experiment_id
else:
# 실험 ID 자동 생성: 주요 하이퍼파라미터(모델명, 배치, 길이, seed ...) 조합
experiment_id = generate_experiment_id(
name=data_args.dataset_name,
split="train",
model_name=(
model_args.model_name_or_path
if "/" not in model_args.model_name_or_path
else model_args.model_name_or_path.split("/")[-1]
),
pooling_mode=model_args.pooling_mode,
train_batch_size=training_args.per_device_train_batch_size
* accelerator.num_processes
* training_args.gradient_accumulation_steps,
max_seq_length=model_args.max_seq_length,
bidirectional=model_args.bidirectional,
epochs=training_args.num_train_epochs,
seed=training_args.seed,
warmup_steps=training_args.warmup_steps,
lr=training_args.learning_rate,
lora_r=custom_args.lora_r,
use_peft=model_args.use_peft,
)
training_args.output_dir = f"{training_args.output_dir}/{experiment_id}"
# TODO: can also pass separator arg here
# 데이터셋 로드: 내부 유틸이 dataset_name/type 에 따라 로딩/전처리
train_dataset = load_dataset(
data_args.dataset_name,
split="train",
file_path=data_args.dataset_file_path,
effective_batch_size=training_args.per_device_train_batch_size
* accelerator.num_processes,
)
# 전부 메모리에 올리는 간단 구현 (대규모일 경우 IterableDataset / streaming 고려)
train_examples = [
train_dataset[i]
for i in tqdm(
range(len(train_dataset)),
desc="Loading train examples...",
disable=not accelerator.is_main_process,
)
]
torch_dtype = (
model_args.torch_dtype
if model_args.torch_dtype in ["auto", None]
else getattr(torch, model_args.torch_dtype)
)
# Stage I 에서 학습된(또는 베이스) 모델 로드 + bidirectional / pooling 설정 반영
model = LLM2Vec.from_pretrained(
base_model_name_or_path=model_args.model_name_or_path,
enable_bidirectional=model_args.bidirectional,
peft_model_name_or_path=model_args.peft_model_name_or_path,
merge_peft=True,
pooling_mode=model_args.pooling_mode,
max_length=model_args.max_seq_length,
torch_dtype=torch_dtype,
attn_implementation=model_args.attn_implementation,
)
if model_args.use_peft:
# LLM2Vec 래퍼 내부 구조: self.model (HF 모델). LoRA 는 내부 HF 모델에 직접 적용
model.model = initialize_peft(
model.model,
lora_r=custom_args.lora_r,
lora_alpha=2 * custom_args.lora_r,
lora_dropout=custom_args.lora_dropout,
)
else:
print("Not using PEFT")
tokenizer = model.tokenizer
# Hard Negative 기반 NLL / contrastive 손실 로드 (scale 은 온도역할)
train_loss = load_loss(custom_args.loss_class, scale=custom_args.loss_scale)
data_collator = DefaultCollator(model)
trainer = LLM2VecSupervisedTrainer(
model=model,
args=training_args,
train_dataset=train_examples,
data_collator=data_collator,
tokenizer=tokenizer,
loss_function=train_loss,
)
if custom_args.stop_after_n_steps is not None:
# 일정 step 이후 조기 종료 (학습 안정성/리소스 절약)
trainer.add_callback(StopTrainingCallback(custom_args.stop_after_n_steps))
# 학습 시작 - 평가 / 로깅 확장은 HuggingFace Trainer 옵션으로 추가 가능
trainer.train()
if __name__ == "__main__":
main()
주요 추가 내용
- 최상단 모듈 한글 도큐스트링: Stage I 대비 Stage II 목적(하드 네거티브 기반 감독/대조 학습)과 전체 흐름 단계별 정리.
- 토크나이즈 준비 함수: 모델별 프롬프트 래핑 이유, eos 토큰 조건 설명.
- LoRA 초기화: 어떤 모듈에 low-rank 삽입하는지와 필요성/조건 정리.
- Dataclass들: 각 필드 의미를 한국어로 상세 주석화.
- Collator: texts 슬롯 구조(q / positive / negatives)와 토크나이즈 전략(슬롯 단위 배치) 설명.
- StopTrainingCallback: 조기 종료 로직 주석.
- Trainer 서브클래스:
compute_loss: 입력 구조, 임베딩 추출 단계, 네거티브 처리, loss 함수 인터페이스 설명.
get_train_dataloader: SequentialSampler 선택 이유, collator 반환 형태.
_save: 저장 절차 및 재현성 포인트. - main() 파이프라인: 인자 파싱, accelerator/DDP 설정, experiment_id 생성, 데이터 메모리 로드 주석, 모델/PEFT 적용, 손실 로드, callback 등록, 학습 시작.
다음에 고려할 수 있는 개선 (선택)
Hard negatives 여러 개 지원: collator에서 features[2:] 모두 임베딩 후 loss_function에 리스트 전달하도록 확장.
대규모 데이터셋: 전부 리스트로 메모리에 올리는 대신 IterableDataset/streaming 도입.
평가 루프: 주기적 in-batch MRR / Recall@k 계산 콜백 추가.
experiment_id 구성 요소 옵션화 (불필요한 길이 축소).
LoRA 안 쓰는 경우 파라미터 freeze 정책(예: embedding/head만 학습) 명시.
좋은 포인트예요. 결론부터 말하면, 이 손실은 “삼중쌍(q, pos, neg)”만 쓰는 게 아니라 InfoNCE 형태로 in-batch negatives를 자동으로 포함합니다. 그리고 DDP일 때는 모든 프로세스 배치까지 all_gather로 합쳐서 “cross-device in-batch negatives”도 씁니다.
핵심 동작 요약
손실 형태: HardNegativeNLLLoss는 scores = sim(Q, [Positives_in_batch (+ Negatives_optional)]) * scale 형태의 점수 행렬을 만들고, 각 q_i의 정답은 “동일 위치의 positive”로 고정됩니다(레이블이 [0..Q-1]).
In-batch negatives: 같은 배치의 “다른” positives 전부가 자동으로 negatives가 됩니다. 단일 GPU라면 per_device_train_batch_size만큼, 멀티 GPU(DDP)면 all_gather되어 전체 글로벌 배치 크기(Σ device batch)가 negatives로 들어갑니다.
명시적 negatives: d_reps_neg가 주어지면 in-batch negatives에 더해 추가 클래스로 붙여 더 많은 negative 후보가 됩니다.
왜 in-batch가 적게 느껴지나?
단일 GPU + 작은 per_device_train_batch_size면 당연히 in-batch negatives 수가 적습니다. gradient_accumulation_steps는 누적만 할 뿐, 하나의 step에서 계산되는 손실의 “동시 후보 수”를 늘리진 않아요.
DDP가 아니면 cross-device 합치기도 안 되니 더 줄어듭니다.
실제 코드 기준 설명
HardNegativeNLLLoss는 torch.distributed.is_initialized()일 때 q/pos/neg 모두 all_gather로 모아 “전체 디바이스 배치”를 후보로 씁니다. 즉 멀티 GPU면 in-batch negatives가 크게 늘어납니다.
neg이 None이면 명시적 negatives 없이 “in-batch positives 전체”가 negatives로 기능합니다.
방금 적용한 보완
Trainer의 compute_loss에서 features[2:]에 들어있는 모든 negative 슬롯을 임베딩해 세로로 concat하도록 바꿨어요. 이전엔 첫 번째 negative만 사용했을 수 있는데, 이제는 배치당 제공되는 모든 하드 네거티브가 점수 행렬에 추가됩니다. 이렇게 하면 in-batch가 작더라도 명시적 negative 풀을 넓힐 수 있습니다.
더 크게 늘리고 싶다면
per_device_train_batch_size를 늘리기 (메모리 허용 범위 내).
멀티 GPU로 DDP 실행해 cross-device in-batch negatives 활성화.
데이터 쪽에서 negative 슬롯을 더 많이 제공하기(지금은 모두 활용).
고급: 메모리 뱅크/큐(과거 배치 임베딩을 캐시)로 cross-batch negatives 추가.
주의
후보가 늘어날수록 점수 행렬 크기가 커져 메모리/시간 비용이 증가합니다. 배치/슬롯/디바이스 수 사이에서 균형을 잡아야 합니다.
원하시면 메모리 뱅크 방식(큐 저장)이나 여러 negative 슬롯을 효율적으로 처리하는 형태로 손실 유틸/트레이너를 더 확장해드릴게요.
'인공지능 > 공부' 카테고리의 다른 글
| 딥러닝 응용 3 - Fundamentals 2 (1) | 2025.10.03 |
|---|---|
| 딥러닝 응용 1, 2 - Intro, Fundamentals 1 (0) | 2025.09.29 |
| 딥러닝 응용 - 3주차 (0) | 2025.09.15 |
| 딥러닝 응용 - 2주차 (0) | 2025.09.08 |
| mteb 한글 평가하기 (0) | 2025.08.29 |