Skip to content

Commit 8e5e3bd

Browse files
committed
feat: colbert_vecs
1 parent 5c4c192 commit 8e5e3bd

File tree

4 files changed

+148
-57
lines changed

4 files changed

+148
-57
lines changed

BGEM3TFModel.py

Lines changed: 65 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,13 @@ def __init__(self, d_model, num_heads, dropout_rate=0.1, **kwargs):
3232
# 드롭아웃
3333
self.dropout = tf.keras.layers.Dropout(dropout_rate)
3434

35-
def stable_softmax(self, logits, axis=None, name=None):
36-
"""
37-
Stable softmax implementation
38-
"""
39-
return tf.nn.softmax(logits=logits + 1e-9, axis=axis, name=name)
35+
def stable_softmax(self, logits, axis=-1, name=None):
36+
"""Numerically stable softmax: subtract max and compute in float32."""
37+
dtype = logits.dtype
38+
x = tf.cast(logits, tf.float32)
39+
x = x - tf.reduce_max(x, axis=axis, keepdims=True)
40+
probs = tf.nn.softmax(x, axis=axis, name=name)
41+
return tf.cast(probs, dtype)
4042

4143
def split_heads(self, x, batch_size):
4244
x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
@@ -45,27 +47,29 @@ def split_heads(self, x, batch_size):
4547
def call(self, inputs, mask=None, training=False):
4648
batch_size = tf.shape(inputs)[0]
4749

48-
# Query, Key, Value를 계산
49-
q = self.wq(inputs) # (batch_size, seq_len, d_model)
50-
k = self.wk(inputs) # (batch_size, seq_len, d_model)
51-
v = self.wv(inputs) # (batch_size, seq_len, d_model)
50+
# Projections
51+
q = self.wq(inputs)
52+
k = self.wk(inputs)
53+
v = self.wv(inputs)
5254

53-
# 다중 헤드로 분리
54-
q = self.split_heads(q, batch_size) # (batch_size, num_heads, seq_len_q, depth)
55-
k = self.split_heads(k, batch_size) # (batch_size, num_heads, seq_len_k, depth)
56-
v = self.split_heads(v, batch_size) # (batch_size, num_heads, seq_len_v, depth)
55+
# Split heads
56+
q = self.split_heads(q, batch_size)
57+
k = self.split_heads(k, batch_size)
58+
v = self.split_heads(v, batch_size)
5759

58-
# Scaled Dot-Product Attention
59-
sqrt_att_head_size = math.sqrt(self.depth)
60-
61-
attention_scores = tf.matmul(q, k, transpose_b=True) # (batch_size, num_heads, seq_len_q, seq_len_k)
62-
dk = tf.cast(sqrt_att_head_size, tf.float32)
63-
attention_scores = tf.divide(attention_scores, dk)
60+
# Scaled dot-product attention (compute in float32 for stability)
61+
q_f = tf.cast(q, tf.float32)
62+
k_f = tf.cast(k, tf.float32)
63+
attention_scores = tf.matmul(q_f, k_f, transpose_b=True)
64+
scale = tf.sqrt(tf.cast(self.depth, tf.float32))
65+
attention_scores = attention_scores / scale
6466

6567
if mask is not None:
66-
attention_scores = tf.add(attention_scores, mask)
68+
attention_scores = attention_scores + tf.cast(mask, tf.float32)
6769

6870
attention_probs = self.stable_softmax(attention_scores, axis=-1)
71+
# Cast back to v dtype for matmul efficiency under mixed precision
72+
attention_probs = tf.cast(attention_probs, v.dtype)
6973
attention_probs = self.dropout(attention_probs, training=training)
7074

