KeiStory

LangChain에서 Rerank로 RAG 정확도 높이기

 

백터 검색만으로는 관련성 높은 문서를 찾기 어려울 수 있는데 Rerank 를 이용하면 결과의 순위를 재조정하는 과정을 거쳐 더 관련성 높은 문서를 선별할수 있게 됩니다.

일반적인 백터 검색은 Bi-encoder 를 사용해 질문과 문서를 각각 인코딩하나
Rerank 에서는 Cross-encoder 를 사용해 질문과 문서를 동시에 입력으로 받아 더 정확한 관련성 점수를 계산하게 됩니다.

아래 예시는 BAAI/bge-reranker-large 모델을 한국어 데이터로 파인튜닝한 'Dongjin-kr/ko-reranker' 를 사용했습니다.

패키지 설치

pip install transformers torch numpy

 

main.py

from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch
import numpy as np

# 소프트맥스 계산 함수
def exp_normalize(x):
    b = x.max()  # 점수 중 최대값
    y = np.exp(x - b)  # 안정적인 계산을 위한 점수 차이의 지수값
    return y / y.sum()  # 전체 합으로 나누어 확률화

# 질문-응답 쌍 (리스트 형태)
pairs = [
    ['대한민국의 수도는 어디인가요?', '서울은 공기가 나쁘고 차가 많이 막힙니다.'], 
    ['대한민국의 수도는 어디인가요?', '대한민국의 독도는 우리땅입니다.'], 
    ['대한민국의 수도는 어디인가요?', '대한민국의 수도는 상하이입니다.'], 
    ['대한민국의 수도는 어디인가요?', '대한민국의 수도는 서울입니다.'], 
]

# 사전 학습된 모델 및 토크나이저 경로
model_path = "Dongjin-kr/ko-reranker"

# 토크나이저 로드 (입력을 모델에 맞게 변환)
tokenizer = AutoTokenizer.from_pretrained(model_path)

# 모델 로드 (문장 쌍 유사도 평가 모델)
model = AutoModelForSequenceClassification.from_pretrained(model_path)
model.eval()  # 모델을 평가 모드로 전환 (학습 X)

# 입력 데이터 전처리 및 모델 실행
with torch.no_grad():  # 그래디언트 계산 비활성화 (추론 모드)
    # 입력 데이터를 토크나이저로 처리
    inputs = tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=512)
    
    # 모델로부터 점수 (logits) 예측
    scores = model(**inputs, return_dict=True).logits.view(-1, ).float()
    
    # 점수를 소프트맥스 방식으로 정규화
    scores = exp_normalize(scores.numpy()) 

# 점수 출력 (소수점 2자리로 반올림)
print(np.round(scores * 100, 2))

 

결과

 

위 예제는 어떤 응답에 대한 결과들이 있다고 가정했을 때 결과 목록중 서로의 상관관계 지수를 나타냅니다.

결과를 보면 알수 있듯이 마지막 질문과 응답이 관계가 높다는 판단을 내리는걸 알수 있습니다.

만약 일반적인 백터 검색을 했다면 대한민국의 수도가 상하이라는 답번을 줄수도 있지만

Rarank 를 이용하면 제대로 된 결과를 반환해주게됩니다.

 

requirements.txt

certifi==2024.12.14
charset-normalizer==3.4.1
colorama==0.4.6
filelock==3.16.1
fsspec==2024.12.0
huggingface-hub==0.27.0
idna==3.10
Jinja2==3.1.5
MarkupSafe==3.0.2
mpmath==1.3.0
networkx==3.4.2
numpy==2.2.1
packaging==24.2
PyYAML==6.0.2
regex==2024.11.6
requests==2.32.3
safetensors==0.4.5
sympy==1.13.1
tokenizers==0.21.0
torch==2.5.1
tqdm==4.67.1
transformers==4.47.1
typing_extensions==4.12.2
urllib3==2.3.0

 

반응형

공유하기

facebook twitter kakaoTalk kakaostory naver band