Skip to content

Commit bc28c41

Browse files
committed
change precision
1 parent 2812307 commit bc28c41

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

compute_mexa.py

+2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import argparse
88

99
def cosine_similarity(array1, array2):
10+
array1 = array1.astype(np.float64)
11+
array2 = array2.astype(np.float64)
1012
cosine_dist = cosine(array1, array2)
1113
cosine_similarity = 1 - cosine_dist
1214
return cosine_similarity

embed_extractor.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def weighted_embeddings(layer, attention_mask, device='cuda'):
1313
sum_embeddings = torch.sum(layer * weights_for_non_padding.unsqueeze(-1), dim=1)
1414
num_of_non_padding_tokens = torch.sum(weights_for_non_padding, dim=-1).unsqueeze(-1)
1515
sentence_embeddings = sum_embeddings / num_of_non_padding_tokens
16-
sentence_embeddings = sentence_embeddings.squeeze().to(torch.float16).cpu().numpy()
16+
sentence_embeddings = sentence_embeddings.squeeze().to(torch.float32).cpu().numpy()
1717
return sentence_embeddings
1818

1919

@@ -22,7 +22,7 @@ def lasttoken_embeddings(layer, attention_mask, device='cuda'):
2222
idx_of_last_token = attention_mask.bool().sum().item() - 1 # scalar index
2323
# Extract the embedding from the layer
2424
embedding = layer[0, idx_of_last_token, :] # shape: [hidden_dim]
25-
sentence_embedding = embedding.to(torch.float16).cpu().numpy()
25+
sentence_embedding = embedding.to(torch.float32).cpu().numpy()
2626
return sentence_embedding
2727

2828

0 commit comments

Comments
 (0)