Skip to content

Commit 24dd9e4

Browse files
committed
Use serde_json to parse PaddingSide instead
1 parent ee1457d commit 24dd9e4

File tree

1 file changed

+6
-21
lines changed

1 file changed

+6
-21
lines changed

backends/ort/src/lib.rs

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -22,24 +22,13 @@ pub struct Config {
2222
pub num_key_value_heads: Option<usize>,
2323
}
2424

25-
#[derive(Debug, Clone)]
25+
#[derive(Debug, Clone, Deserialize)]
26+
#[serde(rename_all = "lowercase")]
2627
enum PaddingSide {
2728
Left,
2829
Right,
2930
}
3031

31-
impl std::str::FromStr for PaddingSide {
32-
type Err = String;
33-
34-
fn from_str(s: &str) -> Result<Self, Self::Err> {
35-
match s.to_lowercase().as_str() {
36-
"left" => Ok(PaddingSide::Left),
37-
"right" => Ok(PaddingSide::Right),
38-
_ => Err(format!("unrecognized `padding_side` value: {}", s)),
39-
}
40-
}
41-
}
42-
4332
struct ModelInputs {
4433
pub input_ids: ndarray::Array2<i64>,
4534
pub attention_mask: ndarray::Array2<i64>,
@@ -171,15 +160,11 @@ impl OrtBackend {
171160
.join("tokenizer_config.json")
172161
.exists()
173162
.then(|| {
174-
std::fs::read_to_string(model_path.join("tokenizer_config.json"))
175-
.ok()?
176-
.parse::<serde_json::Value>()
163+
let content = std::fs::read_to_string(model_path.join("tokenizer_config.json")).ok()?;
164+
serde_json::from_str::<serde_json::Value>(&content)
177165
.ok()?
178-
.get("padding_side")?
179-
.as_str()?
180-
.parse::<PaddingSide>()
181-
.map_err(|e| tracing::warn!("Failed to parse `padding_side` from `tokenizer_config.json`: {}, hence using `right` padding by default.", e))
182-
.ok()
166+
.get("padding_side")
167+
.and_then(|v| serde_json::from_value::<PaddingSide>(v.clone()).ok())
183168
})
184169
.flatten()
185170
.unwrap_or_else(|| {

0 commit comments

Comments
 (0)