RC: W5 D1 — Implementing an iterator for red-black trees
March 11, 2024One important feature of a storage engine is to be able to serve range queries. In an LSM engine, the data is distributed over multiple in-memory components (memtables) and on-disk components (SSTables). To be able to serve range queries, the naive approach would be to run the range query over each of those components, store them in memory and then keep only the most recent keys for those that are have been written on multiple components. However, proceeding this way would likely lead to out of memory errors (this is more likely if the range of the query is large and/or the amount of data stored over all these in-memory and disk components is large). Thus, that is not approach that LSM engines usually use. Instead, LSM engines use iterators: each component has its own iterator (thus, only one value is accessed at a time) and there is then a process to merge those iterators into one.
So far, I have only in-memory components (memtables), that are implemented as red-black trees. So, the first thing I needed to do was to implement an iterator for those trees. Elements in a red-black tree are already sorted. So we only need to do a simple in-order traversal of the tree. Since we want the values over a given range, we don’t need to parse the whole tree: we can optimize to stop the search once we have reached a node whose key is greater than the upper bound, and avoid traversing the left branches of nodes whose key is smaller than the lower bound.
Here is what the implementation gives in practice:
from typing import Optional, Iterator
class Node:
Key = Optional[str]
Data = Optional[str]
def __init__(self, key: Key, data: Data, left: "Node", right: "Node"):
self.key = key
self.data = data
self.left = left
self.right = right
def in_order_traversal(self, lower: Key, upper: Key) -> Iterator["Node"]:
# Optimization: stop traversing when greater than the upper bound
if self.key > upper:
return
# Optimization: stop traversing when smaller than the lower bound
if self.key < lower:
return
# Regular in-order traversal (+ yield nodes whose key is within the range)
if self.left is not RedBlackTree.NIL_LEAF:
yield from self.left.in_order_traversal(lower=lower, upper=upper)
if lower <= self.key <= upper:
yield self
if self.right is not RedBlackTree.NIL_LEAF:
yield from self.right.in_order_traversal(lower=lower, upper=upper)
class RedBlackTree:
NIL_LEAF = Node(key=None, data=None, left=None, right=None)
def __init__(self):
self.root = self.NIL_LEAF
def scan(self, lower: Node.Key, upper: Node.Key) -> Iterator[Node.Data]:
if self.root is self.NIL_LEAF:
return
for node in self.root.in_order_traversal(lower=lower, upper=upper):
yield node.data
One thing that is interesting is to analyze the complexity of this algorithm. This algorithm comes back to:
- Finding the start node (i.e. the node whose key is equal to the lower bound – or the node whose key is the smallest among those bigger than the lower bound)
- Traversing the tree until we reach the upper bound.
If n
is the number of nodes in the tree, since the tree is balanced (and thus its height is log(n)
), the first part
will execute in logarithmic time (o(log n)
).
As for the second part, it really depends on the number of nodes that we have between the lower and the upper bound (the
more nodes, the longer).
Assuming that there are k
nodes within that range, the time complexity for the second part is o(k)
.
Therefore, the time complexity to do a range scan in a red-black tree is o(k + log(n))
.
That’s it for today. Tomorrow, I will need to merge all iterators!