Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 113 additions & 0 deletions tokenizers/benches/whitespace_benchmark.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
#[macro_use]
extern crate criterion;

use criterion::{Criterion, Throughput};
use tokenizers::pre_tokenizers::whitespace::{Whitespace, WhitespaceOptimized};
use tokenizers::{OffsetReferential, OffsetType, PreTokenizedString, PreTokenizer};

fn bench_whitespace_comparison(c: &mut Criterion) {
let mut group = c.benchmark_group("whitespace-pre-tokenizers");

// Test data with various characteristics
let test_cases = vec![
("simple", "Hello world! How are you doing?"),
(
"mixed",
"This is a test with numbers 123 and symbols @#$% and unicode: café résumé",
),
(
"whitespace_heavy",
"Multiple spaces\tand\nnewlines\r\nhere",
),
("symbol_heavy", "Hello!@#$%^&*()world?><>{}[]|\\"),
(
"word_heavy",
"This is a very long sentence with many words that should be tokenized properly",
),
("unicode_heavy", "αβγ δέζ ηθι κλμ νξο πρσ τυφ χψω"),
("mixed_unicode", "Hello 123 αβγ !@# world δέζ ηθι"),
];

for (name, text) in test_cases {
let data_len = text.len() as u64;
group.throughput(Throughput::Bytes(data_len));

// Benchmark original regex-based implementation
group.bench_function(format!("{}-original", name), |b| {
b.iter(|| {
let mut pretokenized = PreTokenizedString::from(text);
let pretok = Whitespace {};
pretok.pre_tokenize(&mut pretokenized).unwrap();
let _result = pretokenized
.get_splits(OffsetReferential::Original, OffsetType::Byte)
.into_iter()
.map(|(s, o, _)| (s, o))
.collect::<Vec<_>>();
})
});

// Benchmark optimized byte-level implementation
group.bench_function(format!("{}-optimized", name), |b| {
b.iter(|| {
let mut pretokenized = PreTokenizedString::from(text);
let pretok = WhitespaceOptimized {};
pretok.pre_tokenize(&mut pretokenized).unwrap();
let _result = pretokenized
.get_splits(OffsetReferential::Original, OffsetType::Byte)
.into_iter()
.map(|(s, o, _)| (s, o))
.collect::<Vec<_>>();
})
});
}

group.finish();
}

fn bench_large_text(c: &mut Criterion) {
let mut group = c.benchmark_group("whitespace-large-text");

// Create a large text by repeating patterns
let base_text =
"Hello world! This is a test with numbers 123 and symbols @#$% and unicode: café résumé. ";
let large_text: String = base_text.repeat(1000); // ~50KB of text
let data_len = large_text.len() as u64;

group.throughput(Throughput::Bytes(data_len));

group.bench_function("large-original", |b| {
b.iter(|| {
let mut pretokenized = PreTokenizedString::from(large_text.as_str());
let pretok = Whitespace {};
pretok.pre_tokenize(&mut pretokenized).unwrap();
let _result = pretokenized
.get_splits(OffsetReferential::Original, OffsetType::Byte)
.into_iter()
.map(|(s, o, _)| (s, o))
.collect::<Vec<_>>();
})
});

group.bench_function("large-optimized", |b| {
b.iter(|| {
let mut pretokenized = PreTokenizedString::from(large_text.as_str());
let pretok = WhitespaceOptimized {};
pretok.pre_tokenize(&mut pretokenized).unwrap();
let _result = pretokenized
.get_splits(OffsetReferential::Original, OffsetType::Byte)
.into_iter()
.map(|(s, o, _)| (s, o))
.collect::<Vec<_>>();
})
});

group.finish();
}

criterion_group! {
name = whitespace_benches;
config = Criterion::default().sample_size(20);
targets = bench_whitespace_comparison, bench_large_text
}

