ktd/
merkle.rs

1//! Merkle tree implementation, used by the gossip protocol.
2
3use serde::{Deserialize, Serialize};
4use std::{fmt::Display, hash::Hasher, net::Ipv6Addr};
5
6#[derive(Debug, Serialize, Deserialize, Default, Copy, Clone, Eq, PartialEq)]
7pub struct Path {
8    pub id: u128,
9    pub prefix: u8,
10}
11
12impl Path {
13    pub fn root() -> Self {
14        Self { id: 0, prefix: 0 }
15    }
16
17    pub fn leaf(id: u128) -> Self {
18        Self { id, prefix: 128 }
19    }
20
21    pub fn is_leaf(self) -> bool {
22        self.prefix == 128
23    }
24
25    fn mask(prefix: u8) -> u128 {
26        1u128.checked_shl(prefix as _).unwrap_or(0).wrapping_sub(1)
27    }
28
29    pub fn contains(self, other: Self) -> bool {
30        self.prefix <= other.prefix && (self.id ^ other.id) & Self::mask(self.prefix) == 0
31    }
32
33    pub fn children(self) -> Option<(Path, Path)> {
34        if self.prefix == 128 {
35            return None;
36        }
37
38        let left = Path {
39            id: self.id,
40            prefix: self.prefix + 1,
41        };
42        let right = Path {
43            id: self.id | (1 << self.prefix),
44            prefix: self.prefix + 1,
45        };
46        Some((left, right))
47    }
48
49    pub fn lca(self, other: Self) -> Self {
50        let prefix = (self.id ^ other.id).trailing_zeros() as u8;
51        let prefix = prefix.min(self.prefix).min(other.prefix);
52        let id = self.id & Self::mask(prefix);
53        Self { id, prefix }
54    }
55}
56
57impl Display for Path {
58    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59        let ipv6 = Ipv6Addr::from(self.id.to_le_bytes());
60        write!(f, "{}/{}", ipv6, self.prefix)
61    }
62}
63
64#[derive(Debug, Eq, PartialEq)]
65struct Node {
66    path: Path,
67    hash: u128,
68    left: Tree,
69    right: Tree,
70}
71
72type Tree = Option<Box<Node>>;
73
74fn insert(tree: Tree, id: u128, hash: u128) -> Tree {
75    let Some(mut node) = tree else {
76        return Some(Box::new(Node {
77            path: Path::leaf(id),
78            hash,
79            left: None,
80            right: None,
81        }));
82    };
83
84    let lca = node.path.lca(Path::leaf(id));
85    if lca.is_leaf() {
86        node.hash = hash;
87        return Some(node);
88    }
89
90    if lca != node.path {
91        let new_node = Box::new(Node {
92            path: lca,
93            hash: 0,
94            left: None,
95            right: None,
96        });
97        let old_node = std::mem::replace(&mut node, new_node);
98        if old_node.path.id & (1 << lca.prefix) != 0 {
99            node.right = Some(old_node);
100        } else {
101            node.left = Some(old_node);
102        }
103    }
104
105    if id & (1 << lca.prefix) != 0 {
106        node.right = insert(node.right, id, hash);
107    } else {
108        node.left = insert(node.left, id, hash);
109    }
110
111    let mut data = [0; 32];
112    data[..16].copy_from_slice(&node.left.as_ref().unwrap().hash.to_le_bytes());
113    data[16..].copy_from_slice(&node.right.as_ref().unwrap().hash.to_le_bytes());
114    node.hash = xxhash_rust::xxh3::xxh3_128(&data);
115    Some(node)
116}
117
118fn find(mut tree: &Tree, path: Path) -> Option<(Path, u128)> {
119    loop {
120        let Some(node) = tree else {
121            return None;
122        };
123
124        if path.contains(node.path) {
125            return Some((node.path, node.hash));
126        }
127
128        if !node.path.contains(path) {
129            return None;
130        }
131
132        tree = if path.id & (1 << node.path.prefix) != 0 {
133            &node.right
134        } else {
135            &node.left
136        };
137    }
138}
139
140#[derive(Debug, Eq, PartialEq)]
141pub struct Merkle {
142    tree: Tree,
143}
144
145impl Merkle {
146    pub const fn new() -> Self {
147        Self { tree: None }
148    }
149
150    pub fn insert(&mut self, id: u128, data: &[u8]) {
151        let mut xxh3 = xxhash_rust::xxh3::Xxh3::new();
152        xxh3.write_u128(id);
153        xxh3.write(data);
154        let hash = xxh3.digest128();
155        self.tree = insert(self.tree.take(), id, hash);
156    }
157
158    pub fn find(&self, path: Path) -> Option<(Path, u128)> {
159        find(&self.tree, path)
160    }
161}
162
163#[cfg(test)]
164mod tests {
165    use super::*;
166
167    fn count_nodes(tree: &Merkle) -> usize {
168        fn inner(tree: &Tree) -> usize {
169            match tree {
170                None => 0,
171                Some(node) => 1 + inner(&node.left) + inner(&node.right),
172            }
173        }
174        inner(&tree.tree)
175    }
176
177    fn tree_from_ints(ints: impl Iterator<Item = u128>) -> Merkle {
178        let mut tree = Merkle::new();
179        for i in ints {
180            tree.insert(i, &i.to_le_bytes());
181        }
182        tree
183    }
184
185    #[test]
186    fn tree_size() {
187        for i in 1..100 {
188            let t = tree_from_ints(0..i as _);
189            assert_eq!(count_nodes(&t), 2 * i - 1);
190        }
191    }
192
193    #[test]
194    fn tree_size_rev() {
195        for i in 1..100 {
196            let t = tree_from_ints((0..i as _).rev());
197            assert_eq!(count_nodes(&t), 2 * i - 1);
198        }
199    }
200
201    #[test]
202    fn test_shuffle_eq() {
203        for i in 0..100 {
204            let a = tree_from_ints(0..i);
205            let b = tree_from_ints((0..i).rev());
206            assert_eq!(a, b);
207        }
208    }
209}