Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit ad5720e

Browse files
committed
Add Aho-Corasick, make trie more bare-bones
1 parent 218b095 commit ad5720e

File tree

2 files changed

+165
-66
lines changed

2 files changed

+165
-66
lines changed

‎README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
A collection of classic data structures and algorithms, emphasizing usability, beauty and clarity over full generality. As such, this should be viewed not as a blackbox *library*, but as a whitebox *cookbook* demonstrating the design and implementation of algorithms. I hope it will be useful to students and educators, as well as fans of algorithmic programming contests.
77

8-
This repository is distributed under the [MIT License](LICENSE). The license text need not be included in contest submissions, though I would appreciate linking back to this repo for others to find. Enjoy!
8+
This repository is distributed under the [MIT License](LICENSE). Contest submissions need not include the license text. Enjoy!
99

1010
## For Students and Educators
1111

@@ -52,4 +52,4 @@ Rather than try to persuade you with words, this repository aims to show by exam
5252
- [Arithmetic](src/math/num.rs): rational and complex numbers, linear algebra, safe modular arithmetic
5353
- [FFT](src/math/fft.rs): fast Fourier transform, number theoretic transform, convolution
5454
- [Scanner](src/scanner.rs): utility for reading input data ergonomically
55-
- [String processing](src/string_proc.rs): Knuth-Morris-Pratt string matching, suffix arrays, Manacher's palindrome search
55+
- [String processing](src/string_proc.rs): Knuth-Morris-Pratt and Aho-Corasick string matching, suffix array, Manacher's linear-time palindrome search

‎src/string_proc.rs

Lines changed: 163 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,60 @@
11
//! String processing algorithms.
22
use std::cmp::{max, min};
3+
use std::collections::{hash_map::Entry, HashMap, VecDeque};
34

4-
/// Data structure for Knuth-Morris-Pratt string matching against a pattern.
5-
pub struct Matcher<'a, T> {
5+
/// Prefix trie, easily augmentable by adding more fields and/or methods
6+
pub struct Trie<C: std::hash::Hash + Eq> {
7+
links: Vec<HashMap<C, usize>>,
8+
}
9+
10+
impl<C: std::hash::Hash + Eq> Default for Trie<C> {
11+
/// Creates an empty trie with a root node.
12+
fn default() -> Self {
13+
Self {
14+
links: vec![HashMap::new()],
15+
}
16+
}
17+
}
18+
19+
impl<C: std::hash::Hash + Eq> Trie<C> {
20+
/// Inserts a word into the trie, and returns the index of its node.
21+
pub fn insert(&mut self, word: impl IntoIterator<Item = C>) -> usize {
22+
let mut node = 0;
23+
24+
for ch in word {
25+
let len = self.links.len();
26+
node = match self.links[node].entry(ch) {
27+
Entry::Occupied(entry) => *entry.get(),
28+
Entry::Vacant(entry) => {
29+
entry.insert(len);
30+
self.links.push(HashMap::new());
31+
len
32+
}
33+
}
34+
}
35+
node
36+
}
37+
38+
/// Finds a word in the trie, and returns the index of its node.
39+
pub fn get(&self, word: impl IntoIterator<Item = C>) -> Option<usize> {
40+
let mut node = 0;
41+
for ch in word {
42+
node = *self.links[node].get(&ch)?;
43+
}
44+
Some(node)
45+
}
46+
}
47+
48+
/// Single-pattern matching with the Knuth-Morris-Pratt algorithm
49+
pub struct Matcher<'a, C: Eq> {
650
/// The string pattern to search for.
7-
pub pattern: &'a [T],
51+
pub pattern: &'a [C],
852
/// KMP match failure automaton. fail[i] is the length of the longest
9-
/// proper prefix-suffix of pattern[0...i].
53+
/// proper prefix-suffix of pattern[0..=i].
1054
pub fail: Vec<usize>,
1155
}
1256

