인공지능/자연어 처리

Late Chunking 사용해보기 및 Chunking 코드 익숙해지기

이게될까 2025. 1. 22. 11:22
728x90
728x90

https://github.com/jina-ai/late-chunking

 

GitHub - jina-ai/late-chunking: Code for explaining and evaluating late chunking (chunked pooling)

Code for explaining and evaluating late chunking (chunked pooling) - jina-ai/late-chunking

github.com

 

일단 코드는 여기서 나왔습니다.

코드에 익숙해지기 위해 조금 제맘대로 파 해치기도 했습니다.

청크 풀링 (Chunked Pooling)

그 다음으로, 우리가 임베딩에 사용할 모델을 로드합니다. 여기에서는 jinaai/jina-embeddings-v2-base-en을 선택했지만, 평균 풀링(mean pooling)을 지원하는 다른 모델도 사용할 수 있습니다. 다만, 최대 컨텍스트 길이가 긴 모델을 사용하는 것이 권장됩니다.

 

from transformers import AutoModel
from transformers import AutoTokenizer

from chunked_pooling import chunked_pooling, chunk_by_sentences

# load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained('jinaai/jina-embeddings-v2-base-en', trust_remote_code=True)
model = AutoModel.from_pretrained('jinaai/jina-embeddings-v2-base-en', trust_remote_code=True)

모델을 불러옵니다!

이제 인코딩하려는 텍스트를 정의하고 이를 청크로 분할합니다. chunk_by_sentences 함수는 청크별 토큰 수를 나타내는 범위(span) 주석도 반환합니다. 이 주석은 청크 풀링(chunked pooling)에 필요한 정보를 제공합니다.

input_text = "Berlin is the capital and largest city of Germany, both by area and by population. Its more than 3.85 million inhabitants make it the European Union's most populous city, as measured by population within city limits. The city is also one of the states of Germany, and is the third smallest state in the country in terms of area."

# determine chunks
chunks, span_annotations = chunk_by_sentences(input_text, tokenizer)
print('Chunks:\n- "' + '"\n- "'.join(chunks) + '"')

Chunks:
- "Berlin is the capital and largest city of Germany, both by area and by population."
 " Its more than 3.85 million inhabitants make it the European Union's most populous city, as measured by population within city limits."
- " The city is also one of the states of Germany, and is the third smallest state in the country in terms of area."

 

여기서 이게 어떻게 이렇게 나오나 확인해봤습니다.

inputs = tokenizer(input_text, return_tensors='pt', return_offsets_mapping=True)
token_offsets = inputs['offset_mapping'][0] # 각 토큰의 시작과 끝으로 문자자 수와 위치를 알 수 있음
token_ids = inputs['input_ids'][0] # 이건 각 토큰

punctuation_mark_id = tokenizer.convert_tokens_to_ids('.') # .토큰 번호
sep_id = tokenizer.convert_tokens_to_ids('[SEP]') # sep 토큰 번호

각 토큰의 시작과 끝 위치를 알 수 있는 것도 처음 알았습니다....

chunk_positions = [
    (i, int(start + 1))  # 추출할 값: 인덱스와 시작 위치
    for i, (token_id, (start, end)) in enumerate(zip(token_ids, token_offsets))  # 토큰 ID와 오프셋 정보를 순회
    if token_id == punctuation_mark_id  # 조건 1: 특정 구두점 토큰에 해당하는 경우
    and (
        token_offsets[i + 1][0] - token_offsets[i][1] > 0  # 조건 2a: 다음 토큰과 현재 토큰 사이에 간격이 있는 경우
        or token_ids[i + 1] == sep_id  # 조건 2b: 다음 토큰이 구분자 토큰인 경우
    )
]

음.... 이걸 말로 풀어서 설명하면

토큰 번호와 토큰의 위치정보를 가지고 for 문을 도는데 토큰 아이디가 .이고, 토큰 간 공백이 있거나 SEP 토큰인 경우 그 토큰 위치와 문자열 상 위치를 x,y 좌표로 출력해줍니다.

