view redblacktree.py @ 31:a2358c64d9af

Chapter 5.4, exercise 5: Pareto distribution and city populations.
author Brian Neal <bgneal@gmail.com>
date Mon, 07 Jan 2013 20:41:44 -0600
parents 0326803882ad
children
line wrap: on
line source
"""
Copyright (C) 2012 Brian G. Neal.

Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in
the Software without restriction, including without limitation the rights to
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
the Software, and to permit persons to whom the Software is furnished to do so,
subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

----

A red-black tree for Section 3.4, Exercise 4 in Allen Downey's _Think
Complexity_ book.

http://greenteapress.com/complexity

Red black trees are described on Wikipedia:
http://en.wikipedia.org/wiki/Red-black_tree.

This is basically a Python implementation of Julienne Walker's Red Black Trees
tutorial found at:
http://www.eternallyconfuzzled.com/tuts/datastructures/jsw_tut_rbtree.aspx
We implement Julienne's top-down insertion and deletion algorithms here.

Some ideas were also borrowed from code by Darren Hart at
http://dvhart.com/darren/files/rbtree.py

"""

BLACK, RED = range(2)
LEFT, RIGHT = range(2)


class TreeError(Exception):
    """Base exception class for red-black tree errors."""
    pass


class Node(object):
    """A node class for red-black trees.

    * A node has a color, either RED or BLACK.
    * A node has a key and an optional value.

        * The key is used to order the red-black tree by calling the "<"
          operator when comparing two nodes.
        * The value is useful for using the red-black tree to implement a map
          datastructure.

    * Nodes have exactly 2 link fields which we represent as a list of
      2 elements to represent the left and right children of the node. This list
      representation was borrowed from Julienne Walker as it simplifies some of
      the code. Element 0 is the LEFT child and element 1 is the RIGHT child.

    """

    def __init__(self, key, value=None, color=RED):
        self.key = key
        self.value = value
        self.color = color
        self.link = [None, None]

    def free(self):
        """Call this function when removing a node from a tree.

        It updates the links to encourage garbage collection.

        """
        self.link[0] = self.link[1] = None

    def __str__(self):
        c = 'B' if self.color == BLACK else 'R'
        if self.value:
            return '({}: {} => {})'.format(c, self.key, self.value)
        else:
            return '({}: {})'.format(c, self.key)


def is_red(n):
    """Return True if the given Node n is RED and False otherwise.

    If the node is None, then it is considered to be BLACK, and False is
    returned.

    """
    return n is not None and n.color == RED