13-
impl<'a, T: Eq> Matcher<'a, T> {
57+
impl<'a, C: Eq> Matcher<'a, C> {
1458
/// Precomputes the automaton that allows linear-time string matching.
1559
///
1660
/// # Example
@@ -33,7 +77,7 @@ impl<'a, T: Eq> Matcher<'a, T> {
3377
/// # Panics
3478
///
3579
/// Panics if pattern is empty.
36-
pub fn new(pattern: &'a [T]) -> Self {
80+
pub fn new(pattern: &'a [C]) -> Self {
3781
let mut fail = Vec::with_capacity(pattern.len());
3882
fail.push(0);
3983
let mut len = 0;
@@ -49,10 +93,10 @@ impl<'a, T: Eq> Matcher<'a, T> {
4993
Self { pattern, fail }
5094
}
5195

52-
/// KMP algorithm, sets matches[i] = length of longest prefix of pattern
53-
/// matching a suffix of text[0...i].
54-
pub fn kmp_match(&self, text: &[T]) -> Vec<usize> {
55-
let mut matches = Vec::with_capacity(text.len());
96+
/// KMP algorithm, sets match_lens[i] = length of longest prefix of pattern
97+
/// matching a suffix of text[0..=i].
98+
pub fn kmp_match(&self, text: &[C]) -> Vec<usize> {
99+
let mut match_lens = Vec::with_capacity(text.len());
56100
let mut len = 0;
57101
for ch in text {
58102
if len == self.pattern.len() {
@@ -64,9 +108,94 @@ impl<'a, T: Eq> Matcher<'a, T> {
64108
if self.pattern[len] == *ch {
65109
len += 1;
66110
}
67-
matches.push(len);
111+
match_lens.push(len);
68112
}
69-
matches
113+
match_lens
114+
}
115+
}
116+
117+
/// Multi-pattern matching with the Aho-Corasick algorithm
118+
pub struct MultiMatcher<C: std::hash::Hash + Eq> {
119+
/// A prefix trie storing the string patterns to search for.
120+
pub trie: Trie<C>,
121+
/// Stores which completed pattern string each node corresponds to.
122+
pub pat_id: Vec<Option<usize>>,
123+
/// Aho-Corasick failure automaton. fail[i] is the node corresponding to the
124+
/// longest prefix-suffix of the node corresponding to i.
125+
pub fail: Vec<usize>,
126+
/// Shortcut to the next match along the failure chain, or to the root.
127+
pub fast: Vec<usize>,
128+
}
129+
130+
impl<C: std::hash::Hash + Eq> MultiMatcher<C> {
131+
fn next(trie: &Trie<C>, fail: &[usize], mut node: usize, ch: &C) -> usize {
132+
loop {
133+
if let Some(&child) = trie.links[node].get(ch) {
134+
return child;
135+
} else if node == 0 {
136+
return 0;
137+
}
138+
node = fail[node];
139+
}
140+
}
141+
142+
/// Precomputes the automaton that allows linear-time string matching.
143+
/// If there are duplicate patterns, all but one copy will be ignored.
144+
pub fn new(patterns: Vec<impl IntoIterator<Item = C>>) -> Self {
145+
let mut trie = Trie::default();
146+
let pat_nodes: Vec<usize> = patterns.into_iter().map(|pat| trie.insert(pat)).collect();
147+
148+
let mut pat_id = vec![None; trie.links.len()];
149+
for (i, node) in pat_nodes.into_iter().enumerate() {
150+
pat_id[node] = Some(i);
151+
}
152+
153+
let mut fail = vec![0; trie.links.len()];
154+
let mut fast = vec![0; trie.links.len()];
155+
let mut q: VecDeque<usize> = trie.links[0].values().cloned().collect();
156+
157+
while let Some(node) = q.pop_front() {
158+
for (ch, &child) in &trie.links[node] {
159+
let nx = Self::next(&trie, &fail, fail[node], &ch);
160+
fail[child] = nx;
161+
fast[child] = if pat_id[nx].is_some() { nx } else { fast[nx] };
162+
q.push_back(child);
163+
}
164+
}
165+
166+
Self {
167+
trie,
168+
pat_id,
169+
fail,
170+
fast,
171+
}
172+
}
173+
174+
/// Aho-Corasick algorithm, sets match_nodes[i] = node corresponding to
175+
/// longest prefix of some pattern matching a suffix of text[0..=i].
176+
pub fn ac_match(&self, text: &[C]) -> Vec<usize> {
177+
let mut match_nodes = Vec::with_capacity(text.len());
178+
let mut node = 0;
179+
for ch in text {
180+
node = Self::next(&self.trie, &self.fail, node, &ch);
181+
match_nodes.push(node);
182+
}
183+
match_nodes
184+
}
185+
186+
/// For each non-empty match, returns where in the text it ends, and the index
187+
/// of the corresponding pattern.
188+
pub fn get_end_pos_and_pat_id(&self, match_nodes: &[usize]) -> Vec<(usize, usize)> {
189+
let mut res = vec![];
190+
for (text_pos, &(mut node)) in match_nodes.iter().enumerate() {
191+
while node != 0 {
192+
if let Some(id) = self.pat_id[node] {
193+
res.push((text_pos + 1, id));
194+
}
195+
node = self.fast[node];
196+
}
197+
}
198+
res
70199
}
71200
}
72201

@@ -155,39 +284,6 @@ impl SuffixArray {
155284
}
156285
}
157286

158-
/// Prefix trie
159-
#[derive(Default)]
160-
pub struct Trie<K: std::hash::Hash + Eq> {
161-
count: usize,
162-
branches: std::collections::HashMap<K, Trie<K>>,
163-
}
164-
165-
impl<K: std::hash::Hash + Eq + Default> Trie<K> {
166-
/// Inserts a word into the trie.
167-
pub fn insert(&mut self, word: impl IntoIterator<Item = K>) {
168-
let mut node = self;
169-
node.count += 1;
170-
171-
for ch in word {
172-
node = { node }.branches.entry(ch).or_default();
173-
node.count += 1;
174-
}
175-
}
176-
177-
/// Computes the number of inserted words that start with the given prefix.
178-
pub fn get(&self, prefix: impl IntoIterator<Item = K>) -> usize {
179-
let mut node = self;
180-
181-
for ch in prefix {
182-
match node.branches.get(&ch) {
183-
Some(sub) => node = sub,
184-
None => return 0,
185-
}
186-
}
187-
node.count
188-
}
189-
}
190-
191287
/// Manacher's algorithm for computing palindrome substrings in linear time.
192288
/// pal[2*i] = odd length of palindrome centred at text[i].
193289
/// pal[2*i+1] = even length of palindrome centred at text[i+0.5].
@@ -226,7 +322,7 @@ mod test {
226322
use super::*;
227323

228324
#[test]
229-
fn test_kmp() {
325+
fn test_kmp_matching() {
230326
let text = b"banana";
231327
let pattern = b"ana";
232328

@@ -235,6 +331,27 @@ mod test {
235331
assert_eq!(matches, vec![0, 1, 2, 3, 2, 3]);
236332
}
237333

334+
#[test]
335+
fn test_ac_matching() {
336+
let text = b"banana bans, apple benefits.";
337+
let dict = vec![
338+
"banana".bytes(),
339+
"benefit".bytes(),
340+
"banapple".bytes(),
341+
"ban".bytes(),
342+
"fit".bytes(),
343+
];
344+
345+
let matcher = MultiMatcher::new(dict);
346+
let match_nodes = matcher.ac_match(text);
347+
let end_pos_and_id = matcher.get_end_pos_and_pat_id(&match_nodes);
348+
349+
assert_eq!(
350+
end_pos_and_id,
351+
vec![(3, 3), (6, 0), (10, 3), (26, 1), (26, 4)]
352+
);
353+
}
354+
238355
#[test]
239356
fn test_suffix_array() {
240357
let text1 = b"bobocel";
@@ -258,24 +375,6 @@ mod test {
258375
}
259376
}
260377

261-
#[test]
262-
fn test_trie() {
263-
let dict = vec!["banana", "benefit", "banapple", "ban"];
264-
265-
let trie = dict.into_iter().fold(Trie::default(), |mut trie, word| {
266-
Trie::insert(&mut trie, word.bytes());
267-
trie
268-
});
269-
270-
assert_eq!(trie.get("".bytes()), 4);
271-
assert_eq!(trie.get("b".bytes()), 4);
272-
assert_eq!(trie.get("ba".bytes()), 3);
273-
assert_eq!(trie.get("ban".bytes()), 3);
274-
assert_eq!(trie.get("bana".bytes()), 2);
275-
assert_eq!(trie.get("banan".bytes()), 1);
276-
assert_eq!(trie.get("bane".bytes()), 0);
277-
}
278-
279378
#[test]
280379
fn test_palindrome() {
281380
let text = b"banana";

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /