Skip to content

Commit 983cb26

Browse files
committed
Revert bloat around Config and keep it simpler
1 parent 6795e1e commit 983cb26

File tree

1 file changed

+13
-33
lines changed

1 file changed

+13
-33
lines changed

backends/ort/src/lib.rs

Lines changed: 13 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -12,42 +12,14 @@ use text_embeddings_backend_core::{
1212
};
1313

1414
#[derive(Debug, Clone, Deserialize)]
15-
#[serde(from = "ConfigValidator")]
1615
pub struct Config {
17-
pub pad_token_id: usize,
16+
pub pad_token_id: Option<usize>,
1817
pub eos_token_id: Option<usize>,
1918
// NOTE: the fields below are only required when the ONNX model expects the `past_key_values`
2019
// as input i.e., whenever the ONNX model has been exported with optimized MHA nodes
2120
pub hidden_size: usize,
2221
pub num_hidden_layers: usize,
23-
pub num_key_value_heads: usize,
24-
}
25-
26-
#[derive(Deserialize)]
27-
struct ConfigValidator {
28-
pad_token_id: Option<usize>,
29-
eos_token_id: Option<usize>,
30-
hidden_size: usize,
31-
num_hidden_layers: usize,
32-
num_key_value_heads: Option<usize>,
33-
}
34-
35-
impl From<ConfigValidator> for Config {
36-
fn from(config: ConfigValidator) -> Self {
37-
let pad_token_id = config.pad_token_id.or(config.eos_token_id).unwrap_or(0);
38-
39-
let num_key_value_heads = config
40-
.num_key_value_heads
41-
.unwrap_or(config.num_hidden_layers);
42-
43-
Config {
44-
pad_token_id,
45-
eos_token_id: config.eos_token_id,
46-
hidden_size: config.hidden_size,
47-
num_hidden_layers: config.num_hidden_layers,
48-
num_key_value_heads,
49-
}
50-
}
22+
pub num_key_value_heads: Option<usize>,
5123
}
5224

5325
#[derive(Debug, Clone, Default, Deserialize)]
@@ -195,6 +167,11 @@ impl OrtBackend {
195167
let max_length = batch.max_length as usize;
196168
let elems = batch_size * max_length;
197169

170+
let pad_token_id = self
171+
.config
172+
.pad_token_id
173+
.unwrap_or(self.config.eos_token_id.unwrap_or(0)) as i64;
174+
198175
let (input_ids, attention_mask, token_type_ids, position_ids, input_lengths, masking) =
199176
if batch_size > 1 {
200177
let mut input_ids = Vec::with_capacity(elems);
@@ -227,7 +204,7 @@ impl OrtBackend {
227204
// sequences in the batch have the same length
228205
masking = true;
229206
for pad_pos in 0..padding {
230-
input_ids.push(self.config.pad_token_id as i64);
207+
input_ids.push(pad_token_id);
231208
attention_mask.push(0_i64);
232209
token_type_ids.push(0);
233210
position_ids.push((seq_length + pad_pos) as i64);
@@ -257,7 +234,7 @@ impl OrtBackend {
257234
// sequences in the batch have the same length
258235
masking = true;
259236
for _ in 0..padding {
260-
input_ids.push(self.config.pad_token_id as i64);
237+
input_ids.push(pad_token_id);
261238
attention_mask.push(0_i64);
262239
token_type_ids.push(0);
263240
position_ids.push(0);
@@ -318,7 +295,10 @@ impl OrtBackend {
318295
let past_key_values = if self.past_key_values {
319296
let hidden_size = self.config.hidden_size;
320297
let num_hidden_layers = self.config.num_hidden_layers;
321-
let num_key_value_heads = self.config.num_key_value_heads;
298+
let num_key_value_heads = self
299+
.config
300+
.num_key_value_heads
301+
.unwrap_or(self.config.num_hidden_layers);
322302
let head_size = hidden_size / num_key_value_heads;
323303
let mut arrays = Vec::new();
324304

0 commit comments

Comments
 (0)