Skip to content

Commit 6795e1e

Browse files
committed
Add TokenizerConfig with default right-padding
1 parent 7b4335e commit 6795e1e

File tree

1 file changed

+33
-55
lines changed

1 file changed

+33
-55
lines changed

backends/ort/src/lib.rs

Lines changed: 33 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,20 @@ impl From<ConfigValidator> for Config {
5050
}
5151
}
5252

53-
#[derive(Debug, Clone, Deserialize)]
53+
#[derive(Debug, Clone, Default, Deserialize)]
5454
#[serde(rename_all = "lowercase")]
5555
enum PaddingSide {
5656
Left,
57+
#[default]
5758
Right,
5859
}
5960

61+
#[derive(Debug, Clone, Deserialize)]
62+
pub struct TokenizerConfig {
63+
#[serde(default)]
64+
padding_side: PaddingSide,
65+
}
66+
6067
struct ModelInputs {
6168
pub input_ids: ndarray::Array2<i64>,
6269
pub attention_mask: ndarray::Array2<i64>,
@@ -69,6 +76,7 @@ struct ModelInputs {
6976
pub struct OrtBackend {
7077
session: Mutex<Session>,
7178
config: Config,
79+
tokenizer_config: TokenizerConfig,
7280

7381
token_type_ids: bool,
7482
// NOTE: required since the key can either be `token_type_ids` or `input_type`
@@ -77,7 +85,6 @@ pub struct OrtBackend {
7785
past_key_values: bool,
7886

7987
pool: Pool,
80-
padding_side: PaddingSide,
8188
}
8289

8390
impl OrtBackend {
@@ -113,11 +120,26 @@ impl OrtBackend {
113120
};
114121

115122
let config: Config = {
116-
let content = std::fs::read_to_string(&model_path.join("config.json"))
123+
let content = std::fs::read_to_string(model_path.join("config.json"))
117124
.map_err(|e| BackendError::Start(format!("Failed to read `config.json`: {}", e)))?;
118125
serde_json::from_str(&content)
119126
.map_err(|e| BackendError::Start(format!("Failed to parse `config.json`: {}", e)))?
120127
};
128+
129+
let tokenizer_config_path = model_path.join("tokenizer_config.json");
130+
let tokenizer_config: TokenizerConfig = if tokenizer_config_path.exists() {
131+
let content = std::fs::read_to_string(&tokenizer_config_path).map_err(|e| {
132+
BackendError::Start(format!("Failed to read `tokenizer_config.json`: {}", e))
133+
})?;
134+
serde_json::from_str(&content).map_err(|e| {
135+
BackendError::Start(format!("Failed to parse `tokenizer_config.json`: {}", e))
136+
})?
137+
} else {
138+
TokenizerConfig {
139+
padding_side: PaddingSide::default(),
140+
}
141+
};
142+
121143
let session = Session::builder()
122144
.s()?
123145
.with_intra_threads(num_cpus::get())
@@ -150,60 +172,15 @@ impl OrtBackend {
150172
}
151173
}
152174

153-
154-
let pad_token_id = config
155-
.as_ref()
156-
.and_then(|c| c.pad_token_id.or(c.eos_token_id))
157-
.unwrap_or(0);
158-
159-
// NOTE: given that `hidden_size`, `num_hidden_layers`, and `num_key_value_heads` are set
160-
// to `Option` in the `Config`, but required if `past_key_values` is an input of ONNX, then
161-
// those should be validated in advance
162-
if past_key_values {
163-
match &config {
164-
Some(config) => {
165-
if config.hidden_size.is_none()
166-
|| config.num_hidden_layers.is_none()
167-
|| config.num_key_value_heads.is_none()
168-
{
169-
return Err(BackendError::Start(
170-
"`config.json` doesn't contain all required keys: `hidden_size`, `num_hidden_layers`, and `num_key_value_heads`.".into()
171-
));
172-
}
173-
}
174-
None => {
175-
return Err(BackendError::Start(format!(
176-
"`config.json` not found at {config_path:?}, but it's required as this ONNX expects `past_key_values` as input, meaning that the `config.json` file should contain at least the keys: `hidden_size`, `num_hidden_layers`, and `num_key_value_heads`."
177-
)));
178-
}
179-
}
180-
}
181-
182-
let padding_side = model_path
183-
.join("tokenizer_config.json")
184-
.exists()
185-
.then(|| {
186-
let content = std::fs::read_to_string(model_path.join("tokenizer_config.json")).ok()?;
187-
serde_json::from_str::<serde_json::Value>(&content)
188-
.ok()?
189-
.get("padding_side")
190-
.and_then(|v| serde_json::from_value::<PaddingSide>(v.clone()).ok())
191-
})
192-
.flatten()
193-
.unwrap_or_else(|| {
194-
tracing::warn!("Could not determine `padding_side` from `tokenizer_config.json`, hence using `right` padding by default.");
195-
PaddingSide::Right
196-
});
197-
198175
Ok(Self {
199176
session: Mutex::new(session),
200177
config,
178+
tokenizer_config,
201179
token_type_ids,
202180
token_type_ids_key,
203181
position_ids,
204182
past_key_values,
205183
pool,
206-
padding_side,
207184
})
208185
}
209186
}
@@ -433,7 +410,8 @@ impl Backend for OrtBackend {
433410
let batch_size = batch.len();
434411
let max_length = batch.max_length as usize;
435412

436-
let (model_inputs, masking) = self.prepare_inputs(&batch, &self.padding_side)?;
413+
let (model_inputs, masking) =
414+
self.prepare_inputs(&batch, &self.tokenizer_config.padding_side)?;
437415

438416
let inputs = self.prepare_ort_inputs(
439417
model_inputs.input_ids,
@@ -482,7 +460,7 @@ impl Backend for OrtBackend {
482460
};
483461

484462
let pooled_embeddings = match self.pool {
485-
Pool::Cls => match self.padding_side {
463+
Pool::Cls => match self.tokenizer_config.padding_side {
486464
PaddingSide::Left => {
487465
if masking {
488466
let mut cls_embeddings = Vec::new();
@@ -506,7 +484,7 @@ impl Backend for OrtBackend {
506484
}
507485
PaddingSide::Right => outputs.slice(s![.., 0, ..]).into_owned().into_dyn(),
508486
},
509-
Pool::LastToken => match self.padding_side {
487+
Pool::LastToken => match self.tokenizer_config.padding_side {
510488
// NOTE: when using left-padding, the last-token is always in the last position
511489
// as the padding tokens are on the left (note that given that the last token
512490
// in the sequence is the EOS token we need to use the last - 1.
@@ -556,7 +534,7 @@ impl Backend for OrtBackend {
556534
input_lengths = input_lengths.select(Axis(0), &indices);
557535
};
558536

559-
match self.padding_side {
537+
match self.tokenizer_config.padding_side {
560538
PaddingSide::Left => {
561539
let mut mean_embeddings = Vec::new();
562540
for (batch_idx, &seq_length) in input_lengths.iter().enumerate() {
@@ -609,7 +587,7 @@ impl Backend for OrtBackend {
609587
// member of the batch that require pooling
610588
// or if batch_size > 1 and the members of the batch have different lengths
611589
let raw_embeddings = if (masking || has_pooling_requests) && batch_size > 1 {
612-
match self.padding_side {
590+
match self.tokenizer_config.padding_side {
613591
PaddingSide::Left => {
614592
let mut final_indices: Vec<usize> =
615593
Vec::with_capacity(batch_size * max_length);
@@ -680,7 +658,7 @@ impl Backend for OrtBackend {
680658
fn predict(&self, batch: Batch) -> Result<Predictions, BackendError> {
681659
let batch_size = batch.len();
682660

683-
let (model_inputs, _) = self.prepare_inputs(&batch, &self.padding_side)?;
661+
let (model_inputs, _) = self.prepare_inputs(&batch, &self.tokenizer_config.padding_side)?;
684662

685663
let inputs = self.prepare_ort_inputs(
686664
model_inputs.input_ids,

0 commit comments

Comments
 (0)