zerotrie/byte_phf/builder.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207
// This file is part of ICU4X. For terms of use, please see the file
// called LICENSE at the top level of the ICU4X source tree
// (online at: https://github.com/unicode-org/icu4x/blob/main/LICENSE ).
use super::*;
use crate::error::ZeroTrieBuildError;
use alloc::vec;
use alloc::vec::Vec;
/// To speed up the search algorithm, we limit the number of times the level-2 parameter (q)
/// can hit its max value (initially Q_FAST_MAX) before we try the next level-1 parameter (p).
/// In practice, this has a small impact on the resulting perfect hash, resulting in about
/// 1 in 10000 hash maps that fall back to the slow path.
const MAX_L2_SEARCH_MISSES: usize = 24;
/// Directly compute the perfect hash function.
///
/// Returns `(p, [q_0, q_1, ..., q_(N-1)])`, or an error if the PHF could not be computed.
#[allow(unused_labels)] // for readability
pub fn find(bytes: &[u8]) -> Result<(u8, Vec<u8>), ZeroTrieBuildError> {
let n_usize = bytes.len();
let mut p = 0u8;
let mut qq = vec![0u8; n_usize];
let mut bqs = vec![0u8; n_usize];
let mut seen = vec![false; n_usize];
let max_allowable_p = P_FAST_MAX;
let mut max_allowable_q = Q_FAST_MAX;
#[allow(non_snake_case)]
let N = if n_usize > 0 && n_usize < 256 {
n_usize as u8
} else {
debug_assert!(n_usize == 0 || n_usize == 256);
return Ok((p, qq));
};
'p_loop: loop {
let mut buckets: Vec<(usize, Vec<u8>)> = (0..n_usize).map(|i| (i, vec![])).collect();
for byte in bytes {
let l1 = f1(*byte, p, N) as usize;
buckets[l1].1.push(*byte);
}
buckets.sort_by_key(|(_, v)| -(v.len() as isize));
// println!("New P: p={p:?}, buckets={buckets:?}");
let mut i = 0;
let mut num_max_q = 0;
bqs.fill(0);
seen.fill(false);
'q_loop: loop {
if i == buckets.len() {
for (local_j, real_j) in buckets.iter().map(|(j, _)| *j).enumerate() {
qq[real_j] = bqs[local_j];
}
// println!("Success: p={p:?}, num_max_q={num_max_q:?}, bqs={bqs:?}, qq={qq:?}");
// if num_max_q > 0 {
// println!("num_max_q={num_max_q:?}");
// }
return Ok((p, qq));
}
let mut bucket = buckets[i].1.as_slice();
'byte_loop: for (j, byte) in bucket.iter().enumerate() {
let l2 = f2(*byte, bqs[i], N) as usize;
if seen[l2] {
// println!("Skipping Q: p={p:?}, i={i:?}, byte={byte:}, q={i:?}, l2={:?}", f2(*byte, bqs[i], N));
for k_byte in &bucket[0..j] {
let l2 = f2(*k_byte, bqs[i], N) as usize;
assert!(seen[l2]);
seen[l2] = false;
}
'reset_loop: loop {
if bqs[i] < max_allowable_q {
bqs[i] += 1;
continue 'q_loop;
}
num_max_q += 1;
bqs[i] = 0;
if i == 0 || num_max_q > MAX_L2_SEARCH_MISSES {
if p == max_allowable_p && max_allowable_q != Q_REAL_MAX {
// println!("Could not solve fast function: trying again: {bytes:?}");
max_allowable_q = Q_REAL_MAX;
p = 0;
continue 'p_loop;
} else if p == max_allowable_p {
// If a fallback algorithm for `p` is added, relax this assertion
// and re-run the loop with a higher `max_allowable_p`.
debug_assert_eq!(max_allowable_p, P_REAL_MAX);
// println!("Could not solve PHF function");
return Err(ZeroTrieBuildError::CouldNotSolvePerfectHash);
} else {
p += 1;
continue 'p_loop;
}
}
i -= 1;
bucket = buckets[i].1.as_slice();
for byte in bucket {
let l2 = f2(*byte, bqs[i], N) as usize;
assert!(seen[l2]);
seen[l2] = false;
}
}
} else {
// println!("Marking as seen: i={i:?}, byte={byte:}, l2={:?}", f2(*byte, bqs[i], N));
let l2 = f2(*byte, bqs[i], N) as usize;
seen[l2] = true;
}
}
// println!("Found Q: i={i:?}, q={:?}", bqs[i]);
i += 1;
}
}
}
impl PerfectByteHashMap<Vec<u8>> {
/// Computes a new [`PerfectByteHashMap`].
///
/// (this is a doc-hidden API)
pub fn try_new(keys: &[u8]) -> Result<Self, ZeroTrieBuildError> {
let n_usize = keys.len();
let n = n_usize as u8;
let (p, mut qq) = find(keys)?;
let mut keys_permuted = vec![0; n_usize];
for key in keys {
let l1 = f1(*key, p, n) as usize;
let q = qq[l1];
let l2 = f2(*key, q, n) as usize;
keys_permuted[l2] = *key;
}
let mut result = Vec::with_capacity(n_usize * 2 + 1);
result.push(p);
result.append(&mut qq);
result.append(&mut keys_permuted);
Ok(Self(result))
}
}
#[cfg(test)]
mod tests {
use super::*;
extern crate std;
use std::print;
use std::println;
fn print_byte_to_stdout(byte: u8) {
let c = char::from(byte);
if c.is_ascii_alphanumeric() {
print!("'{c}'");
} else {
print!("0x{byte:X}");
}
}
fn random_alphanums(seed: u64, len: usize) -> Vec<u8> {
use rand::seq::SliceRandom;
use rand::SeedableRng;
const BYTES: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789";
let mut rng = rand_pcg::Lcg64Xsh32::seed_from_u64(seed);
BYTES.choose_multiple(&mut rng, len).copied().collect()
}
#[test]
fn test_random_distributions() {
let mut p_distr = vec![0; 256];
let mut q_distr = vec![0; 256];
for len in 0..50 {
for seed in 0..50 {
let bytes = random_alphanums(seed, len);
let (p, qq) = find(bytes.as_slice()).unwrap();
p_distr[p as usize] += 1;
for q in qq {
q_distr[q as usize] += 1;
}
}
}
println!("p_distr: {p_distr:?}");
println!("q_distr: {q_distr:?}");
let fast_p = p_distr[0..=P_FAST_MAX as usize].iter().sum::<usize>();
let slow_p = p_distr[(P_FAST_MAX + 1) as usize..].iter().sum::<usize>();
let fast_q = q_distr[0..=Q_FAST_MAX as usize].iter().sum::<usize>();
let slow_q = q_distr[(Q_FAST_MAX + 1) as usize..].iter().sum::<usize>();
assert_eq!(2500, fast_p);
assert_eq!(0, slow_p);
assert_eq!(61247, fast_q);
assert_eq!(3, slow_q);
let bytes = random_alphanums(0, 16);
#[allow(non_snake_case)]
let N = u8::try_from(bytes.len()).unwrap();
let (p, qq) = find(bytes.as_slice()).unwrap();
println!("Results:");
for byte in bytes.iter() {
print_byte_to_stdout(*byte);
let l1 = f1(*byte, p, N) as usize;
let q = qq[l1];
let l2 = f2(*byte, q, N) as usize;
println!(" => l1 {l1} => q {q} => l2 {l2}");
}
}
}