criterion_main!(whitespace_benches);
2 changes: 1 addition & 1 deletion tokenizers/src/models/wordpiece/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ impl WordPiece {
pub fn read_bytes(vocab: &[u8]) -> Result<Vocab> {
let file = BufReader::new(vocab);

let mut vocab = HashMap::new();
let mut vocab = AHashMap::new();
for (index, line) in file.lines().enumerate() {
let line = line?;
vocab.insert(line.trim_end().to_owned(), index as u32);
Expand Down
218 changes: 218 additions & 0 deletions tokenizers/src/pre_tokenizers/whitespace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,141 @@ impl PreTokenizer for Whitespace {
}
}

/// Optimized whitespace pre-tokenizer that uses byte-level scanning instead of regex.
/// This provides better performance but may have slightly different behavior in edge cases
/// compared to the regex-based implementation.
#[derive(Clone, Debug, PartialEq, Eq)]
#[macro_rules_attribute(impl_serde_type!)]
pub struct WhitespaceOptimized;

impl Default for WhitespaceOptimized {
fn default() -> Self {
Self
}
}

impl PreTokenizer for WhitespaceOptimized {
fn pre_tokenize(&self, pretokenized: &mut PreTokenizedString) -> Result<()> {
pretokenized.split(|_, normalized| {
normalized.split(Invert(WhitespacePattern), SplitDelimiterBehavior::Removed)
})
}
}

/// Custom pattern implementation for optimized whitespace splitting
/// This implements the equivalent of the regex r"\w+|[^\w\s]+" but with manual byte scanning
struct WhitespacePattern;

impl crate::tokenizer::pattern::Pattern for WhitespacePattern {
fn find_matches(&self, inside: &str) -> Result<Vec<(crate::Offsets, bool)>> {
if inside.is_empty() {
return Ok(vec![((0, 0), false)]);
}

let mut matches = Vec::new();
let mut current_start = 0;
let mut current_end = 0;
let mut current_type = None; // None = whitespace, Some(true) = word, Some(false) = symbol

let mut i = 0;
while i < inside.len() {
let char_start = inside[i..].chars().next().unwrap();
let char_len = char_start.len_utf8();

let is_whitespace = char_start.is_whitespace();
let is_word_char = char_start.is_alphanumeric() || char_start == '_';
let is_symbol = !is_whitespace && !is_word_char;

match (current_type, is_whitespace, is_word_char, is_symbol) {
(None, true, _, _) => {
// Continue in whitespace
i += char_len;
}
(None, false, true, _) => {
// Transition from whitespace to word
current_start = i;
current_end = i + char_len;
current_type = Some(true);
i += char_len;
}
(None, false, false, true) => {
// Transition from whitespace to symbol
current_start = i;
current_end = i + char_len;
current_type = Some(false);
i += char_len;
}
(None, false, false, false) => {
// This shouldn't happen since a char is either whitespace, word, or symbol
// But handle it gracefully by treating as symbol
current_start = i;
current_end = i + char_len;
current_type = Some(false);
i += char_len;
}
(Some(true), true, _, _) => {
// Transition from word to whitespace - finish word
matches.push(((current_start, current_end), true));
current_type = None;
i += char_len;
}
(Some(true), false, true, _) => {
// Continue in word
current_end = i + char_len;
i += char_len;
}
(Some(true), false, false, true) => {
// Transition from word to symbol - finish word, start symbol
matches.push(((current_start, current_end), true));
current_start = i;
current_end = i + char_len;
current_type = Some(false);
i += char_len;
}
(Some(true), false, false, false) => {
// This shouldn't happen, but handle as symbol
matches.push(((current_start, current_end), true));
current_start = i;
current_end = i + char_len;
current_type = Some(false);
i += char_len;
}
(Some(false), true, _, _) => {
// Transition from symbol to whitespace - finish symbol
matches.push(((current_start, current_end), true));
current_type = None;
i += char_len;
}
(Some(false), false, true, _) => {
// Transition from symbol to word - finish symbol, start word
matches.push(((current_start, current_end), true));
current_start = i;
current_end = i + char_len;
current_type = Some(true);
i += char_len;
}
(Some(false), false, false, true) => {
// Continue in symbol
current_end = i + char_len;
i += char_len;
}
(Some(false), false, false, false) => {
// This shouldn't happen, but handle as symbol
current_end = i + char_len;
i += char_len;
}
}
}

// Don't forget the last token
if current_type.is_some() {
matches.push(((current_start, current_end), true));
}

Ok(matches)
}
}

#[derive(Copy, Clone, Debug, PartialEq, Eq)]
#[macro_rules_attribute(impl_serde_type!)]
pub struct WhitespaceSplit;
Expand Down Expand Up @@ -102,4 +237,87 @@ mod tests {
);
}
}

#[test]
fn optimized_compatibility() {
// Test that the optimized version produces the same results as the original
let test_cases = vec![
"Hello world!",
"How are you doing?",
"This is a test with numbers 123 and symbols @#$%",
"Multiple spaces",
"Tabs\tand\nnewlines",
"Unicode: café résumé naïve",
"Mixed: Hello123!@# world",
"Edge cases: a.b,c;d:e",
"Empty string:",
"Only spaces: ",
"Only symbols: !@#$%",
"Only words: hello world",
"Numbers: 123 456 789",
"Underscores: hello_world test_case",
"Special chars: αβγ δέζ ηθι",
];

for test_case in test_cases {
let mut original = PreTokenizedString::from(test_case);
let mut optimized = PreTokenizedString::from(test_case);

let original_pretok = Whitespace {};
let optimized_pretok = WhitespaceOptimized {};

original_pretok.pre_tokenize(&mut original).unwrap();
optimized_pretok.pre_tokenize(&mut optimized).unwrap();

let original_splits = original
.get_splits(OffsetReferential::Original, OffsetType::Byte)
.into_iter()
.map(|(s, o, _)| (s, o))
.collect::<Vec<_>>();

let optimized_splits = optimized
.get_splits(OffsetReferential::Original, OffsetType::Byte)
.into_iter()
.map(|(s, o, _)| (s, o))
.collect::<Vec<_>>();

assert_eq!(
original_splits, optimized_splits,
"Mismatch for test case: '{}'",
test_case
);
}
}

#[test]
fn optimized_edge_cases() {
let pretok = WhitespaceOptimized {};

// Test various edge cases
let edge_cases = vec![
("", vec![]),
(" ", vec![]),
(" ", vec![]),
("a", vec![("a", (0, 1))]),
("!", vec![("!", (0, 1))]),
("a!", vec![("a", (0, 1)), ("!", (1, 2))]),
("!a", vec![("!", (0, 1)), ("a", (1, 2))]),
("a b", vec![("a", (0, 1)), ("b", (2, 3))]),
("a b", vec![("a", (0, 1)), ("b", (3, 4))]),
("a\tb", vec![("a", (0, 1)), ("b", (2, 3))]),
("a\nb", vec![("a", (0, 1)), ("b", (2, 3))]),
("a\r\nb", vec![("a", (0, 1)), ("b", (3, 4))]),
];

for (input, expected) in edge_cases {
let mut pretokenized = PreTokenizedString::from(input);
pretok.pre_tokenize(&mut pretokenized).unwrap();
let result = pretokenized
.get_splits(OffsetReferential::Original, OffsetType::Byte)
.into_iter()
.map(|(s, o, _)| (s, o))
.collect::<Vec<_>>();
assert_eq!(result, expected, "Failed for input: '{}'", input);
}
}
}