인공지능/자연어 처리

vllm 활용해서 logit 추출 및 logprob, CoT, SC-CoT Inference 진행

이게될까 2025. 4. 2. 23:47
728x90
728x90

class로 된 python이라 self나 다른 것 들이 붙어있긴 한데 적당히 보면 될 것 같습니다.

기록 용이라....

from datasets import load_from_disk, DatasetDict
import argparse, os, json, torch, itertools, math, re
from typing import List, Dict, Tuple
from scipy.special import digamma
from vllm import LLM, SamplingParams
from collections import defaultdict, Counter
from transformers import AutoTokenizer
from setproctitle import setproctitle

 

일단 전부 import 다 해두고...

 

sampling_params = SamplingParams(
        max_tokens=2048,
        temperature=temperature, 
        logprobs=top_k, 
    )

vllm의 generation 혹은 chat에서 log prob를 사용하기 위해 samplinParams에 logprobs에 topk를 넣어줍니다.

 

저번 vllm에 활용할 때 모델 쳇 탬플릿을 그대로 넣어주느라 귀찮았던 것이 있었습니다.

이번엔 dict 구조를 활용해 prompt를 작성하고, 그걸 vllm의 chat에 넣어줘 진행하였습니다.

이건 Zero-shot 일반 출력이었습니다.

input_prompts = [[
                {"role": "system", "content": self.prompts['system']['base']},
                {"role": "user", "content": self.prompts['generation']['base'].format(question=query)}
            ] for query in queries]
        
outputs = self.model.chat(input_prompts, self.sampling_params)
outputs = [output.outputs[0].text for output in outputs]

요렇게 진행하면 모델의 쳇 탬플릿에 맞게 알아서 들어갑니다.

    def extract_most_common_finalanswers(self, all_outputs: List[List[str]]) -> List[str]:
        final_outputs = []
        finalanswer_pattern = re.compile(r"<finalanswer>(.*?)</finalanswer>", re.DOTALL)

        def normalize(text: str) -> str:
            return text.strip().replace(" ", "").lower()

        for i, sc_outputs in enumerate(all_outputs):
            try:
                extracted = []
                for output in sc_outputs:
                    match = finalanswer_pattern.search(output)
                    if match:
                        normalized = normalize(match.group(1))
                        extracted.append(normalized)

                if not extracted:
                    print(f"[Warning] No <finalanswer> found for idx {i}.")
                    final_outputs.append(sc_outputs[0])
                    continue
                #print(sc_outputs)
                print(Counter(extracted))
                most_common = Counter(extracted).most_common(1)[0][0]

                # 원본 복원: 정확히 일치하는 정답만
                for output in sc_outputs:
                    match = finalanswer_pattern.search(output)
                    if match and normalize(match.group(1)) == most_common:
                        final_outputs.append(output)
                        break
                else:
                    final_outputs.append(sc_outputs[0])

            except Exception as e:
                print(f"[Error] Processing failed at idx {i}: {e}")
                final_outputs.append(sc_outputs[0])

        return final_outputs

SC-CoT를 진행하기 위해 답을 모아 가장 많은 수의 답을 선택하여 넘기는 과정입니다.


        self.sampling_params =  SamplingParams(
            max_tokens=2048,
            temperature=self.temperature,
            top_k=50,     
            top_p=0.95
            )
        
        # num_SC 배수만큼 복제
        expanded_prompts = list(itertools.chain.from_iterable([base_prompts] * num_SC))

        outputs = self.model.chat(expanded_prompts, self.sampling_params)
        flat_outputs = [output.outputs[0].text for output in outputs]

        # [batch_size * num_SC] → [batch_size][num_SC] 로 재구성
        grouped_outputs = [
            flat_outputs[i * num_SC:(i + 1) * num_SC] for i in range(len(queries))
        ]
        final_outputs = self.extract_most_common_finalanswers(grouped_outputs)
        
        
        self.sampling_params = SamplingParams(
            max_tokens=2048,
            temperature=self.temperature, 
            logprobs=self.top_k, 
            )

num_SC만큼 Sampling을 진행하여 가장 많은 수의 답을 선택해서 끝냅니다.

정답의 다양성을 위해 온도와 topk. topp를 늘렸는데 좀 바보가 되긴 하더라고요....

 

logprobs_seq = output.outputs[0].logprobs

logprobs는 이렇게 이렇게 접근이 가능합니다.

 

이제 logit에 접근합니다.

class TopKLogitCapture:
    def __init__(self, tokenizer, top_k: int = 10, temperature: float = 0.2):
        self.top_k = top_k
        self.temperature = temperature
        self.logit_store = defaultdict(list)
        self.tokenizer = tokenizer

    def __call__(self, prompt_token : List[int], token_ids: List[int], logits: torch.Tensor) -> torch.Tensor:
        logits = logits.float().cpu()
        scaled_logits = logits / self.temperature
        probs = F.softmax(scaled_logits, dim=-1)
        logprobs = torch.log(probs + 1e-8) 
        
        logprobs, top_indices = torch.topk(logprobs, self.top_k)
        
        raw_values = logits[top_indices]

      
        self.logit_store[prompt_hash].append({
            "token_ids": top_indices.numpy().tolist(),
            "top_tokens": [
                (self.tokenizer.decode([idx.item()]), idx.item()) 
                for idx in top_indices
            ],
            "raw_logits": raw_values.numpy().tolist(),
            "probs": probs.numpy().tolist(),
            "logprobs": logprobs.numpy().tolist(), 
            "prompt_token": prompt_token,  # 원본 저장
            "generated_len": len(token_ids)
        })
        
        return logits  # Return original logits for actual generation
        
        
    def reset(self):
        self.logit_store = defaultdict(list)

이 것만 있으면 지속적으로 logit이 출력됩니다.

 

    logit_capture = TopKLogitCapture(
        tokenizer = tokenizer,
        top_k=top_k,
        temperature=temperature
    )
    sampling_params = SamplingParams(
        max_tokens=2048,
        temperature=temperature, 
        logits_processors=[logit_capture]
    )

이렇게 SamplingParams에 저 Class를 집어넣어 줘야 합니다.

 

그런데 이Logit은 계속해서 쌓이기만 합니다.

무조건 reset()을 한번씩 해줘서 지워줘야 해요 

        decomp_logits = self.logit_capture.logit_store.copy()

이렇게 복사한 Logit을 사용하면 됩니다.

 

이게 성공만 했으면 좋았을텐데

너무 아쉽네요,...

 

 

 

 

728x90