diff --git a/rust/lance-tokenizer/src/icu.rs b/rust/lance-tokenizer/src/icu.rs index 4e36f23115f..e10e7eecc82 100644 --- a/rust/lance-tokenizer/src/icu.rs +++ b/rust/lance-tokenizer/src/icu.rs @@ -33,6 +33,40 @@ pub struct IcuTokenStream { index: usize, } +fn push_token(tokens: &mut Vec, text: &str, offset_from: usize, offset_to: usize) { + if offset_from == offset_to { + return; + } + + let token_text = &text[offset_from..offset_to]; + if token_text.chars().any(char::is_alphanumeric) { + tokens.push(Token { + offset_from, + offset_to, + position: tokens.len(), + text: token_text.to_owned(), + position_length: 1, + }); + } +} + +fn push_tokens_split_on_non_alphanumeric( + tokens: &mut Vec, + text: &str, + offset_from: usize, + offset_to: usize, +) { + let mut part_start = offset_from; + for (relative_offset, c) in text[offset_from..offset_to].char_indices() { + if !c.is_alphanumeric() { + let delimiter_offset = offset_from + relative_offset; + push_token(tokens, text, part_start, delimiter_offset); + part_start = delimiter_offset + c.len_utf8(); + } + } + push_token(tokens, text, part_start, offset_to); +} + impl TokenStream for IcuTokenStream { fn advance(&mut self) -> bool { if self.index < self.tokens.len() { @@ -63,16 +97,7 @@ impl Tokenizer for IcuTokenizer { }; for offset_to in boundaries { - let token_text = &text[offset_from..offset_to]; - if token_text.chars().any(char::is_alphanumeric) { - tokens.push(Token { - offset_from, - offset_to, - position: tokens.len(), - text: token_text.to_owned(), - position_length: 1, - }); - } + push_tokens_split_on_non_alphanumeric(&mut tokens, text, offset_from, offset_to); offset_from = offset_to; } @@ -121,7 +146,27 @@ mod tests { .iter() .map(|token| token.text.as_str()) .collect::>(), - vec!["Mark'd", "ye", "his", "words"] + vec!["Mark", "d", "ye", "his", "words"] + ); + } + + #[test] + fn test_icu_tokenizer_splits_on_non_alphanumeric() { + let tokens = collect_tokens("foo_bar__baz-alpha.beta"); + + assert_eq!( + tokens + .iter() + .map(|token| token.text.as_str()) + .collect::>(), + vec!["foo", "bar", "baz", "alpha", "beta"] + ); + assert_eq!( + tokens + .iter() + .map(|token| (token.offset_from, token.offset_to, token.position)) + .collect::>(), + vec![(0, 3, 0), (4, 7, 1), (9, 12, 2), (13, 18, 3), (19, 23, 4)] ); } }