7175
# Attention result
@@ -118,11 +122,23 @@ def __init__(self, model_name, normalize_embeddings=False, use_fp16=True,
118122
self.num_layers = self.config.num_hidden_layers
119123
self.vocab_size = self.config.vocab_size
120124

125+
# Optional mixed precision
126+
if self.use_fp16:
127+
from tensorflow.keras import mixed_precision
128+
try:
129+
mixed_precision.set_global_policy("mixed_float16")
130+
except Exception:
131+
pass
132+
121133
# Build components
122134
self._build_embeddings()
123135
self._build_encoder_layers()
124136
self._build_pooler()
137+
# Handle ColBERT dim parameter
138+
self.colbert_dim = self.d_model if not colbert_dim or colbert_dim < 1 else int(colbert_dim)
125139
self._build_colbert()
140+
# Sparse head (optional)
141+
self.sparse_linear = tf.keras.layers.Dense(1, name="sparse_linear")
126142

127143
# Tokenizer
128144
self.tokenizer = AutoTokenizer.from_pretrained(
@@ -207,9 +223,7 @@ def _build_pooler(self):
207223
)
208224

209225
def _build_colbert(self):
210-
self.colbert_linear = tf.keras.layers.Dense(
211-
units=self.d_model,
212-
)
226+
self.colbert_linear = tf.keras.layers.Dense(self.colbert_dim, name="colbert_linear")
213227

214228
def call(self, inputs, training=False, output_hidden_states=False):
215229

@@ -278,7 +292,10 @@ def call(self, inputs, training=False, output_hidden_states=False):
278292

279293
# Pooling
280294
if self.pooling_method == "mean":
281-
pooled_output = tf.reduce_mean(hidden_states, axis=1)
295+
m = tf.cast(attention_mask_origin, tf.float32)[:, :, None]
296+
summed = tf.reduce_sum(tf.cast(hidden_states, tf.float32) * m, axis=1)
297+
denom = tf.reduce_sum(m, axis=1) + tf.cast(1e-9, tf.float32)
298+
pooled_output = tf.cast(summed / denom, hidden_states.dtype)
282299
else: # default: cls
283300
pooled_output = hidden_states[:, 0, :]
284301

@@ -291,15 +308,23 @@ def call(self, inputs, training=False, output_hidden_states=False):
291308
pooled_output = tf.nn.l2_normalize(pooled_output, axis=-1)
292309

293310
## colbert_vecs
294-
colbert_vecs = self.colbert_linear(hidden_states[:, 1:])
295-
colbert_vecs = colbert_vecs * tf.cast(attention_mask_origin[:, 1:][:, :, None], dtype=tf.float32)
311+
colbert_vecs = None
312+
if self.return_colbert_vecs:
313+
m = tf.cast(attention_mask_origin[:, 1:], hidden_states.dtype)[:, :, None]
314+
colbert_vecs = self.colbert_linear(hidden_states[:, 1:]) * m
296315

297316
outputs = {
298317
"dense_vecs": pooled_output,
299-
"colbert_vecs": colbert_vecs,
300318
"last_hidden_state": hidden_states
301319
}
302320

321+
if colbert_vecs is not None:
322+
outputs["colbert_vecs"] = colbert_vecs
323+
324+
if self.return_sparse:
325+
token_weights = tf.nn.relu(self.sparse_linear(hidden_states))
326+
outputs["token_weights"] = token_weights
327+
303328
if output_hidden_states:
304329
outputs["hidden_states"] = all_hidden_states
305330

@@ -368,8 +393,6 @@ def save_model_with_tokenizer(model, tokenizer, save_path):
368393
tf.TensorSpec(shape=[None, None], dtype=tf.int32, name='attention_mask')
369394
])
370395
def serving_fn(input_ids, attention_mask):
371-
372-
print(input_ids)
373396
inputs = {
374397
'input_ids': input_ids,
375398
'attention_mask': attention_mask
@@ -379,15 +402,24 @@ def serving_fn(input_ids, attention_mask):
379402

380403
if outputs.get('hidden_states'):
381404
hidden_states = tf.stack(outputs['hidden_states'], axis=0)
382-
return {
383-
'dense_vecs': outputs['dense_vecs'], # CLS Token
384-
'colbert_vecs': outputs['colbert_vecs'],
405+
ret = {
406+
'dense_vecs': outputs['dense_vecs'], # CLS Token or masked mean
385407
'hidden_states': hidden_states # (num_layers, batch, seq_len, hidden_dim)
386408
}
409+
if 'colbert_vecs' in outputs:
410+
ret['colbert_vecs'] = outputs['colbert_vecs']
411+
if 'token_weights' in outputs:
412+
ret['token_weights'] = outputs['token_weights']
413+
return ret
387414
else:
388-
return {
415+
ret = {
389416
'dense_vecs': outputs['dense_vecs'],
390417
}
418+
if 'colbert_vecs' in outputs:
419+
ret['colbert_vecs'] = outputs['colbert_vecs']
420+
if 'token_weights' in outputs:
421+
ret['token_weights'] = outputs['token_weights']
422+
return ret
391423

392424
# Save model
393425
tf.saved_model.save(

BGEM3WeightConverter.py

Lines changed: 77 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,58 @@ def _init_colbert_weights(tf_model):
5555
colbert = load_colbert_weights()
5656
colbert_weights = colbert['weight']
5757
colbert_bias = colbert['bias']
58+
# Convert to numpy and report shape
59+
w = colbert_weights.detach().cpu().numpy() if hasattr(colbert_weights, "detach") else np.array(colbert_weights)
60+
b = colbert_bias.detach().cpu().numpy() if hasattr(colbert_bias, "detach") else np.array(colbert_bias)
5861

59-
tf_model.colbert_linear.set_weights([
60-
colbert_weights.numpy().T,
61-
colbert_bias.numpy()
62-
])
62+
out_dim, in_dim = w.shape # PT: (out_dim, in_dim)
63+
print(f"ColBERT head weight shape: (out_dim={out_dim}, in_dim={in_dim})")
64+
65+
# Ensure the Dense layer has matching units and is built
66+
try:
67+
current_units = getattr(tf_model.colbert_linear, "units", None)
68+
except Exception:
69+
current_units = None
70+
71+
if current_units is not None and current_units != out_dim:
72+
# Units mismatch; warn. Ideally create the model with detected colbert_dim to avoid this.
73+
print(f"Warning: colbert_linear units ({current_units}) != detected out_dim ({out_dim}). We will attempt to set weights and may fail.")
74+
75+
# Ensure variables exist. If not built yet, do a dummy call to build with correct in_dim.
76+
if not getattr(tf_model.colbert_linear, "built", False):
77+
dummy = tf.zeros((1, 2, in_dim), dtype=tf.float32)
78+
_ = tf_model.colbert_linear(dummy)
79+
80+
# Set weights (kernel shape: (in_dim, out_dim))
81+
tf_model.colbert_linear.set_weights([w.T, b])
82+
83+
84+
def _init_sparse_weights(tf_model):
85+
"""Initialize sparse head weights if available (optional)."""
86+
try:
87+
st = load_sparse_weights()
88+
except FileNotFoundError as e:
89+
print(str(e))
90+
return
91+
92+
# Expect PyTorch shape: (out_dim=1, in_dim=hidden)
93+
w_pt = st["weight"]
94+
b_pt = st["bias"]
95+
# Ensure numpy
96+
if hasattr(w_pt, "cpu"):
97+
w_np = w_pt.cpu().numpy()
98+
else:
99+
w_np = np.array(w_pt)
100+
if hasattr(b_pt, "cpu"):
101+
b_np = b_pt.cpu().numpy()
102+
else:
103+
b_np = np.array(b_pt)
104+
105+
# Build layer if not built
106+
in_dim = w_np.shape[1]
107+
tf_model.sparse_linear.build((None, None, in_dim))
108+
# Keras Dense kernel shape: (in_dim, out_dim)
109+
tf_model.sparse_linear.set_weights([w_np.T, b_np])
63110

64111

65112
class BGEM3WeightConverter:
@@ -85,15 +132,15 @@ def initialize_weights(self, tf_model):
85132
# Initialize encoder layers
86133
self._init_transformer_blocks(tf_model)
87134

88-
# Initialize pooler
89-
self._init_pooler_weights(tf_model)
90-
91-
# Initialize pooler
135+
# Initialize pooler (once)
92136
self._init_pooler_weights(tf_model)
93137

94138
# Initialize colbert
95139
_init_colbert_weights(tf_model)
96140

141+
# Initialize sparse head (optional)
142+
_init_sparse_weights(tf_model)
143+
97144
return tf_model
98145

99146
def _init_embedding_weights(self, tf_model):
@@ -230,9 +277,28 @@ def _init_pooler_weights(self, tf_model):
230277

231278

232279
def convert_and_save_model(model_name: str, save_path: str):
233-
"""Convert PyTorch model to TensorFlow and save"""
234-
# Initialize TensorFlow model
235-
tf_model = BGEM3TensorFlow(model_name)
280+
"""Convert PyTorch model to TensorFlow and save.
281+
Also detects and uses original ColBERT dimension for TF head.
282+
"""
283+
# Detect ColBERT original dimension from weights (out_dim)
284+
try:
285+
colbert = load_colbert_weights()
286+
colbert_w = colbert['weight']
287+
out_dim = int(colbert_w.shape[0])
288+
print(f"Detected ColBERT dimension: {out_dim}")
289+
colbert_dim = out_dim
290+
return_colbert_vecs = True
291+
except Exception as e:
292+
print(f"ColBERT weights not found or failed to load: {e}")
293+
colbert_dim = -1
294+
return_colbert_vecs = False
295+
296+
# Initialize TensorFlow model with detected colbert_dim
297+
tf_model = BGEM3TensorFlow(
298+
model_name,
299+
colbert_dim=colbert_dim,
300+
return_colbert_vecs=return_colbert_vecs,
301+
)
236302

237303
# Convert weights
238304
converter = BGEM3WeightConverter(model_name)

model_conversion_validator.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,11 @@ def encode_with_tf_model_and_get_hidden_states(serving_fn, tokenizer, queries, m
132132

133133
hidden_states = outputs["hidden_states"] # (num_layers, batch, seq_len, hidden_dim)
134134
final_embeddings = outputs["dense_vecs"]
135-
print("outputs['colbert_vecs'] : ")
136-
print(outputs["colbert_vecs"])
135+
if "colbert_vecs" in outputs:
136+
print("outputs['colbert_vecs'] : ")
137+
print(outputs["colbert_vecs"])
138+
else:
139+
print("colbert_vecs not returned by TF model (flag disabled).")
137140

138141
return final_embeddings.numpy(), hidden_states
139142

torch_tf_validator.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -61,34 +61,24 @@ def main():
6161

6262
inputs_tf = tokenize_wo_padding(tokenizer, text, return_tensors="tf")
6363
inputs_tf_w_padding = tokenize_w_padding(tokenizer, text, return_tensors="tf")
64-
inputs_tf_w_padding_attnFixed = inputs_tf_w_padding.copy()
65-
inputs_tf_w_padding_attnFixed['attention_mask'] = tf.where(inputs_tf_w_padding['attention_mask'] == 0, -9999999, 0)
6664
tf_model = load_tf_model(model_path_tf).signatures["serving_default"]
6765

6866
loguru.logger.info("Tensorflow] Model output".ljust(50, "-"))
6967
with tf.device("/GPU:0"):
7068
output_tf = tf_model(**inputs_tf)
7169
output_tf_w_padding = tf_model(**inputs_tf_w_padding)
72-
output_tf_w_padding_attnFixed = tf_model(**inputs_tf_w_padding_attnFixed)
7370
loguru.logger.info("output without padding (GT)".ljust(50, "-"))
7471
loguru.logger.info(output_tf['hidden_states'][-1][:,0])
7572
loguru.logger.info("="*50)
7673
loguru.logger.info("output with padding".ljust(50, "-"))
7774
loguru.logger.info(output_tf_w_padding['hidden_states'][-1][:,0])
7875
loguru.logger.info("="*50)
79-
loguru.logger.info("output with padding (attention fixed)".ljust(50, "-"))
80-
loguru.logger.info(output_tf_w_padding_attnFixed['hidden_states'][-1][:,0])
81-
loguru.logger.info("="*50)
8276
err_tf = tf.abs(output_tf['hidden_states'][-1][:,0] - output_tf_w_padding['hidden_states'][-1][:,0])
8377
loguru.logger.info("Error".ljust(50, "-"))
8478
loguru.logger.info(tf.reduce_mean(err_tf))
8579
loguru.logger.info("="*50)
86-
err_tf_attnFixed = tf.abs(output_tf_w_padding['hidden_states'][-1][:,0] - output_tf_w_padding_attnFixed['hidden_states'][-1][:,0])
87-
loguru.logger.info("Error (attention fixed)".ljust(50, "-"))
88-
loguru.logger.info(tf.reduce_mean(err_tf_attnFixed))
89-
loguru.logger.info("="*50)
9080

9181

9282

9383
if __name__ == "__main__":
94-
main()
84+
main()

0 commit comments

Comments
 (0)