You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: backends/ort/src/lib.rs
+33-55Lines changed: 33 additions & 55 deletions
Original file line number
Diff line number
Diff line change
@@ -50,13 +50,20 @@ impl From<ConfigValidator> for Config {
50
50
}
51
51
}
52
52
53
-
#[derive(Debug,Clone,Deserialize)]
53
+
#[derive(Debug,Clone,Default,Deserialize)]
54
54
#[serde(rename_all = "lowercase")]
55
55
enumPaddingSide{
56
56
Left,
57
+
#[default]
57
58
Right,
58
59
}
59
60
61
+
#[derive(Debug,Clone,Deserialize)]
62
+
pubstructTokenizerConfig{
63
+
#[serde(default)]
64
+
padding_side:PaddingSide,
65
+
}
66
+
60
67
structModelInputs{
61
68
pubinput_ids: ndarray::Array2<i64>,
62
69
pubattention_mask: ndarray::Array2<i64>,
@@ -69,6 +76,7 @@ struct ModelInputs {
69
76
pubstructOrtBackend{
70
77
session:Mutex<Session>,
71
78
config:Config,
79
+
tokenizer_config:TokenizerConfig,
72
80
73
81
token_type_ids:bool,
74
82
// NOTE: required since the key can either be `token_type_ids` or `input_type`
@@ -77,7 +85,6 @@ pub struct OrtBackend {
77
85
past_key_values:bool,
78
86
79
87
pool:Pool,
80
-
padding_side:PaddingSide,
81
88
}
82
89
83
90
implOrtBackend{
@@ -113,11 +120,26 @@ impl OrtBackend {
113
120
};
114
121
115
122
let config:Config = {
116
-
let content = std::fs::read_to_string(&model_path.join("config.json"))
123
+
let content = std::fs::read_to_string(model_path.join("config.json"))
117
124
.map_err(|e| BackendError::Start(format!("Failed to read `config.json`: {}", e)))?;
118
125
serde_json::from_str(&content)
119
126
.map_err(|e| BackendError::Start(format!("Failed to parse `config.json`: {}", e)))?
120
127
};
128
+
129
+
let tokenizer_config_path = model_path.join("tokenizer_config.json");
130
+
let tokenizer_config:TokenizerConfig = if tokenizer_config_path.exists(){
131
+
let content = std::fs::read_to_string(&tokenizer_config_path).map_err(|e| {
132
+
BackendError::Start(format!("Failed to read `tokenizer_config.json`: {}", e))
133
+
})?;
134
+
serde_json::from_str(&content).map_err(|e| {
135
+
BackendError::Start(format!("Failed to parse `tokenizer_config.json`: {}", e))
136
+
})?
137
+
}else{
138
+
TokenizerConfig{
139
+
padding_side:PaddingSide::default(),
140
+
}
141
+
};
142
+
121
143
let session = Session::builder()
122
144
.s()?
123
145
.with_intra_threads(num_cpus::get())
@@ -150,60 +172,15 @@ impl OrtBackend {
150
172
}
151
173
}
152
174
153
-
154
-
let pad_token_id = config
155
-
.as_ref()
156
-
.and_then(|c| c.pad_token_id.or(c.eos_token_id))
157
-
.unwrap_or(0);
158
-
159
-
// NOTE: given that `hidden_size`, `num_hidden_layers`, and `num_key_value_heads` are set
160
-
// to `Option` in the `Config`, but required if `past_key_values` is an input of ONNX, then
161
-
// those should be validated in advance
162
-
if past_key_values {
163
-
match&config {
164
-
Some(config) => {
165
-
if config.hidden_size.is_none()
166
-
|| config.num_hidden_layers.is_none()
167
-
|| config.num_key_value_heads.is_none()
168
-
{
169
-
returnErr(BackendError::Start(
170
-
"`config.json` doesn't contain all required keys: `hidden_size`, `num_hidden_layers`, and `num_key_value_heads`.".into()
171
-
));
172
-
}
173
-
}
174
-
None => {
175
-
returnErr(BackendError::Start(format!(
176
-
"`config.json` not found at {config_path:?}, but it's required as this ONNX expects `past_key_values` as input, meaning that the `config.json` file should contain at least the keys: `hidden_size`, `num_hidden_layers`, and `num_key_value_heads`."
177
-
)));
178
-
}
179
-
}
180
-
}
181
-
182
-
let padding_side = model_path
183
-
.join("tokenizer_config.json")
184
-
.exists()
185
-
.then(|| {
186
-
let content = std::fs::read_to_string(model_path.join("tokenizer_config.json")).ok()?;
0 commit comments