diff ch3ex5.py @ 20:0326803882ad

Finally completing Ch. 3, exercise 5: write a TreeMap that uses a red-black tree.
author Brian Neal <bgneal@gmail.com>
date Thu, 27 Dec 2012 15:30:49 -0600
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/ch3ex5.py	Thu Dec 27 15:30:49 2012 -0600
@@ -0,0 +1,71 @@
+"""Chapter 3, exercise 5.
+
+"A drawback of hashtables is that the elements have to be hashable, which usually
+means they have to be immutable. That's why, in Python, you can use tuples but
+not lists as keys in a dictionary. An alternative is to use a tree-based map.
+Write an implementation of the map interface called TreeMap that uses
+a red-black tree to perform add and get in log time."
+
+
+"""
+import redblacktree
+
+
+class TreeMap(object):
+    """A tree-based map class."""
+
+    def __init__(self):
+        self.tree = redblacktree.Tree()
+
+    def get(self, k):
+        """Looks up the key (k) and returns the corresponding value, or raises
+        a KeyError if the key is not found.
+
+        """
+        return self.tree.find(k)
+
+    def add(self, k, v):
+        """Adds the key/value pair (k, v) to the tree. If the key already
+        exists, the value is updated to the new value v.
+
+        Returns True if the pair was inserted, and False if the key already
+        existed and the tree was updated.
+
+        """
+        node, inserted = self.tree.insert(k, v)
+        if not inserted:
+            node.value = v
+        return inserted
+
+    def remove(self, k):
+        """Removes the mapping with the given key (k).
+        Raises a KeyError if the mapping was not found.
+
+        """
+        result = self.tree.remove(k)
+        if not result:
+            raise KeyError
+
+    def __len__(self):
+        """Returns the number of mappings in the map."""
+        return len(self.tree)
+
+
+def main(script):
+    import string
+    m = TreeMap()
+    s = string.ascii_lowercase
+
+    for k, v in enumerate(s):
+        m.add([k], v)
+
+    for k in range(len(s)):
+        key = [k]
+        print key, m.get(key)
+        m.remove(key)
+
+    assert len(m) == 0
+
+if __name__ == '__main__':
+    import sys
+    main(*sys.argv)