저번 포스트에서 llava를 통해 이미지 패치에서 텍스트를 생성하였다. 그리고 기존 연구처럼 이 텍스트를 CLIP embedding 하는 것 대신 BERT를 통해 text에 대한 feature를 얻는 것을 고안을 했다. CLIP보다 과연 BERT가 더 텍스트의 문맥 정보를 잘 반영한 feature를 추출해낼 수 있을까? toy experiemnt 결과를 공유해보겠다.
from transformers import AutoTokenizer, BertModel
tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
bert_model = BertModel.from_pretrained("google-bert/bert-base-uncased")
우선 BERT model을 불러온다.
text = "A rubber duck wearing a pirate hat. 1. Rubber duck: A yellow rubber duck with a pirate hat on, possibly a toy or a decoration.\n2. Rubber duck: A yellow rubber duck with a pirate hat on, possibly a children's toy or a decoration.\n3. Rubber duck: A yellow rubber duck with a pirate hat on, possibly a playful or whimsical item."
text2 = "rubber duck"
inputs = tokenizer(text, return_tensors="pt")
inputs2 = tokenizer(text2, return_tensors="pt")
with torch.no_grad():
out1 = bert_model(**inputs)
out2 = bert_model(**inputs2)
# Only grab the last hidden state
states1 = out1.pooler_output
states2 = out2.pooler_output
# avg1 = states1.mean(axis=0)
# avg2 = states2.mean(axis=0)
print(states1.shape)
torch.cosine_similarity(states1, states2)
그리고 bert model에 텍스트를 입력으로 주면 다양한 output을 얻을 수 있다. 나는 그 중에서 [CLS] 토큰의 last hidden state를 사용했는데 그 이유는 [CLS] 토큰의 last hidden state가 입력 텍스트의 전체 문맥을 담고 있다고 해서다.
"rubber duck" 라는 텍스트와 러버덕 이미지를 넣어 얻은 llava output인 "A rubber duck wearing a pirate hat. 1. Rubber duck: A yellow rubber duck with a pirate hat on, possibly a toy or a decoration.\n2. Rubber duck: A yellow rubber duck with a pirate hat on, possibly a children's toy or a decoration.\n3. Rubber duck: A yellow rubber duck with a pirate hat on, possibly a playful or whimsical item." 과의 cosine 유사도를 측정해보았을 때 `0.4628` 이 나왔다. 생각한 것보다는 낮은 유사도였다.
하지만 더 큰 문제는 저 `text2`를 "table" 로 대체하여 llava output (`text`)과의 cosine 유사도를 측정하였을 때 `0.6362`가 나온 것이다. 'table' 처럼 러버덕과 의미가 먼 텍스트는 `text2`와의 코사인 유사도가 낮을 것을 기대하였기 때문이다.
그래서 전체 토큰의 last hidden state를 모두 가져와 평균내는 방법을 사용해보았다.
text = "A rubber duck wearing a pirate hat. 1. Rubber duck: A yellow rubber duck with a pirate hat on, possibly a toy or a decoration.\n2. Rubber duck: A yellow rubber duck with a pirate hat on, possibly a children's toy or a decoration.\n3. Rubber duck: A yellow rubber duck with a pirate hat on, possibly a playful or whimsical item."
text2 = "rubber duck"
inputs = tokenizer(text, return_tensors="pt")
inputs2 = tokenizer(text2, return_tensors="pt")
with torch.no_grad():
out1 = bert_model(**inputs)
out2 = bert_model(**inputs2)
# Only grab the last hidden state
states1 = out1.last_hidden_state.squeeze()
states2 = out2.last_hidden_state.squeeze()
# print(states1.shape)
avg1 = states1.mean(axis=0)
avg2 = states2.mean(axis=0)
# print(avg1.shape)
torch.cosine_similarity(avg1.reshape(1,-1), avg2.reshape(1,-1))
그랬더니 llava output과 "rubber duck"간의 cosine 유사도는 `0.4832`였고 "table"과의 cosine 유사도는 `0.2457` 이었다. 이전보다는 좀 더 나은 결과이긴 하지만 여전히 "table"과의 유사도가 크다. (나는 그렇게 생각했다)
그래서 다른 모델을 찾다 SBERT 라는 것을 발견했다. SBERT는 Sentence-BERT라고 하는데, BERT보다 text 임베딩의 성능을 더 우수하게 개선시킨 모델이라고 한다. 즉, BERT보다 텍스트를 더 잘 요약하는 피쳐 벡터를 생성할 수 있다는 것이다.
https://arxiv.org/abs/1908.10084
Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks
BERT (Devlin et al., 2018) and RoBERTa (Liu et al., 2019) has set a new state-of-the-art performance on sentence-pair regression tasks like semantic textual similarity (STS). However, it requires that both sentences are fed into the network, which causes a
arxiv.org
그래서 BERT대신 SBERT를 사용하여 feature를 추출해보았다.
from sentence_transformers import SentenceTransformer
# Initializing the Sentence Transformer model using BERT with mean-tokens pooling
model = SentenceTransformer('bert-base-nli-mean-tokens')
SBERT를 불러오는 방법도 매우 간단하다. hugging face에서 제공하는 sentence transformer를 불러오면 된다.
texta = "A rubber duck wearing a pirate hat. 1. Rubber duck: A yellow rubber duck with a pirate hat on, possibly a toy or a decoration.\n2. Rubber duck: A yellow rubber duck with a pirate hat on, possibly a children's toy or a decoration.\n3. Rubber duck: A yellow rubber duck with a pirate hat on, possibly a playful or whimsical item."
text2 = "rubber duck"
sentence_embeddingsa = model.encode(texta)
sentence_embeddings2 = model.encode(text2)
a = cos_sim(sentence_embeddingsa, sentence_embeddings2)
print(a)
마찬가지로 llava가 생성한 텍스트와 "rubber duck"과의 유사도를 측정했을 때 `0.605`가 나왔다! 기존 BERT보다 더 이상적인 유사도가 나왔다. 더 긍정적인 것은 "table"과의 유사도를 측정하였을 때 `-0.149`가 나왔다! 관련이 없는 텍스트와는 이상적으로 매우 낮은 유사도가 측정되었다. 그래서 나는 최종적으로 SBERT를 활용하여 language field를 학습시킬 GT를 생성하기로 하였다.
나처럼 텍스트 피쳐 간의 코사인 유사도를 측정해야하는 테스크를 수행할 때 BERT보다 SBERT를 사용하면 더 정확한 feature를 추출할 수 있을 것이다.
'프로젝트' 카테고리의 다른 글
[논문 발전시키기] #2. LLaVA를 통해 이미지 패치를 설명하는 텍스트 생성하기 (0) | 2025.02.14 |
---|---|
[논문 발전시키기] #1. 주제 잡기: Open-Vocabulary 3D segmentation using textual semantics (0) | 2025.02.10 |
LLaMa 3.1 fine tuning 해서 챗봇 만들기 (1) | 2025.02.05 |