Skip to content

Commit 7b4335e

Browse files
committed
Add ConfigValidator to validate Config beforehand
1 parent 24dd9e4 commit 7b4335e

File tree

1 file changed

+44
-23
lines changed

1 file changed

+44
-23
lines changed

backends/ort/src/lib.rs

Lines changed: 44 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,42 @@ use text_embeddings_backend_core::{
1212
};
1313

1414
#[derive(Debug, Clone, Deserialize)]
15+
#[serde(from = "ConfigValidator")]
1516
pub struct Config {
16-
pub pad_token_id: Option<usize>,
17+
pub pad_token_id: usize,
1718
pub eos_token_id: Option<usize>,
1819
// NOTE: the fields below are only required when the ONNX model expects the `past_key_values`
1920
// as input i.e., whenever the ONNX model has been exported with optimized MHA nodes
20-
pub hidden_size: Option<usize>,
21-
pub num_hidden_layers: Option<usize>,
22-
pub num_key_value_heads: Option<usize>,
21+
pub hidden_size: usize,
22+
pub num_hidden_layers: usize,
23+
pub num_key_value_heads: usize,
24+
}
25+
26+
#[derive(Deserialize)]
27+
struct ConfigValidator {
28+
pad_token_id: Option<usize>,
29+
eos_token_id: Option<usize>,
30+
hidden_size: usize,
31+
num_hidden_layers: usize,
32+
num_key_value_heads: Option<usize>,
33+
}
34+
35+
impl From<ConfigValidator> for Config {
36+
fn from(config: ConfigValidator) -> Self {
37+
let pad_token_id = config.pad_token_id.or(config.eos_token_id).unwrap_or(0);
38+
39+
let num_key_value_heads = config
40+
.num_key_value_heads
41+
.unwrap_or(config.num_hidden_layers);
42+
43+
Config {
44+
pad_token_id,
45+
eos_token_id: config.eos_token_id,
46+
hidden_size: config.hidden_size,
47+
num_hidden_layers: config.num_hidden_layers,
48+
num_key_value_heads,
49+
}
50+
}
2351
}
2452

2553
#[derive(Debug, Clone, Deserialize)]
@@ -40,7 +68,7 @@ struct ModelInputs {
4068

4169
pub struct OrtBackend {
4270
session: Mutex<Session>,
43-
config: Option<Config>,
71+
config: Config,
4472

4573
token_type_ids: bool,
4674
// NOTE: required since the key can either be `token_type_ids` or `input_type`
@@ -50,7 +78,6 @@ pub struct OrtBackend {
5078

5179
pool: Pool,
5280
padding_side: PaddingSide,
53-
pad_token_id: usize,
5481
}
5582

5683
impl OrtBackend {
@@ -85,6 +112,12 @@ impl OrtBackend {
85112
}
86113
};
87114

115+
let config: Config = {
116+
let content = std::fs::read_to_string(&model_path.join("config.json"))
117+
.map_err(|e| BackendError::Start(format!("Failed to read `config.json`: {}", e)))?;
118+
serde_json::from_str(&content)
119+
.map_err(|e| BackendError::Start(format!("Failed to parse `config.json`: {}", e)))?
120+
};
88121
let session = Session::builder()
89122
.s()?
90123
.with_intra_threads(num_cpus::get())
@@ -117,16 +150,6 @@ impl OrtBackend {
117150
}
118151
}
119152

120-
let config_path = model_path.join("config.json");
121-
let config: Option<Config> = if config_path.exists() {
122-
let content = std::fs::read_to_string(&config_path)
123-
.map_err(|e| BackendError::Start(format!("Failed to read `config.json`: {}", e)))?;
124-
Some(serde_json::from_str(&content).map_err(|e| {
125-
BackendError::Start(format!("Failed to parse `config.json`: {}", e))
126-
})?)
127-
} else {
128-
None
129-
};
130153

131154
let pad_token_id = config
132155
.as_ref()
@@ -181,7 +204,6 @@ impl OrtBackend {
181204
past_key_values,
182205
pool,
183206
padding_side,
184-
pad_token_id,
185207
})
186208
}
187209
}
@@ -228,7 +250,7 @@ impl OrtBackend {
228250
// sequences in the batch have the same length
229251
masking = true;
230252
for pad_pos in 0..padding {
231-
input_ids.push(self.pad_token_id as i64);
253+
input_ids.push(self.config.pad_token_id as i64);
232254
attention_mask.push(0_i64);
233255
token_type_ids.push(0);
234256
position_ids.push((seq_length + pad_pos) as i64);
@@ -258,7 +280,7 @@ impl OrtBackend {
258280
// sequences in the batch have the same length
259281
masking = true;
260282
for _ in 0..padding {
261-
input_ids.push(self.pad_token_id as i64);
283+
input_ids.push(self.config.pad_token_id as i64);
262284
attention_mask.push(0_i64);
263285
token_type_ids.push(0);
264286
position_ids.push(0);
@@ -317,10 +339,9 @@ impl OrtBackend {
317339
let input_lengths = ndarray::Array1::from_vec(input_lengths);
318340

319341
let past_key_values = if self.past_key_values {
320-
let config = self.config.as_ref().unwrap();
321-
let hidden_size = config.hidden_size.unwrap();
322-
let num_hidden_layers = config.num_hidden_layers.unwrap();
323-
let num_key_value_heads = config.num_key_value_heads.unwrap();
342+
let hidden_size = self.config.hidden_size;
343+
let num_hidden_layers = self.config.num_hidden_layers;
344+
let num_key_value_heads = self.config.num_key_value_heads;
324345
let head_size = hidden_size / num_key_value_heads;
325346
let mut arrays = Vec::new();
326347

0 commit comments

Comments
 (0)