view redblacktree.py @ 18:92e2879e2e33

Rework the red-black tree based on Julienne Walker's tutorial. Insertion is implemented now. Deletion will come next.
author Brian Neal <bgneal@gmail.com>
date Wed, 26 Dec 2012 19:59:17 -0600
parents 977628018b4b
children 3c74185c5047
line wrap: on
line source
"""
Copyright (C) 2012 Brian G. Neal.

Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in
the Software without restriction, including without limitation the rights to
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
the Software, and to permit persons to whom the Software is furnished to do so,
subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

----

A red-black tree for Section 3.4, Exercise 4 in Allen Downey's _Think
Complexity_ book.

http://greenteapress.com/complexity

Red black trees are described on Wikipedia:
http://en.wikipedia.org/wiki/Red-black_tree.

This is basically a Python implementation of Julienne Walker's Red Black Trees
tutorial found at:
http://www.eternallyconfuzzled.com/tuts/datastructures/jsw_tut_rbtree.aspx
We implement Julienne's top-down insertion and deletion algorithms here.

Some ideas were also borrowed from code by Darren Hart at
http://dvhart.com/darren/files/rbtree.py

"""

BLACK, RED = range(2)
LEFT, RIGHT = range(2)


class TreeError(Exception):
    """Base exception class for red-black tree errors."""
    pass


class Node(object):
    """A node class for red-black trees.

    * A node has a color, either RED or BLACK.
    * A node has a key and an optional value.

        * The key is used to order the red-black tree by calling the "<"
          operator when comparing two nodes.
        * The value is useful for using the red-black tree to implement a map
          datastructure.

    * Nodes have exactly 2 link fields which we represent as a list of
      2 elements to represent the left and right children of the node. This list
      representation was borrowed from Julienne Walker as it simplifies some of
      the code. Element 0 is the LEFT child and element 1 is the RIGHT child.

    """

    def __init__(self, key, value=None, color=RED):
        self.key = key
        self.value = value
        self.color = color
        self.link = [None, None]

    def __str__(self):
        c = 'B' if self.color == BLACK else 'R'
        if self.value:
            return '({}: {} => {})'.format(c, self.key, self.value)
        else:
            return '({}: {})'.format(c, self.key)


def is_red(n):
    """Return True if the given Node n is RED and False otherwise.

    If the node is None, then it is considered to be BLACK, and False is
    returned.

    """
    return n is not None and n.color == RED


class Tree(object):
    """A red-black Tree class.

    A red-black tree is a binary search tree with the following properties:

        1. A node is either red or black.
        2. The root is black.
        3. All leaves are black.
        4. Both children of every red node are black.
        5. Every simple path from a given node to any descendant leaf contains
           the same number of black nodes.

    These rules ensure that the path from the root to the furthest leaf is no
    more than twice as long as the path from the root to the nearest leaf. Thus
    the tree is roughly height-balanced.

    """

    def __init__(self):
        self.root = None

    def __iter__(self):
        return self._inorder(self.root)

    def _inorder(self, node):
        """A generator to perform an inorder traversal of the nodes in the
        tree starting at the given node.

        """
        if node.link[0]:
            for n in self._inorder(node.link[0]):
                yield n

        yield node

        if node.link[1]:
            for n in self._inorder(node.link[1]):
                yield n

    def _single_rotate(self, root, d):
        """Perform a single rotation about the node 'root' in the given
        direction 'd' (LEFT or RIGHT).

        The old root is set to RED and the new root is set to BLACK.

        Returns the new root.

        """
        nd = not d
        save = root.link[nd]

        root.link[nd] = save.link[d]
        save.link[d] = root

        root.color = RED
        save.color = BLACK

        return save

    def _double_rotate(self, root, d):
        """Perform two single rotations about the node root in the direction d.

        The new root is returned.

        """
        nd = not d
        root.link[nd] = self._single_rotate(root.link[nd], nd)
        return self._single_rotate(root, d)

    def validate(self, root=None):
        """Checks to see if the red-black tree validates at the given node, i.e.
        all red-black tree rules are valid. If root is None, the root of the
        tree is used as the starting point.

        If any rules are violated, a TreeError is raised.

        Returns the black height of the tree rooted at root.

        """
        if root is None:
            root = self.root

        return self._validate(root)

    def _validate(self, root):
        """Internal implementation of the validate() method."""

        if root is None:
            return 1

        ln = root.link[0]
        rn = root.link[1]

        # Check for consecutive red links

        if is_red(root) and (is_red(ln) or is_red(rn)):
            raise TreeError('red violation')

        lh = self._validate(ln)
        rh = self._validate(rn)

        # Check for invalid binary search tree

        if (ln and ln.key >= root.key) or (rn and rn.key <= root.key):
            raise TreeError('binary tree violation')

        # Check for black height mismatch
        if lh != rh:
            raise TreeError('black violation')

        # Only count black links

        return lh if is_red(root) else lh + 1

    def insert(self, key, value=None):
        """Insert the (key, value) pair into the tree."""

        # Check for the empty tree case:
        if self.root is None:
            self.root = Node(key=key, value=value, color=BLACK)
            return

        # False/dummy tree root
        head = Node(key=None, value=None, color=BLACK)
        d = LEFT
        last = LEFT

        t = head
        g = p = None
        t.link[1] = self.root
        q = self.root

        # Search down the tree
        while True:
            if q is None:
                # Insert new node at the bottom
                p.link[d] = q = Node(key=key, value=value, color=RED)
            elif is_red(q.link[0]) and is_red(q.link[1]):
                # Color flip
                q.color = RED
                q.link[0].color = BLACK
                q.link[1].color = BLACK

            # Fix red violation
            if is_red(q) and is_red(p):
                d2 = t.link[1] is g

                if q is p.link[last]:
                    t.link[d2] = self._single_rotate(g, not last)
                else:
                    t.link[d2] = self._double_rotate(g, not last)

            # Stop if found
            if q.key == key:
                break

            last = d
            d = q.key < key

            # Update helpers
            if g is not None:
                t = g;
            g = p
            p = q
            q = q.link[d]

        # Update root
        self.root = head.link[1]
        self.root.color = BLACK


if __name__ == '__main__':
    import random

    for n in range(30):
        tree = Tree()
        tree.validate()

        for i in range(25):
            val = random.randint(0, 100)
            tree.insert(val)

        for n in tree:
            print n.key,
        print

        tree.validate()