프로젝트

[논문 발전시키기] #3. BERT, SBERT의 last hidden state 로 이것저것

나나바 2025. 2. 19. 19:25

저번 포스트에서 llava를 통해 이미지 패치에서 텍스트를 생성하였다. 그리고 기존 연구처럼 이 텍스트를 CLIP embedding 하는 것 대신 BERT를 통해 text에 대한 feature를 얻는 것을 고안을 했다. CLIP보다 과연 BERT가 더 텍스트의 문맥 정보를 잘 반영한 feature를 추출해낼 수 있을까? toy experiemnt 결과를 공유해보겠다.

 

from transformers import AutoTokenizer, BertModel

tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
bert_model = BertModel.from_pretrained("google-bert/bert-base-uncased")

우선 BERT model을 불러온다. 

text = "A rubber duck wearing a pirate hat. 1. Rubber duck: A yellow rubber duck with a pirate hat on, possibly a toy or a decoration.\n2. Rubber duck: A yellow rubber duck with a pirate hat on, possibly a children's toy or a decoration.\n3. Rubber duck: A yellow rubber duck with a pirate hat on, possibly a playful or whimsical item."
text2 = "rubber duck"

inputs = tokenizer(text, return_tensors="pt")
inputs2 = tokenizer(text2, return_tensors="pt")

with torch.no_grad():
    out1 = bert_model(**inputs)
    out2 = bert_model(**inputs2)

# Only grab the last hidden state
states1 = out1.pooler_output
states2 = out2.pooler_output

# avg1 = states1.mean(axis=0)
# avg2 = states2.mean(axis=0)
print(states1.shape)
torch.cosine_similarity(states1, states2)

 

그리고 bert model에 텍스트를 입력으로 주면 다양한 output을 얻을 수 있다. 나는 그 중에서 [CLS] 토큰의 last hidden state를 사용했는데 그 이유는 [CLS] 토큰의 last hidden state가 입력 텍스트의 전체 문맥을 담고 있다고 해서다.

 

"rubber duck" 라는 텍스트와 러버덕 이미지를 넣어 얻은 llava output인 "A rubber duck wearing a pirate hat. 1. Rubber duck: A yellow rubber duck with a pirate hat on, possibly a toy or a decoration.\n2. Rubber duck: A yellow rubber duck with a pirate hat on, possibly a children's toy or a decoration.\n3. Rubber duck: A yellow rubber duck with a pirate hat on, possibly a playful or whimsical item." 과의 cosine 유사도를 측정해보았을 때 `0.4628` 이 나왔다. 생각한 것보다는 낮은 유사도였다. 

 

하지만 더 큰 문제는 저 `text2`를 "table" 로 대체하여 llava output (`text`)과의 cosine 유사도를 측정하였을 때 `0.6362`가 나온 것이다. 'table' 처럼 러버덕과 의미가 먼 텍스트는 `text2`와의 코사인 유사도가 낮을 것을 기대하였기 때문이다. 

 

 

그래서 전체 토큰의 last hidden state를 모두 가져와 평균내는 방법을 사용해보았다.

text = "A rubber duck wearing a pirate hat. 1. Rubber duck: A yellow rubber duck with a pirate hat on, possibly a toy or a decoration.\n2. Rubber duck: A yellow rubber duck with a pirate hat on, possibly a children's toy or a decoration.\n3. Rubber duck: A yellow rubber duck with a pirate hat on, possibly a playful or whimsical item."
text2 = "rubber duck"

inputs = tokenizer(text, return_tensors="pt")
inputs2 = tokenizer(text2, return_tensors="pt")

with torch.no_grad():
    out1 = bert_model(**inputs)
    out2 = bert_model(**inputs2)

# Only grab the last hidden state
states1 = out1.last_hidden_state.squeeze()
states2 = out2.last_hidden_state.squeeze()

# print(states1.shape)
avg1 = states1.mean(axis=0)
avg2 = states2.mean(axis=0)

# print(avg1.shape)
torch.cosine_similarity(avg1.reshape(1,-1), avg2.reshape(1,-1))

그랬더니 llava output과 "rubber duck"간의 cosine 유사도는 `0.4832`였고 "table"과의 cosine 유사도는 `0.2457` 이었다. 이전보다는 좀 더 나은 결과이긴 하지만 여전히 "table"과의 유사도가 크다. (나는 그렇게 생각했다)

 

그래서 다른 모델을 찾다 SBERT 라는 것을 발견했다. SBERT는 Sentence-BERT라고 하는데, BERT보다 text 임베딩의 성능을 더 우수하게 개선시킨 모델이라고 한다. 즉, BERT보다 텍스트를 더 잘 요약하는 피쳐 벡터를 생성할 수 있다는 것이다.

https://arxiv.org/abs/1908.10084

 

Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks

BERT (Devlin et al., 2018) and RoBERTa (Liu et al., 2019) has set a new state-of-the-art performance on sentence-pair regression tasks like semantic textual similarity (STS). However, it requires that both sentences are fed into the network, which causes a

arxiv.org

 

그래서 BERT대신 SBERT를 사용하여 feature를 추출해보았다.

from sentence_transformers import SentenceTransformer
# Initializing the Sentence Transformer model using BERT with mean-tokens pooling
model = SentenceTransformer('bert-base-nli-mean-tokens')

 

SBERT를 불러오는 방법도 매우 간단하다. hugging face에서 제공하는 sentence transformer를 불러오면 된다.

texta = "A rubber duck wearing a pirate hat. 1. Rubber duck: A yellow rubber duck with a pirate hat on, possibly a toy or a decoration.\n2. Rubber duck: A yellow rubber duck with a pirate hat on, possibly a children's toy or a decoration.\n3. Rubber duck: A yellow rubber duck with a pirate hat on, possibly a playful or whimsical item."
text2 = "rubber duck"

sentence_embeddingsa = model.encode(texta)
sentence_embeddings2 = model.encode(text2)

a = cos_sim(sentence_embeddingsa, sentence_embeddings2)

print(a)

 

마찬가지로 llava가 생성한 텍스트와 "rubber duck"과의 유사도를 측정했을 때 `0.605`가 나왔다! 기존 BERT보다 더 이상적인 유사도가 나왔다. 더 긍정적인 것은 "table"과의 유사도를 측정하였을 때 `-0.149`가 나왔다! 관련이 없는 텍스트와는 이상적으로 매우 낮은 유사도가 측정되었다. 그래서 나는 최종적으로 SBERT를 활용하여 language field를 학습시킬 GT를 생성하기로 하였다.

 

 

나처럼 텍스트 피쳐 간의 코사인 유사도를 측정해야하는 테스크를 수행할 때 BERT보다 SBERT를 사용하면 더 정확한 feature를 추출할 수 있을 것이다.