1use serde::{Deserialize, Serialize};
4use std::{fmt::Display, hash::Hasher, net::Ipv6Addr};
5
6#[derive(Debug, Serialize, Deserialize, Default, Copy, Clone, Eq, PartialEq)]
7pub struct Path {
8 pub id: u128,
9 pub prefix: u8,
10}
11
12impl Path {
13 pub fn root() -> Self {
14 Self { id: 0, prefix: 0 }
15 }
16
17 pub fn leaf(id: u128) -> Self {
18 Self { id, prefix: 128 }
19 }
20
21 pub fn is_leaf(self) -> bool {
22 self.prefix == 128
23 }
24
25 fn mask(prefix: u8) -> u128 {
26 1u128.checked_shl(prefix as _).unwrap_or(0).wrapping_sub(1)
27 }
28
29 pub fn contains(self, other: Self) -> bool {
30 self.prefix <= other.prefix && (self.id ^ other.id) & Self::mask(self.prefix) == 0
31 }
32
33 pub fn children(self) -> Option<(Path, Path)> {
34 if self.prefix == 128 {
35 return None;
36 }
37
38 let left = Path {
39 id: self.id,
40 prefix: self.prefix + 1,
41 };
42 let right = Path {
43 id: self.id | (1 << self.prefix),
44 prefix: self.prefix + 1,
45 };
46 Some((left, right))
47 }
48
49 pub fn lca(self, other: Self) -> Self {
50 let prefix = (self.id ^ other.id).trailing_zeros() as u8;
51 let prefix = prefix.min(self.prefix).min(other.prefix);
52 let id = self.id & Self::mask(prefix);
53 Self { id, prefix }
54 }
55}
56
57impl Display for Path {
58 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59 let ipv6 = Ipv6Addr::from(self.id.to_le_bytes());
60 write!(f, "{}/{}", ipv6, self.prefix)
61 }
62}
63
64#[derive(Debug, Eq, PartialEq)]
65struct Node {
66 path: Path,
67 hash: u128,
68 left: Tree,
69 right: Tree,
70}
71
72type Tree = Option<Box<Node>>;
73
74fn insert(tree: Tree, id: u128, hash: u128) -> Tree {
75 let Some(mut node) = tree else {
76 return Some(Box::new(Node {
77 path: Path::leaf(id),
78 hash,
79 left: None,
80 right: None,
81 }));
82 };
83
84 let lca = node.path.lca(Path::leaf(id));
85 if lca.is_leaf() {
86 node.hash = hash;
87 return Some(node);
88 }
89
90 if lca != node.path {
91 let new_node = Box::new(Node {
92 path: lca,
93 hash: 0,
94 left: None,
95 right: None,
96 });
97 let old_node = std::mem::replace(&mut node, new_node);
98 if old_node.path.id & (1 << lca.prefix) != 0 {
99 node.right = Some(old_node);
100 } else {
101 node.left = Some(old_node);
102 }
103 }
104
105 if id & (1 << lca.prefix) != 0 {
106 node.right = insert(node.right, id, hash);
107 } else {
108 node.left = insert(node.left, id, hash);
109 }
110
111 let mut data = [0; 32];
112 data[..16].copy_from_slice(&node.left.as_ref().unwrap().hash.to_le_bytes());
113 data[16..].copy_from_slice(&node.right.as_ref().unwrap().hash.to_le_bytes());
114 node.hash = xxhash_rust::xxh3::xxh3_128(&data);
115 Some(node)
116}
117
118fn find(mut tree: &Tree, path: Path) -> Option<(Path, u128)> {
119 loop {
120 let Some(node) = tree else {
121 return None;
122 };
123
124 if path.contains(node.path) {
125 return Some((node.path, node.hash));
126 }
127
128 if !node.path.contains(path) {
129 return None;
130 }
131
132 tree = if path.id & (1 << node.path.prefix) != 0 {
133 &node.right
134 } else {
135 &node.left
136 };
137 }
138}
139
140#[derive(Debug, Eq, PartialEq)]
141pub struct Merkle {
142 tree: Tree,
143}
144
145impl Merkle {
146 pub const fn new() -> Self {
147 Self { tree: None }
148 }
149
150 pub fn insert(&mut self, id: u128, data: &[u8]) {
151 let mut xxh3 = xxhash_rust::xxh3::Xxh3::new();
152 xxh3.write_u128(id);
153 xxh3.write(data);
154 let hash = xxh3.digest128();
155 self.tree = insert(self.tree.take(), id, hash);
156 }
157
158 pub fn find(&self, path: Path) -> Option<(Path, u128)> {
159 find(&self.tree, path)
160 }
161}
162
163#[cfg(test)]
164mod tests {
165 use super::*;
166
167 fn count_nodes(tree: &Merkle) -> usize {
168 fn inner(tree: &Tree) -> usize {
169 match tree {
170 None => 0,
171 Some(node) => 1 + inner(&node.left) + inner(&node.right),
172 }
173 }
174 inner(&tree.tree)
175 }
176
177 fn tree_from_ints(ints: impl Iterator<Item = u128>) -> Merkle {
178 let mut tree = Merkle::new();
179 for i in ints {
180 tree.insert(i, &i.to_le_bytes());
181 }
182 tree
183 }
184
185 #[test]
186 fn tree_size() {
187 for i in 1..100 {
188 let t = tree_from_ints(0..i as _);
189 assert_eq!(count_nodes(&t), 2 * i - 1);
190 }
191 }
192
193 #[test]
194 fn tree_size_rev() {
195 for i in 1..100 {
196 let t = tree_from_ints((0..i as _).rev());
197 assert_eq!(count_nodes(&t), 2 * i - 1);
198 }
199 }
200
201 #[test]
202 fn test_shuffle_eq() {
203 for i in 0..100 {
204 let a = tree_from_ints(0..i);
205 let b = tree_from_ints((0..i).rev());
206 assert_eq!(a, b);
207 }
208 }
209}