@@ -12,42 +12,14 @@ use text_embeddings_backend_core::{
1212} ;
1313
1414#[ derive( Debug , Clone , Deserialize ) ]
15- #[ serde( from = "ConfigValidator" ) ]
1615pub struct Config {
17- pub pad_token_id : usize ,
16+ pub pad_token_id : Option < usize > ,
1817 pub eos_token_id : Option < usize > ,
1918 // NOTE: the fields below are only required when the ONNX model expects the `past_key_values`
2019 // as input i.e., whenever the ONNX model has been exported with optimized MHA nodes
2120 pub hidden_size : usize ,
2221 pub num_hidden_layers : usize ,
23- pub num_key_value_heads : usize ,
24- }
25-
26- #[ derive( Deserialize ) ]
27- struct ConfigValidator {
28- pad_token_id : Option < usize > ,
29- eos_token_id : Option < usize > ,
30- hidden_size : usize ,
31- num_hidden_layers : usize ,
32- num_key_value_heads : Option < usize > ,
33- }
34-
35- impl From < ConfigValidator > for Config {
36- fn from ( config : ConfigValidator ) -> Self {
37- let pad_token_id = config. pad_token_id . or ( config. eos_token_id ) . unwrap_or ( 0 ) ;
38-
39- let num_key_value_heads = config
40- . num_key_value_heads
41- . unwrap_or ( config. num_hidden_layers ) ;
42-
43- Config {
44- pad_token_id,
45- eos_token_id : config. eos_token_id ,
46- hidden_size : config. hidden_size ,
47- num_hidden_layers : config. num_hidden_layers ,
48- num_key_value_heads,
49- }
50- }
22+ pub num_key_value_heads : Option < usize > ,
5123}
5224
5325#[ derive( Debug , Clone , Default , Deserialize ) ]
@@ -195,6 +167,11 @@ impl OrtBackend {
195167 let max_length = batch. max_length as usize ;
196168 let elems = batch_size * max_length;
197169
170+ let pad_token_id = self
171+ . config
172+ . pad_token_id
173+ . unwrap_or ( self . config . eos_token_id . unwrap_or ( 0 ) ) as i64 ;
174+
198175 let ( input_ids, attention_mask, token_type_ids, position_ids, input_lengths, masking) =
199176 if batch_size > 1 {
200177 let mut input_ids = Vec :: with_capacity ( elems) ;
@@ -227,7 +204,7 @@ impl OrtBackend {
227204 // sequences in the batch have the same length
228205 masking = true ;
229206 for pad_pos in 0 ..padding {
230- input_ids. push ( self . config . pad_token_id as i64 ) ;
207+ input_ids. push ( pad_token_id) ;
231208 attention_mask. push ( 0_i64 ) ;
232209 token_type_ids. push ( 0 ) ;
233210 position_ids. push ( ( seq_length + pad_pos) as i64 ) ;
@@ -257,7 +234,7 @@ impl OrtBackend {
257234 // sequences in the batch have the same length
258235 masking = true ;
259236 for _ in 0 ..padding {
260- input_ids. push ( self . config . pad_token_id as i64 ) ;
237+ input_ids. push ( pad_token_id) ;
261238 attention_mask. push ( 0_i64 ) ;
262239 token_type_ids. push ( 0 ) ;
263240 position_ids. push ( 0 ) ;
@@ -318,7 +295,10 @@ impl OrtBackend {
318295 let past_key_values = if self . past_key_values {
319296 let hidden_size = self . config . hidden_size ;
320297 let num_hidden_layers = self . config . num_hidden_layers ;
321- let num_key_value_heads = self . config . num_key_value_heads ;
298+ let num_key_value_heads = self
299+ . config
300+ . num_key_value_heads
301+ . unwrap_or ( self . config . num_hidden_layers ) ;
322302 let head_size = hidden_size / num_key_value_heads;
323303 let mut arrays = Vec :: new ( ) ;
324304
0 commit comments