@@ -12,14 +12,42 @@ use text_embeddings_backend_core::{
1212} ;
1313
1414#[ derive( Debug , Clone , Deserialize ) ]
15+ #[ serde( from = "ConfigValidator" ) ]
1516pub struct Config {
16- pub pad_token_id : Option < usize > ,
17+ pub pad_token_id : usize ,
1718 pub eos_token_id : Option < usize > ,
1819 // NOTE: the fields below are only required when the ONNX model expects the `past_key_values`
1920 // as input i.e., whenever the ONNX model has been exported with optimized MHA nodes
20- pub hidden_size : Option < usize > ,
21- pub num_hidden_layers : Option < usize > ,
22- pub num_key_value_heads : Option < usize > ,
21+ pub hidden_size : usize ,
22+ pub num_hidden_layers : usize ,
23+ pub num_key_value_heads : usize ,
24+ }
25+
26+ #[ derive( Deserialize ) ]
27+ struct ConfigValidator {
28+ pad_token_id : Option < usize > ,
29+ eos_token_id : Option < usize > ,
30+ hidden_size : usize ,
31+ num_hidden_layers : usize ,
32+ num_key_value_heads : Option < usize > ,
33+ }
34+
35+ impl From < ConfigValidator > for Config {
36+ fn from ( config : ConfigValidator ) -> Self {
37+ let pad_token_id = config. pad_token_id . or ( config. eos_token_id ) . unwrap_or ( 0 ) ;
38+
39+ let num_key_value_heads = config
40+ . num_key_value_heads
41+ . unwrap_or ( config. num_hidden_layers ) ;
42+
43+ Config {
44+ pad_token_id,
45+ eos_token_id : config. eos_token_id ,
46+ hidden_size : config. hidden_size ,
47+ num_hidden_layers : config. num_hidden_layers ,
48+ num_key_value_heads,
49+ }
50+ }
2351}
2452
2553#[ derive( Debug , Clone , Deserialize ) ]
@@ -40,7 +68,7 @@ struct ModelInputs {
4068
4169pub struct OrtBackend {
4270 session : Mutex < Session > ,
43- config : Option < Config > ,
71+ config : Config ,
4472
4573 token_type_ids : bool ,
4674 // NOTE: required since the key can either be `token_type_ids` or `input_type`
@@ -50,7 +78,6 @@ pub struct OrtBackend {
5078
5179 pool : Pool ,
5280 padding_side : PaddingSide ,
53- pad_token_id : usize ,
5481}
5582
5683impl OrtBackend {
@@ -85,6 +112,12 @@ impl OrtBackend {
85112 }
86113 } ;
87114
115+ let config: Config = {
116+ let content = std:: fs:: read_to_string ( & model_path. join ( "config.json" ) )
117+ . map_err ( |e| BackendError :: Start ( format ! ( "Failed to read `config.json`: {}" , e) ) ) ?;
118+ serde_json:: from_str ( & content)
119+ . map_err ( |e| BackendError :: Start ( format ! ( "Failed to parse `config.json`: {}" , e) ) ) ?
120+ } ;
88121 let session = Session :: builder ( )
89122 . s ( ) ?
90123 . with_intra_threads ( num_cpus:: get ( ) )
@@ -117,16 +150,6 @@ impl OrtBackend {
117150 }
118151 }
119152
120- let config_path = model_path. join ( "config.json" ) ;
121- let config: Option < Config > = if config_path. exists ( ) {
122- let content = std:: fs:: read_to_string ( & config_path)
123- . map_err ( |e| BackendError :: Start ( format ! ( "Failed to read `config.json`: {}" , e) ) ) ?;
124- Some ( serde_json:: from_str ( & content) . map_err ( |e| {
125- BackendError :: Start ( format ! ( "Failed to parse `config.json`: {}" , e) )
126- } ) ?)
127- } else {
128- None
129- } ;
130153
131154 let pad_token_id = config
132155 . as_ref ( )
@@ -181,7 +204,6 @@ impl OrtBackend {
181204 past_key_values,
182205 pool,
183206 padding_side,
184- pad_token_id,
185207 } )
186208 }
187209}
@@ -228,7 +250,7 @@ impl OrtBackend {
228250 // sequences in the batch have the same length
229251 masking = true ;
230252 for pad_pos in 0 ..padding {
231- input_ids. push ( self . pad_token_id as i64 ) ;
253+ input_ids. push ( self . config . pad_token_id as i64 ) ;
232254 attention_mask. push ( 0_i64 ) ;
233255 token_type_ids. push ( 0 ) ;
234256 position_ids. push ( ( seq_length + pad_pos) as i64 ) ;
@@ -258,7 +280,7 @@ impl OrtBackend {
258280 // sequences in the batch have the same length
259281 masking = true ;
260282 for _ in 0 ..padding {
261- input_ids. push ( self . pad_token_id as i64 ) ;
283+ input_ids. push ( self . config . pad_token_id as i64 ) ;
262284 attention_mask. push ( 0_i64 ) ;
263285 token_type_ids. push ( 0 ) ;
264286 position_ids. push ( 0 ) ;
@@ -317,10 +339,9 @@ impl OrtBackend {
317339 let input_lengths = ndarray:: Array1 :: from_vec ( input_lengths) ;
318340
319341 let past_key_values = if self . past_key_values {
320- let config = self . config . as_ref ( ) . unwrap ( ) ;
321- let hidden_size = config. hidden_size . unwrap ( ) ;
322- let num_hidden_layers = config. num_hidden_layers . unwrap ( ) ;
323- let num_key_value_heads = config. num_key_value_heads . unwrap ( ) ;
342+ let hidden_size = self . config . hidden_size ;
343+ let num_hidden_layers = self . config . num_hidden_layers ;
344+ let num_key_value_heads = self . config . num_key_value_heads ;
324345 let head_size = hidden_size / num_key_value_heads;
325346 let mut arrays = Vec :: new ( ) ;
326347
0 commit comments