chunks = [
        input_text[x[1] : y[1]]
        for x, y in zip([(1, 0)] + chunk_positions[:-1], chunk_positions)
    ]
span_annotations = [
        (x[0], y[0]) for (x, y) in zip([(1, 0)] + chunk_positions[:-1], chunk_positions)
    ]

이렇게 하면 아까 진행했던 청크로 나뉘어진 문장이 나오게 됩니다. 

Chunk는 뒤쪽 번호를 활용하여 문자열을 나눈 청크를 출력해주고, span_annotations는 그 문자열의 토큰 위치를 출력합니다.

 

이제 청크를 전통적인 방식과 컨텍스트에 민감한 청크 풀링(context-sensitive chunked pooling) 방식을 사용하여 인코딩합니다:

# chunk before
embeddings_traditional_chunking = model.encode(chunks)

# chunk afterwards (context-sensitive chunked pooling)
inputs = tokenizer(input_text, return_tensors='pt')
model_output = model(**inputs)
embeddings = chunked_pooling(model_output, [span_annotations])[0]
 

이제 여기서 traditional chunking에서는 기존에 진행했던 chunk된 text를 집어넣고, Late Chunking은 모델에 텍스트 원본을 그대로 집어넣고, 후에 나눠주는 방식입니다. 

Chunked Pooling 함수를 한번 확인해 봤습니다.

def chunked_pooling(
    model_output: 'BatchEncoding', span_annotation: list, max_length=None
):
    token_embeddings = model_output[0]
    outputs = []
    for embeddings, annotations in zip(token_embeddings, span_annotation):
        if (
            max_length is not None
        ):  # remove annotations which go bejond the max-length of the model
            annotations = [
                (start, min(end, max_length - 1))
                for (start, end) in annotations
                if start < (max_length - 1)
            ]
        pooled_embeddings = [
            embeddings[start:end].sum(dim=0) / (end - start)
            for start, end in annotations
            if (end - start) >= 1
        ]
        pooled_embeddings = [
            embedding.float().detach().cpu().numpy() for embedding in pooled_embeddings
        ]
        outputs.append(pooled_embeddings)

    return outputs

 

 

마지막으로, 단어 "Berlin"과 각 청크 간의 유사도를 비교합니다. 컨텍스트에 민감한 청크 풀링(context-sensitive chunked pooling) 방식을 사용할 경우, 유사도가 더 높게 나타나야 합니다:

 
 
import numpy as np

cos_sim = lambda x, y: np.dot(x, y) / (np.linalg.norm(x) * np.linalg.norm(y))

berlin_embedding = model.encode('Berlin')

for chunk, new_embedding, trad_embeddings in zip(chunks, embeddings, embeddings_traditional_chunking):
    print(f'similarity_new("Berlin", "{chunk}"):', cos_sim(berlin_embedding, new_embedding))
    print(f'similarity_trad("Berlin", "{chunk}"):', cos_sim(berlin_embedding, trad_embeddings))

similarity_new("Berlin", "Berlin is the capital and largest city of Germany, both by area and by population."): 0.849546
similarity_trad("Berlin", "Berlin is the capital and largest city of Germany, both by area and by population."): 0.84862185
similarity_new("Berlin", " Its more than 3.85 million inhabitants make it the European Union's most populous city, as measured by population within city limits."): 0.82489026
similarity_trad("Berlin", " Its more than 3.85 million inhabitants make it the European Union's most populous city, as measured by population within city limits."): 0.708434
similarity_new("Berlin", " The city is also one of the states of Germany, and is the third smallest state in the country in terms of area."): 0.8498008
similarity_trad("Berlin", " The city is also one of the states of Germany, and is the third smallest state in the country in terms of area."): 0.75345546

 

이렇게 크게 별 일 없이 코드는 종료 됩니다.괜히 어렵게 생각한 것 같네요 ㅎㅎ...

 

728x90