1use std::{fmt, num::NonZeroU32};
2
3#[repr(transparent)]
4#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
5pub(super) struct NodeId(NonZeroU32);
6
7impl NodeId {
8 fn new(i: usize) -> Self {
9 debug_assert!(i < u32::MAX as usize);
10 Self(NonZeroU32::new((i + 1) as u32).unwrap())
12 }
13
14 pub fn get(self) -> usize {
15 (self.0.get() - 1) as usize
16 }
17
18 pub fn next(self) -> Self {
19 let mut id = self.0.get();
20 id = id.saturating_add(1);
21 Self(NonZeroU32::new(id).unwrap())
22 }
23}
24impl fmt::Debug for NodeId {
25 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
26 write!(f, "NodeId({})", self.get())
27 }
28}
29
30pub(super) struct Tree<T> {
31 nodes: Vec<Node<T>>,
32}
33impl<T> Tree<T> {
34 pub(super) fn new(root: T) -> Self {
35 let nodes = vec![Node {
36 parent: None,
37 prev_sibling: None,
38 next_sibling: None,
39 last_child: None,
40 descendants_end: 1,
41 value: root,
42 }];
43
44 Tree { nodes }
45 }
46
47 pub fn index(&self, id: NodeId) -> NodeRef<T> {
48 #[cfg(debug_assertions)]
49 let _ = self.nodes[id.get()];
50 NodeRef { tree: self, id }
51 }
52
53 pub fn index_mut(&mut self, id: NodeId) -> NodeMut<T> {
54 #[cfg(debug_assertions)]
55 let _ = self.nodes[id.get()];
56 NodeMut { tree: self, id }
57 }
58
59 pub fn root(&self) -> NodeRef<T> {
60 self.index(NodeId::new(0))
61 }
62
63 pub fn root_mut(&mut self) -> NodeMut<T> {
64 self.index_mut(NodeId::new(0))
65 }
66
67 pub fn len(&self) -> usize {
68 self.nodes.len()
69 }
70
71 pub fn iter(&self) -> impl std::iter::ExactSizeIterator<Item = (NodeId, &T)> {
72 self.nodes.iter().enumerate().map(|(i, n)| (NodeId::new(i), &n.value))
73 }
74}
75impl<T> fmt::Debug for Tree<T> {
76 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
77 f.debug_struct("Tree").field("nodes", &self.nodes).finish()
78 }
79}
80
81struct Node<T> {
82 parent: Option<NodeId>,
83 prev_sibling: Option<NodeId>,
84 next_sibling: Option<NodeId>,
85 last_child: Option<NodeId>,
86 descendants_end: u32,
87 value: T,
88}
89
90impl<T> fmt::Debug for Node<T> {
91 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
92 f.debug_struct("Node")
93 .field("parent", &self.parent)
94 .field("prev_sibling", &self.prev_sibling)
95 .field("next_sibling", &self.next_sibling)
96 .field("last_child", &self.last_child)
97 .field("descendant_end", &self.descendants_end)
98 .finish_non_exhaustive()
99 }
100}
101
102pub(super) struct NodeRef<'a, T> {
103 tree: &'a Tree<T>,
104 id: NodeId,
105}
106impl<T> Clone for NodeRef<'_, T> {
107 fn clone(&self) -> Self {
108 *self
109 }
110}
111impl<T> Copy for NodeRef<'_, T> {}
112impl<'a, T> NodeRef<'a, T> {
113 pub fn id(&self) -> NodeId {
114 self.id
115 }
116
117 pub fn parent(&self) -> Option<NodeRef<'a, T>> {
118 self.tree.nodes[self.id.get()].parent.map(|p| NodeRef { tree: self.tree, id: p })
119 }
120
121 pub fn prev_sibling(&self) -> Option<NodeRef<'a, T>> {
122 self.tree.nodes[self.id.get()]
123 .prev_sibling
124 .map(|p| NodeRef { tree: self.tree, id: p })
125 }
126
127 pub fn next_sibling(&self) -> Option<NodeRef<'a, T>> {
128 self.tree.nodes[self.id.get()]
129 .next_sibling
130 .map(|p| NodeRef { tree: self.tree, id: p })
131 }
132
133 pub fn has_siblings(&self) -> bool {
134 let node = &self.tree.nodes[self.id.get()];
135 node.prev_sibling.is_some() || node.next_sibling.is_some()
136 }
137
138 pub fn first_child(&self) -> Option<NodeRef<'a, T>> {
139 self.tree.nodes[self.id.get()].last_child.map(|_| NodeRef {
140 tree: self.tree,
141 id: self.id.next(), })
143 }
144
145 pub fn last_child(&self) -> Option<NodeRef<'a, T>> {
146 self.tree.nodes[self.id.get()]
147 .last_child
148 .map(|p| NodeRef { tree: self.tree, id: p })
149 }
150
151 pub fn has_children(&self) -> bool {
152 self.tree.nodes[self.id.get()].last_child.is_some()
153 }
154
155 pub fn children_count(&self) -> usize {
156 let mut r = 0;
157 if let Some(mut c) = self.first_child() {
158 r += 1;
159
160 while let Some(n) = c.next_sibling() {
161 c = n;
162 r += 1;
163 }
164 }
165 r
166 }
167
168 pub fn descendants_range(self) -> std::ops::Range<usize> {
169 let self_i = self.id.get();
170 let start = self_i + 1;
171 let end = self.tree.nodes[self_i].descendants_end as usize;
172 start..end
173 }
174
175 pub fn self_and_descendants(self) -> TreeIter {
176 let node = self.id.get();
177 TreeIter {
178 node,
179 next: node,
180 end: self.tree.nodes[self.id.get()].descendants_end as usize,
181 }
182 }
183
184 pub fn value(&self) -> &'a T {
185 &self.tree.nodes[self.id.get()].value
186 }
187}
188impl<T> PartialEq for NodeRef<'_, T> {
189 fn eq(&self, other: &Self) -> bool {
190 self.id == other.id
191 }
192}
193
194pub(super) struct NodeMut<'a, T> {
195 tree: &'a mut Tree<T>,
196 id: NodeId,
197}
198impl<T> NodeMut<'_, T> {
199 pub fn id(&self) -> NodeId {
200 self.id
201 }
202
203 pub fn push_child(&mut self, value: T) -> NodeMut<T> {
204 let len = self.tree.nodes.len();
205 let new_id = NodeId::new(len);
206
207 let self_node = &mut self.tree.nodes[self.id.get()];
208 let mut new_node = Node {
209 parent: Some(self.id),
210 prev_sibling: None,
211 next_sibling: None,
212 last_child: None,
213 descendants_end: len as u32 + 1,
214 value,
215 };
216
217 if let Some(last) = &mut self_node.last_child {
218 let prev_last = *last;
219 new_node.prev_sibling = Some(prev_last);
220 *last = new_id;
221 self.tree.nodes[prev_last.get()].next_sibling = Some(new_id);
222 } else {
223 self_node.last_child = Some(new_id);
224 }
225
226 self.tree.nodes.push(new_node);
227
228 NodeMut {
229 tree: self.tree,
230 id: new_id,
231 }
232 }
233
234 pub fn push_reuse(&mut self, child: NodeRef<T>, reuse: &mut impl FnMut(&T) -> T) {
235 let mut clone = self.push_child(reuse(child.value()));
236
237 if let Some(mut child) = child.first_child() {
238 clone.push_reuse(child, reuse);
239
240 while let Some(c) = child.next_sibling() {
241 child = c;
242 clone.push_reuse(c, reuse);
243 }
244 }
245
246 clone.close();
247 }
248
249 fn first_child(&mut self) -> Option<NodeMut<T>> {
250 self.tree.nodes[self.id.get()].last_child.map(|_| NodeMut {
251 tree: self.tree,
252 id: self.id.next(), })
254 }
255
256 pub fn parallel_fold(&mut self, mut split: Tree<T>, take: &mut impl FnMut(&mut T) -> T) {
257 if let Some(mut c) = split.root_mut().first_child() {
258 self.parallel_fold_node(&mut c, take);
259
260 let tree = c.tree;
261 let mut child_idx = c.id.get();
262 while let Some(id) = tree.nodes[child_idx].next_sibling {
263 self.parallel_fold_node(&mut NodeMut { tree, id }, take);
264 child_idx = id.get();
265 }
266 }
267 }
268
269 fn parallel_fold_node(&mut self, split: &mut NodeMut<T>, take: &mut impl FnMut(&mut T) -> T) {
270 let mut clone = self.push_child(take(split.value()));
271
272 if let Some(mut child) = split.first_child() {
273 clone.parallel_fold_node(&mut child, take);
274
275 let tree = child.tree;
276 let mut child_idx = child.id.get();
277 while let Some(id) = tree.nodes[child_idx].next_sibling {
278 clone.parallel_fold_node(&mut NodeMut { tree, id }, take);
279 child_idx = id.get();
280 }
281 }
282
283 clone.close();
284 }
285
286 pub fn close(self) {
287 let len = self.tree.len();
288 self.tree.nodes[self.id.get()].descendants_end = len as u32;
289 }
290
291 pub fn value(&mut self) -> &mut T {
292 &mut self.tree.nodes[self.id.get()].value
293 }
294}
295
296pub(super) struct TreeIter {
297 node: usize, next: usize,
300 end: usize,
301}
302impl TreeIter {
303 pub fn next(&mut self) -> Option<NodeId> {
306 if self.next < self.end {
307 let next = NodeId::new(self.next);
308 self.next += 1;
309 Some(next)
310 } else {
311 None
312 }
313 }
314
315 pub fn close<T>(&mut self, tree: &Tree<T>, yielded: NodeId) {
317 let node = &tree.nodes[yielded.get()];
318 if let Some(next_sibling) = node.next_sibling {
319 self.next = next_sibling.get();
320 } else if let Some(parent) = node.parent {
321 let node = &tree.nodes[parent.get()];
322 self.next = self.end.min(node.descendants_end as usize);
323 } else {
324 self.next = self.end;
325 }
326 }
327
328 pub fn skip_to(&mut self, node: NodeId) {
329 let node = node.get();
330 if node > self.next {
331 if node > self.end {
332 self.next = self.end;
333 } else {
334 self.next = node;
335 }
336 }
337 }
338
339 pub fn len(&self) -> usize {
340 self.end - self.next
341 }
342
343 pub fn rev<T>(self, tree: &Tree<T>) -> RevTreeIter {
344 let mut count = self.next - self.node;
345
346 let mut iter = RevTreeIter {
347 next: self.node,
348 end: self.node,
349 started: false,
350 };
351
352 while count > 0 {
353 count -= 1;
354 iter.next(tree);
355 }
356
357 iter
358 }
359
360 pub fn empty() -> Self {
361 Self { node: 0, next: 0, end: 0 }
362 }
363}
364
365pub(super) struct RevTreeIter {
366 next: usize,
367 end: usize,
368 started: bool,
369}
370impl RevTreeIter {
371 pub fn next<T>(&mut self, tree: &Tree<T>) -> Option<NodeId> {
374 if self.next != self.end || !self.started {
375 self.started = true;
376
377 let next = NodeId::new(self.next);
378 let node = &tree.nodes[self.next];
379
380 if let Some(last_child) = node.last_child {
381 self.next = last_child.get();
382 } else if let Some(prev) = node.prev_sibling {
383 self.next = prev.get();
384 } else {
385 let mut node = node;
386 let mut changed = false;
387 while let Some(parent) = node.parent {
388 let parent = parent.get();
389 if parent == self.end {
390 self.next = self.end;
391 changed = true;
392 break;
393 }
394
395 node = &tree.nodes[parent];
396
397 if let Some(prev) = node.prev_sibling {
398 self.next = prev.get();
399 changed = true;
400 break;
401 }
402 }
403 if !changed {
404 self.next = self.end;
406 }
407 }
408
409 Some(next)
410 } else {
411 None
412 }
413 }
414
415 pub fn close<T>(&mut self, tree: &Tree<T>, yielded: NodeId) {
417 let mut node = &tree.nodes[yielded.get()];
418
419 if let Some(prev) = node.prev_sibling {
420 self.next = prev.get();
421 } else {
422 while let Some(parent) = node.parent {
423 let parent = parent.get();
424
425 if parent == self.end {
426 self.next = self.end;
427 break;
428 }
429
430 node = &tree.nodes[parent];
431
432 if let Some(prev) = node.prev_sibling {
433 self.next = prev.get();
434 break;
435 }
436 }
437 }
438 }
439
440 pub fn skip_to<T>(&mut self, tree: &Tree<T>, node: NodeId) {
441 let node = node.get();
442 if node > self.end {
443 let root = &tree.nodes[self.end];
444 if node >= root.descendants_end as usize {
445 self.next = self.end;
446 } else {
447 self.next = node;
448 self.started = true;
449 }
450 }
451 }
452
453 pub fn empty() -> Self {
454 Self {
455 next: 0,
456 end: 0,
457 started: true,
458 }
459 }
460}
461
462#[cfg(test)]
463mod tests {
464 use super::*;
465
466 fn iter_tree() -> Tree<&'static str> {
467 let mut tree = Tree::new("r");
468 let mut r = tree.root_mut();
469 let mut a = r.push_child("a");
470 a.push_child("a.a");
471 let mut ab = a.push_child("a.b");
472 ab.push_child("a.b.a");
473 ab.push_child("a.b.b");
474 a.push_child("a.c");
475 a.close();
476 r.push_child("b");
477 r.close();
478 tree
479 }
480
481 #[test]
482 fn iter_next() {
483 let tree = iter_tree();
484 let mut iter = tree.root().self_and_descendants();
485
486 let mut r = vec![];
487 while let Some(id) = iter.next() {
488 r.push(*tree.index(id).value());
489 }
490
491 assert_eq!(r, vec!["r", "a", "a.a", "a.b", "a.b.a", "a.b.b", "a.c", "b"]);
492 }
493
494 #[test]
495 fn iter_rev() {
496 let tree = iter_tree();
497 let mut iter = tree.root().self_and_descendants().rev(&tree);
498
499 let mut r = vec![];
500 while let Some(id) = iter.next(&tree) {
501 r.push(*tree.index(id).value());
502 }
503
504 assert_eq!(r, vec!["r", "b", "a", "a.c", "a.b", "a.b.b", "a.b.a", "a.a"]);
505 }
506
507 #[test]
508 fn iter_not_root() {
509 let tree = iter_tree();
510 let mut iter = tree.root().first_child().unwrap().self_and_descendants();
511
512 let mut r = vec![];
513 while let Some(id) = iter.next() {
514 r.push(*tree.index(id).value());
515 }
516
517 assert_eq!(r, vec!["a", "a.a", "a.b", "a.b.a", "a.b.b", "a.c"]);
518 }
519
520 #[test]
521 fn iter_rev_not_root() {
522 let tree = iter_tree();
523 let mut iter = tree.root().first_child().unwrap().self_and_descendants().rev(&tree);
524
525 let mut r = vec![];
526 while let Some(id) = iter.next(&tree) {
527 r.push(*tree.index(id).value());
528 }
529
530 assert_eq!(r, vec!["a", "a.c", "a.b", "a.b.b", "a.b.a", "a.a"]);
531 }
532
533 #[test]
534 fn iter_descendants() {
535 let tree = iter_tree();
536 let mut iter = tree.root().first_child().unwrap().self_and_descendants();
537 iter.next();
538
539 let mut r = vec![];
540 while let Some(id) = iter.next() {
541 r.push(*tree.index(id).value());
542 }
543
544 assert_eq!(r, vec!["a.a", "a.b", "a.b.a", "a.b.b", "a.c"]);
545 }
546
547 #[test]
548 fn iter_rev_descendants() {
549 let tree = iter_tree();
550 let mut iter = tree.root().first_child().unwrap().self_and_descendants().rev(&tree);
551 iter.next(&tree);
552
553 let mut r = vec![];
554 while let Some(id) = iter.next(&tree) {
555 r.push(*tree.index(id).value());
556 }
557
558 assert_eq!(r, vec!["a.c", "a.b", "a.b.b", "a.b.a", "a.a"]);
559 }
560
561 #[test]
562 fn iter_close() {
563 let tree = iter_tree();
564 let mut iter = tree.root().self_and_descendants();
565
566 iter.next().unwrap(); let a = iter.next().unwrap();
568
569 iter.close(&tree, a);
570
571 let mut r = vec![];
572 while let Some(id) = iter.next() {
573 r.push(*tree.index(id).value());
574 }
575
576 assert_eq!(r, vec!["b"]);
577 }
578
579 #[test]
580 fn iter_rev_close() {
581 let tree = iter_tree();
582 let mut iter = tree.root().self_and_descendants().rev(&tree);
583
584 iter.next(&tree).unwrap(); let b = iter.next(&tree).unwrap();
586
587 iter.close(&tree, b);
588
589 let mut r = vec![];
590 while let Some(id) = iter.next(&tree) {
591 r.push(*tree.index(id).value());
592 }
593
594 assert_eq!(r, vec!["a", "a.c", "a.b", "a.b.b", "a.b.a", "a.a"]);
595 }
596
597 #[test]
598 fn iter_skip_to() {
599 let tree = iter_tree();
600
601 let mut iter = tree.root().self_and_descendants();
602 let mut all = vec![];
603 while let Some(id) = iter.next() {
604 all.push(id);
605 }
606
607 for (i, id) in all.iter().enumerate() {
608 let mut iter = tree.root().self_and_descendants();
609 iter.skip_to(*id);
610
611 let mut result = vec![];
612 while let Some(id) = iter.next() {
613 result.push(tree.nodes[id.get()].value);
614 }
615
616 let expected: Vec<_> = all[i..].iter().map(|id| tree.nodes[id.get()].value).collect();
617
618 assert_eq!(expected, result);
619 }
620 }
621
622 #[test]
623 fn iter_rev_skip_to() {
624 let tree = iter_tree();
625
626 let mut iter = tree.root().self_and_descendants().rev(&tree);
627 let mut all = vec![];
628 while let Some(id) = iter.next(&tree) {
629 all.push(id);
630 }
631
632 for (i, id) in all.iter().enumerate() {
633 let mut iter = tree.root().self_and_descendants().rev(&tree);
634 iter.skip_to(&tree, *id);
635
636 let mut result = vec![];
637 while let Some(id) = iter.next(&tree) {
638 result.push(tree.nodes[id.get()].value);
639 }
640
641 let expected: Vec<_> = all[i..].iter().map(|id| tree.nodes[id.get()].value).collect();
642 assert_eq!(expected, result);
643 }
644 }
645}