@@ -32,11 +32,13 @@ def __init__(self, d_model, num_heads, dropout_rate=0.1, **kwargs):
32
32
# 드롭아웃
33
33
self .dropout = tf .keras .layers .Dropout (dropout_rate )
34
34
35
- def stable_softmax (self , logits , axis = None , name = None ):
36
- """
37
- Stable softmax implementation
38
- """
39
- return tf .nn .softmax (logits = logits + 1e-9 , axis = axis , name = name )
35
+ def stable_softmax (self , logits , axis = - 1 , name = None ):
36
+ """Numerically stable softmax: subtract max and compute in float32."""
37
+ dtype = logits .dtype
38
+ x = tf .cast (logits , tf .float32 )
39
+ x = x - tf .reduce_max (x , axis = axis , keepdims = True )
40
+ probs = tf .nn .softmax (x , axis = axis , name = name )
41
+ return tf .cast (probs , dtype )
40
42
41
43
def split_heads (self , x , batch_size ):
42
44
x = tf .reshape (x , (batch_size , - 1 , self .num_heads , self .depth ))
@@ -45,27 +47,29 @@ def split_heads(self, x, batch_size):
45
47
def call (self , inputs , mask = None , training = False ):
46
48
batch_size = tf .shape (inputs )[0 ]
47
49
48
- # Query, Key, Value를 계산
49
- q = self .wq (inputs ) # (batch_size, seq_len, d_model)
50
- k = self .wk (inputs ) # (batch_size, seq_len, d_model)
51
- v = self .wv (inputs ) # (batch_size, seq_len, d_model)
50
+ # Projections
51
+ q = self .wq (inputs )
52
+ k = self .wk (inputs )
53
+ v = self .wv (inputs )
52
54
53
- # 다중 헤드로 분리
54
- q = self .split_heads (q , batch_size ) # (batch_size, num_heads, seq_len_q, depth)
55
- k = self .split_heads (k , batch_size ) # (batch_size, num_heads, seq_len_k, depth)
56
- v = self .split_heads (v , batch_size ) # (batch_size, num_heads, seq_len_v, depth)
55
+ # Split heads
56
+ q = self .split_heads (q , batch_size )
57
+ k = self .split_heads (k , batch_size )
58
+ v = self .split_heads (v , batch_size )
57
59
58
- # Scaled Dot-Product Attention
59
- sqrt_att_head_size = math . sqrt ( self . depth )
60
-
61
- attention_scores = tf .matmul (q , k , transpose_b = True ) # (batch_size, num_heads, seq_len_q, seq_len_k )
62
- dk = tf .cast (sqrt_att_head_size , tf .float32 )
63
- attention_scores = tf . divide ( attention_scores , dk )
60
+ # Scaled dot-product attention (compute in float32 for stability)
61
+ q_f = tf . cast ( q , tf . float32 )
62
+ k_f = tf . cast ( k , tf . float32 )
63
+ attention_scores = tf .matmul (q_f , k_f , transpose_b = True )
64
+ scale = tf .sqrt ( tf . cast (self . depth , tf .float32 ) )
65
+ attention_scores = attention_scores / scale
64
66
65
67
if mask is not None :
66
- attention_scores = tf .add ( attention_scores , mask )
68
+ attention_scores = attention_scores + tf .cast ( mask , tf . float32 )
67
69
68
70
attention_probs = self .stable_softmax (attention_scores , axis = - 1 )
71
+ # Cast back to v dtype for matmul efficiency under mixed precision
72
+ attention_probs = tf .cast (attention_probs , v .dtype )
69
73
attention_probs = self .dropout (attention_probs , training = training )
70
74
71
75
# Attention result
@@ -118,11 +122,23 @@ def __init__(self, model_name, normalize_embeddings=False, use_fp16=True,
118
122
self .num_layers = self .config .num_hidden_layers
119
123
self .vocab_size = self .config .vocab_size
120
124
125
+ # Optional mixed precision
126
+ if self .use_fp16 :
127
+ from tensorflow .keras import mixed_precision
128
+ try :
129
+ mixed_precision .set_global_policy ("mixed_float16" )
130
+ except Exception :
131
+ pass
132
+
121
133
# Build components
122
134
self ._build_embeddings ()
123
135
self ._build_encoder_layers ()
124
136
self ._build_pooler ()
137
+ # Handle ColBERT dim parameter
138
+ self .colbert_dim = self .d_model if not colbert_dim or colbert_dim < 1 else int (colbert_dim )
125
139
self ._build_colbert ()
140
+ # Sparse head (optional)
141
+ self .sparse_linear = tf .keras .layers .Dense (1 , name = "sparse_linear" )
126
142
127
143
# Tokenizer
128
144
self .tokenizer = AutoTokenizer .from_pretrained (
@@ -207,9 +223,7 @@ def _build_pooler(self):
207
223
)
208
224
209
225
def _build_colbert (self ):
210
- self .colbert_linear = tf .keras .layers .Dense (
211
- units = self .d_model ,
212
- )
226
+ self .colbert_linear = tf .keras .layers .Dense (self .colbert_dim , name = "colbert_linear" )
213
227
214
228
def call (self , inputs , training = False , output_hidden_states = False ):
215
229
@@ -278,7 +292,10 @@ def call(self, inputs, training=False, output_hidden_states=False):
278
292
279
293
# Pooling
280
294
if self .pooling_method == "mean" :
281
- pooled_output = tf .reduce_mean (hidden_states , axis = 1 )
295
+ m = tf .cast (attention_mask_origin , tf .float32 )[:, :, None ]
296
+ summed = tf .reduce_sum (tf .cast (hidden_states , tf .float32 ) * m , axis = 1 )
297
+ denom = tf .reduce_sum (m , axis = 1 ) + tf .cast (1e-9 , tf .float32 )
298
+ pooled_output = tf .cast (summed / denom , hidden_states .dtype )
282
299
else : # default: cls
283
300
pooled_output = hidden_states [:, 0 , :]
284
301
@@ -291,15 +308,23 @@ def call(self, inputs, training=False, output_hidden_states=False):
291
308
pooled_output = tf .nn .l2_normalize (pooled_output , axis = - 1 )
292
309
293
310
## colbert_vecs
294
- colbert_vecs = self .colbert_linear (hidden_states [:, 1 :])
295
- colbert_vecs = colbert_vecs * tf .cast (attention_mask_origin [:, 1 :][:, :, None ], dtype = tf .float32 )
311
+ colbert_vecs = None
312
+ if self .return_colbert_vecs :
313
+ m = tf .cast (attention_mask_origin [:, 1 :], hidden_states .dtype )[:, :, None ]
314
+ colbert_vecs = self .colbert_linear (hidden_states [:, 1 :]) * m
296
315
297
316
outputs = {
298
317
"dense_vecs" : pooled_output ,
299
- "colbert_vecs" : colbert_vecs ,
300
318
"last_hidden_state" : hidden_states
301
319
}
302
320
321
+ if colbert_vecs is not None :
322
+ outputs ["colbert_vecs" ] = colbert_vecs
323
+
324
+ if self .return_sparse :
325
+ token_weights = tf .nn .relu (self .sparse_linear (hidden_states ))
326
+ outputs ["token_weights" ] = token_weights
327
+
303
328
if output_hidden_states :
304
329
outputs ["hidden_states" ] = all_hidden_states
305
330
@@ -368,8 +393,6 @@ def save_model_with_tokenizer(model, tokenizer, save_path):
368
393
tf .TensorSpec (shape = [None , None ], dtype = tf .int32 , name = 'attention_mask' )
369
394
])
370
395
def serving_fn (input_ids , attention_mask ):
371
-
372
- print (input_ids )
373
396
inputs = {
374
397
'input_ids' : input_ids ,
375
398
'attention_mask' : attention_mask
@@ -379,15 +402,24 @@ def serving_fn(input_ids, attention_mask):
379
402
380
403
if outputs .get ('hidden_states' ):
381
404
hidden_states = tf .stack (outputs ['hidden_states' ], axis = 0 )
382
- return {
383
- 'dense_vecs' : outputs ['dense_vecs' ], # CLS Token
384
- 'colbert_vecs' : outputs ['colbert_vecs' ],
405
+ ret = {
406
+ 'dense_vecs' : outputs ['dense_vecs' ], # CLS Token or masked mean
385
407
'hidden_states' : hidden_states # (num_layers, batch, seq_len, hidden_dim)
386
408
}
409
+ if 'colbert_vecs' in outputs :
410
+ ret ['colbert_vecs' ] = outputs ['colbert_vecs' ]
411
+ if 'token_weights' in outputs :
412
+ ret ['token_weights' ] = outputs ['token_weights' ]
413
+ return ret
387
414
else :
388
- return {
415
+ ret = {
389
416
'dense_vecs' : outputs ['dense_vecs' ],
390
417
}
418
+ if 'colbert_vecs' in outputs :
419
+ ret ['colbert_vecs' ] = outputs ['colbert_vecs' ]
420
+ if 'token_weights' in outputs :
421
+ ret ['token_weights' ] = outputs ['token_weights' ]
422
+ return ret
391
423
392
424
# Save model
393
425
tf .saved_model .save (
0 commit comments