1
+ import loguru
2
+
3
+ from transformers import AutoTokenizer , AutoModel
4
+ import tensorflow as tf
5
+ import torch
6
+
7
+ def load_torch_model (model_path ):
8
+ model = AutoModel .from_pretrained (model_path )
9
+ return model
10
+
11
+
12
+ def load_tf_model (model_path ):
13
+ with tf .device ("/CPU:0" ):
14
+ model = tf .saved_model .load (model_path )
15
+ return model
16
+
17
+
18
+ def load_tokenizer (model_path ):
19
+ tokenizer = AutoTokenizer .from_pretrained (model_path )
20
+ return tokenizer
21
+
22
+
23
+ def tokenize_wo_padding (tokenizer , text , return_tensors = "pt" ):
24
+ return tokenizer (text , padding = False , return_tensors = return_tensors )
25
+
26
+
27
+ def tokenize_w_padding (tokenizer , text , return_tensors = "pt" , max_length = 512 ):
28
+ return tokenizer (text , padding = "max_length" , max_length = max_length , return_tensors = return_tensors )
29
+
30
+
31
+ def main ():
32
+ # Load the model
33
+ model_path = "BAAI/bge-m3"
34
+ model_path_tf = "/workspace/BGE-M3-Model-Converter/model"
35
+ model = load_torch_model (model_path )
36
+ tokenizer = load_tokenizer (model_path )
37
+
38
+ # Tokenize the text
39
+ text = "Hello, my dog is cute"
40
+ inputs = tokenize_wo_padding (tokenizer , text )
41
+ inputs_w_padding = tokenize_w_padding (tokenizer , text )
42
+
43
+ # Get the output from the model
44
+ loguru .logger .info ("Torch] Model output" .ljust (50 , "-" ))
45
+ model .eval ().to ("cuda" )
46
+ with torch .no_grad ():
47
+ inputs = {k : v .to ("cuda" ) for k , v in inputs .items ()}
48
+ inputs_w_padding = {k : v .to ("cuda" ) for k , v in inputs_w_padding .items ()}
49
+
50
+ output = model (** inputs )
51
+ output_w_padding = model (** inputs_w_padding )
52
+ loguru .logger .info ("output without padding (GT)" .ljust (50 , "-" ))
53
+ loguru .logger .info (output ['last_hidden_state' ][:, 0 ])
54
+ loguru .logger .info ("=" * 50 )
55
+ loguru .logger .info ("output with padding" .ljust (50 , "-" ))
56
+ loguru .logger .info (output_w_padding ['last_hidden_state' ][:, 0 ])
57
+ loguru .logger .info ("=" * 50 )
58
+ err = torch .abs (output ['last_hidden_state' ][:, 0 ] - output_w_padding ['last_hidden_state' ][:, 0 ])
59
+ loguru .logger .info ("Error" .ljust (50 , "-" ))
60
+ loguru .logger .info (err .mean ())
61
+
62
+ inputs_tf = tokenize_wo_padding (tokenizer , text , return_tensors = "tf" )
63
+ inputs_tf_w_padding = tokenize_w_padding (tokenizer , text , return_tensors = "tf" )
64
+ inputs_tf_w_padding_attnFixed = inputs_tf_w_padding .copy ()
65
+ inputs_tf_w_padding_attnFixed ['attention_mask' ] = tf .where (inputs_tf_w_padding ['attention_mask' ] == 0 , - 9999999 , 0 )
66
+ tf_model = load_tf_model (model_path_tf ).signatures ["serving_default" ]
67
+
68
+ loguru .logger .info ("Tensorflow] Model output" .ljust (50 , "-" ))
69
+ with tf .device ("/GPU:0" ):
70
+ output_tf = tf_model (** inputs_tf )
71
+ output_tf_w_padding = tf_model (** inputs_tf_w_padding )
72
+ output_tf_w_padding_attnFixed = tf_model (** inputs_tf_w_padding_attnFixed )
73
+ loguru .logger .info ("output without padding (GT)" .ljust (50 , "-" ))
74
+ loguru .logger .info (output_tf ['hidden_states' ][- 1 ][:,0 ])
75
+ loguru .logger .info ("=" * 50 )
76
+ loguru .logger .info ("output with padding" .ljust (50 , "-" ))
77
+ loguru .logger .info (output_tf_w_padding ['hidden_states' ][- 1 ][:,0 ])
78
+ loguru .logger .info ("=" * 50 )
79
+ loguru .logger .info ("output with padding (attention fixed)" .ljust (50 , "-" ))
80
+ loguru .logger .info (output_tf_w_padding_attnFixed ['hidden_states' ][- 1 ][:,0 ])
81
+ loguru .logger .info ("=" * 50 )
82
+ err_tf = tf .abs (output_tf ['hidden_states' ][- 1 ][:,0 ] - output_tf_w_padding ['hidden_states' ][- 1 ][:,0 ])
83
+ loguru .logger .info ("Error" .ljust (50 , "-" ))
84
+ loguru .logger .info (tf .reduce_mean (err_tf ))
85
+ loguru .logger .info ("=" * 50 )
86
+ err_tf_attnFixed = tf .abs (output_tf_w_padding ['hidden_states' ][- 1 ][:,0 ] - output_tf_w_padding_attnFixed ['hidden_states' ][- 1 ][:,0 ])
87
+ loguru .logger .info ("Error (attention fixed)" .ljust (50 , "-" ))
88
+ loguru .logger .info (tf .reduce_mean (err_tf_attnFixed ))
89
+ loguru .logger .info ("=" * 50 )
90
+
91
+
92
+
93
+ if __name__ == "__main__" :
94
+ main ()
0 commit comments