RC: W5 D1 — Implementing an iterator for red-black trees

March 11, 2024

One important feature of a storage engine is to be able to serve range queries. In an LSM engine, the data is distributed over multiple in-memory components (memtables) and on-disk components (SSTables). To be able to serve range queries, the naive approach would be to run the range query over each of those components, store them in memory and then keep only the most recent keys for those that are have been written on multiple components. However, proceeding this way would likely lead to out of memory errors (this is more likely if the range of the query is large and/or the amount of data stored over all these in-memory and disk components is large). Thus, that is not approach that LSM engines usually use. Instead, LSM engines use iterators: each component has its own iterator (thus, only one value is accessed at a time) and there is then a process to merge those iterators into one.

So far, I have only in-memory components (memtables), that are implemented as red-black trees. So, the first thing I needed to do was to implement an iterator for those trees. Elements in a red-black tree are already sorted. So we only need to do a simple in-order traversal of the tree. Since we want the values over a given range, we don’t need to parse the whole tree: we can optimize to stop the search once we have reached a node whose key is greater than the upper bound, and avoid traversing the left branches of nodes whose key is smaller than the lower bound.

Here is what the implementation gives in practice:

from typing import Optional, Iterator


class Node:
    Key = Optional[str]
    Data = Optional[str]

    def __init__(self, key: Key, data: Data, left: "Node", right: "Node"):
        self.key = key
        self.data = data
        self.left = left
        self.right = right

    def in_order_traversal(self, lower: Key, upper: Key) -> Iterator["Node"]:
        # Optimization: stop traversing when greater than the upper bound
        if self.key > upper:
            return
        # Optimization: stop traversing when smaller than the lower bound
        if self.key < lower:
            return

        # Regular in-order traversal (+ yield nodes whose key is within the range)
        if self.left is not RedBlackTree.NIL_LEAF:
            yield from self.left.in_order_traversal(lower=lower, upper=upper)
        if lower <= self.key <= upper:
            yield self
        if self.right is not RedBlackTree.NIL_LEAF:
            yield from self.right.in_order_traversal(lower=lower, upper=upper)


class RedBlackTree:
    NIL_LEAF = Node(key=None, data=None, left=None, right=None)

    def __init__(self):
        self.root = self.NIL_LEAF

    def scan(self, lower: Node.Key, upper: Node.Key) -> Iterator[Node.Data]:
        if self.root is self.NIL_LEAF:
            return

        for node in self.root.in_order_traversal(lower=lower, upper=upper):
            yield node.data

One thing that is interesting is to analyze the complexity of this algorithm. This algorithm comes back to:

  1. Finding the start node (i.e. the node whose key is equal to the lower bound – or the node whose key is the smallest among those bigger than the lower bound)
  2. Traversing the tree until we reach the upper bound.

If n is the number of nodes in the tree, since the tree is balanced (and thus its height is log(n)), the first part will execute in logarithmic time (o(log n)). As for the second part, it really depends on the number of nodes that we have between the lower and the upper bound (the more nodes, the longer). Assuming that there are k nodes within that range, the time complexity for the second part is o(k). Therefore, the time complexity to do a range scan in a red-black tree is o(k + log(n)).

That’s it for today. Tomorrow, I will need to merge all iterators!