class Tree(object):
    """A red-black Tree class.

    A red-black tree is a binary search tree with the following properties:

        1. A node is either red or black.
        2. The root is black.
        3. All leaves are black.
        4. Both children of every red node are black.
        5. Every simple path from a given node to any descendant leaf contains
           the same number of black nodes.

    These rules ensure that the path from the root to the furthest leaf is no
    more than twice as long as the path from the root to the nearest leaf. Thus
    the tree is roughly height-balanced.

    """

    def __init__(self):
        self.root = None
        self._size = 0

    def __len__(self):
        """To support the len() function by returning the number of elements in
        the tree.

        """
        return self._size

    def __iter__(self):
        """Return an iterator to perform an inorder traversal of the tree."""
        return self._inorder(self.root)

    def _inorder(self, node):
        """A generator to perform an inorder traversal of the nodes in the
        tree starting at the given node.

        """
        if node:
            if node.link[0]:
                for n in self._inorder(node.link[0]):
                    yield n

            yield node

            if node.link[1]:
                for n in self._inorder(node.link[1]):
                    yield n

    def _single_rotate(self, root, d):
        """Perform a single rotation about the node 'root' in the given
        direction 'd' (LEFT or RIGHT).

        The old root is set to RED and the new root is set to BLACK.

        Returns the new root.

        """
        nd = not d
        save = root.link[nd]

        root.link[nd] = save.link[d]
        save.link[d] = root

        root.color = RED
        save.color = BLACK

        return save

    def _double_rotate(self, root, d):
        """Perform two single rotations about the node root in the direction d.

        The new root is returned.

        """
        nd = not d
        root.link[nd] = self._single_rotate(root.link[nd], nd)
        return self._single_rotate(root, d)

    def validate(self, root=None):
        """Checks to see if the red-black tree validates at the given node, i.e.
        all red-black tree rules are valid. If root is None, the root of the
        tree is used as the starting point.

        If any rules are violated, a TreeError is raised.

        Returns the black height of the tree rooted at root.

        """
        if root is None:
            root = self.root

        return self._validate(root)

    def _validate(self, root):
        """Internal implementation of the validate() method."""

        if root is None:
            return 1

        ln = root.link[0]
        rn = root.link[1]

        # Check for consecutive red links

        if is_red(root) and (is_red(ln) or is_red(rn)):
            raise TreeError('red violation')

        lh = self._validate(ln)
        rh = self._validate(rn)

        # Check for invalid binary search tree

        if (ln and ln.key >= root.key) or (rn and rn.key <= root.key):
            raise TreeError('binary tree violation')

        # Check for black height mismatch
        if lh != rh:
            raise TreeError('black violation')

        # Only count black links

        return lh if is_red(root) else lh + 1

    def insert(self, key, value=None):
        """Insert the (key, value) pair into the tree. Duplicate keys are not
        allowed in the tree.

        Returns a tuple of the form (node, flag) where node is either the newly
        inserted tree node or an already existing node that has the same key.
        The flag member is a Boolean that will be True if the node was inserted
        and False if a node with the given key already exists in the tree.

        """

        # Check for the empty tree case:
        if self.root is None:
            self.root = Node(key=key, value=value, color=BLACK)
            self._size = 1
            return self.root, True

        # False/dummy tree root
        head = Node(key=None, value=None, color=BLACK)
        d = last = LEFT     # direction variables

        # Set up helpers
        t = head
        g = p = None
        q = t.link[1] = self.root

        # Return values
        target, insert_flag = (None, False)

        # Search down the tree
        while True:
            if q is None:
                # Insert new node at the bottom
                p.link[d] = q = Node(key=key, value=value, color=RED)
                self._size += 1
                insert_flag = True
            elif is_red(q.link[0]) and is_red(q.link[1]):
                # Color flip
                q.color = RED
                q.link[0].color = BLACK
                q.link[1].color = BLACK

            # Fix red violation
            if is_red(q) and is_red(p):
                d2 = t.link[1] is g

                if q is p.link[last]:
                    t.link[d2] = self._single_rotate(g, not last)
                else:
                    t.link[d2] = self._double_rotate(g, not last)

            # Stop if found
            if q.key == key:
                target = q
                break

            last = d
            d = q.key < key

            # Update helpers
            if g is not None:
                t = g;
            g = p
            p = q
            q = q.link[d]

        # Update root
        self.root = head.link[1]
        self.root.color = BLACK

        return target, insert_flag

    def remove(self, key):
        """Remove the given key from the tree.

        Returns True if the key was found and removed and False if the key was
        not found in the tree.

        """
        remove_flag = False     # return value

        if self.root is not None:
            # False/dummy tree root
            head = Node(key=None, value=None, color=BLACK)
            f = None    # found item
            d = RIGHT   # direction

            # Set up helpers
            q = head
            g = p = None
            q.link[d] = self.root

            # Search and push a red down to fix red violations as we go
            while q.link[d]:
                last = d

                # Move the helpers down
                g = p
                p = q
                q = q.link[d]
                d = q.key < key

                # Save the node with the matching data and keep going; we'll do
                # removal tasks at the end
                if q.key == key:
                    f = q

                # Push the red node down with rotations and color flips
                if not is_red(q) and not is_red(q.link[d]):
                    if is_red(q.link[not d]):
                        p.link[last] = self._single_rotate(q, d)
                        p = p.link[last]
                    elif not is_red(q.link[not d]):
                        s = p.link[not last]

                        if s:
                            if not is_red(s.link[not last]) and not is_red(s.link[last]):
                                # Color flip
                                p.color = BLACK
                                s.color = RED
                                q.color = RED
                            else:
                                d2 = g.link[1] is p

                                if is_red(s.link[last]):
                                    g.link[d2] = self._double_rotate(p, last)
                                elif is_red(s.link[not last]):
                                    g.link[d2] = self._single_rotate(p, last)

                                # Ensure correct coloring
                                q.color = g.link[d2].color = RED
                                g.link[d2].link[0].color = BLACK
                                g.link[d2].link[1].color = BLACK

            # Replace and remove the saved node
            if f:
                f.key, f.value = q.key, q.value
                p.link[p.link[1] is q] = q.link[q.link[0] is None]
                q.free()
                self._size -= 1
                remove_flag = True

            # Update root and make it black
            self.root = head.link[1]
            if self.root:
                self.root.color = BLACK

        return remove_flag

    def find(self, key):
        """Looks up the key in the tree and returns the corresponding value, or
        raises a KeyError if it does not exist in the tree.

        """
        p = self.root
        while p:
            if p.key == key:
                return p.value
            else:
                d = p.key < key
                p = p.link[d]

        raise KeyError


if __name__ == '__main__':
    import random

    for n in range(30):
        tree = Tree()
        tree.validate()

        vals = []
        vals_set = set()
        for i in range(25):
            val = random.randint(0, 100)
            while val in vals_set:
                val = random.randint(0, 100)
            vals.append(val)
            vals_set.add(val)

            tree.insert(val)
            tree.validate()

        print 'Inserted in this order:', vals
        assert len(vals) == len(vals_set)
        keys = []
        for n in tree:
            print n.key,
            keys.append(n.key)
        print ' - len:', len(tree)

        # delete in a random order
        random.shuffle(keys)

        for k in keys:
            print 'Removing', k
            tree.remove(k)
            for n in tree:
                print n.key,
            print ' - len:', len(tree)
            tree.validate()