diff redblacktree.py @ 19:3c74185c5047

Added the remove() method to the red-black tree. Made insert() and remove() return useful return values to support use in maps or sets.
author Brian Neal <bgneal@gmail.com>
date Thu, 27 Dec 2012 13:46:12 -0600
parents 92e2879e2e33
children 0326803882ad
line wrap: on
line diff
--- a/redblacktree.py	Wed Dec 26 19:59:17 2012 -0600
+++ b/redblacktree.py	Thu Dec 27 13:46:12 2012 -0600
@@ -71,6 +71,14 @@
         self.color = color
         self.link = [None, None]
 
+    def free(self):
+        """Call this function when removing a node from a tree.
+
+        It updates the links to encourage garbage collection.
+
+        """
+        self.link[0] = self.link[1] = None
+
     def __str__(self):
         c = 'B' if self.color == BLACK else 'R'
         if self.value:
@@ -109,8 +117,17 @@
 
     def __init__(self):
         self.root = None
+        self._size = 0
+
+    def __len__(self):
+        """To support the len() function by returning the number of elements in
+        the tree.
+
+        """
+        return self._size
 
     def __iter__(self):
+        """Return an iterator to perform an inorder traversal of the tree."""
         return self._inorder(self.root)
 
     def _inorder(self, node):
@@ -118,15 +135,16 @@
         tree starting at the given node.
 
         """
-        if node.link[0]:
-            for n in self._inorder(node.link[0]):
-                yield n
+        if node:
+            if node.link[0]:
+                for n in self._inorder(node.link[0]):
+                    yield n
 
-        yield node
+            yield node
 
-        if node.link[1]:
-            for n in self._inorder(node.link[1]):
-                yield n
+            if node.link[1]:
+                for n in self._inorder(node.link[1]):
+                    yield n
 
     def _single_rotate(self, root, d):
         """Perform a single rotation about the node 'root' in the given
@@ -204,28 +222,41 @@
         return lh if is_red(root) else lh + 1
 
     def insert(self, key, value=None):
-        """Insert the (key, value) pair into the tree."""
+        """Insert the (key, value) pair into the tree. Duplicate keys are not
+        allowed in the tree.
+
+        Returns a tuple of the form (node, flag) where node is either the newly
+        inserted tree node or an already existing node that has the same key.
+        The flag member is a Boolean that will be True if the node was inserted
+        and False if a node with the given key already exists in the tree.
+
+        """
 
         # Check for the empty tree case:
         if self.root is None:
             self.root = Node(key=key, value=value, color=BLACK)
-            return
+            self._size = 1
+            return self.root, True
 
         # False/dummy tree root
         head = Node(key=None, value=None, color=BLACK)
-        d = LEFT
-        last = LEFT
+        d = last = LEFT     # direction variables
 
+        # Set up helpers
         t = head
         g = p = None
-        t.link[1] = self.root
-        q = self.root
+        q = t.link[1] = self.root
+
+        # Return values
+        target, insert_flag = (None, False)
 
         # Search down the tree
         while True:
             if q is None:
                 # Insert new node at the bottom
                 p.link[d] = q = Node(key=key, value=value, color=RED)
+                self._size += 1
+                insert_flag = True
             elif is_red(q.link[0]) and is_red(q.link[1]):
                 # Color flip
                 q.color = RED
@@ -243,6 +274,7 @@
 
             # Stop if found
             if q.key == key:
+                target = q
                 break
 
             last = d
@@ -259,6 +291,85 @@
         self.root = head.link[1]
         self.root.color = BLACK
 
+        return target, insert_flag
+
+    def remove(self, key):
+        """Remove the given key from the tree.
+
+        Returns True if the key was found and removed and False if the key was
+        not found in the tree.
+
+        """
+        remove_flag = False     # return value
+
+        if self.root is not None:
+            # False/dummy tree root
+            head = Node(key=None, value=None, color=BLACK)
+            f = None    # found item
+            d = RIGHT   # direction
+
+            # Set up helpers
+            q = head
+            g = p = None
+            q.link[d] = self.root
+
+            # Search and push a red down to fix red violations as we go
+            while q.link[d]:
+                last = d
+
+                # Move the helpers down
+                g = p
+                p = q
+                q = q.link[d]
+                d = q.key < key
+
+                # Save the node with the matching data and keep going; we'll do
+                # removal tasks at the end
+                if q.key == key:
+                    f = q
+
+                # Push the red node down with rotations and color flips
+                if not is_red(q) and not is_red(q.link[d]):
+                    if is_red(q.link[not d]):
+                        p.link[last] = self._single_rotate(q, d)
+                        p = p.link[last]
+                    elif not is_red(q.link[not d]):
+                        s = p.link[not last]
+
+                        if s:
+                            if not is_red(s.link[not last]) and not is_red(s.link[last]):
+                                # Color flip
+                                p.color = BLACK
+                                s.color = RED
+                                q.color = RED
+                            else:
+                                d2 = g.link[1] is p
+
+                                if is_red(s.link[last]):
+                                    g.link[d2] = self._double_rotate(p, last)
+                                elif is_red(s.link[not last]):
+                                    g.link[d2] = self._single_rotate(p, last)
+
+                                # Ensure correct coloring
+                                q.color = g.link[d2].color = RED
+                                g.link[d2].link[0].color = BLACK
+                                g.link[d2].link[1].color = BLACK
+
+            # Replace and remove the saved node
+            if f:
+                f.key, f.value = q.key, q.value
+                p.link[p.link[1] is q] = q.link[q.link[0] is None]
+                q.free()
+                self._size -= 1
+                remove_flag = True
+
+            # Update root and make it black
+            self.root = head.link[1]
+            if self.root:
+                self.root.color = BLACK
+
+        return remove_flag
+
 
 if __name__ == '__main__':
     import random
@@ -267,12 +378,33 @@
         tree = Tree()
         tree.validate()
 
+        vals = []
+        vals_set = set()
         for i in range(25):
             val = random.randint(0, 100)
+            while val in vals_set:
+                val = random.randint(0, 100)
+            vals.append(val)
+            vals_set.add(val)
+
             tree.insert(val)
+            tree.validate()
 
+        print 'Inserted in this order:', vals
+        assert len(vals) == len(vals_set)
+        keys = []
         for n in tree:
             print n.key,
-        print
+            keys.append(n.key)
+        print ' - len:', len(tree)
 
-        tree.validate()
+        # delete in a random order
+        random.shuffle(keys)
+
+        for k in keys:
+            print 'Removing', k
+            tree.remove(k)
+            for n in tree:
+                print n.key,
+            print ' - len:', len(tree)
+            tree.validate()