Mercurial > public > think_complexity
diff ch3ex5.py @ 20:0326803882ad
Finally completing Ch. 3, exercise 5: write a TreeMap that uses a red-black
tree.
author | Brian Neal <bgneal@gmail.com> |
---|---|
date | Thu, 27 Dec 2012 15:30:49 -0600 |
parents | |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/ch3ex5.py Thu Dec 27 15:30:49 2012 -0600 @@ -0,0 +1,71 @@ +"""Chapter 3, exercise 5. + +"A drawback of hashtables is that the elements have to be hashable, which usually +means they have to be immutable. That's why, in Python, you can use tuples but +not lists as keys in a dictionary. An alternative is to use a tree-based map. +Write an implementation of the map interface called TreeMap that uses +a red-black tree to perform add and get in log time." + + +""" +import redblacktree + + +class TreeMap(object): + """A tree-based map class.""" + + def __init__(self): + self.tree = redblacktree.Tree() + + def get(self, k): + """Looks up the key (k) and returns the corresponding value, or raises + a KeyError if the key is not found. + + """ + return self.tree.find(k) + + def add(self, k, v): + """Adds the key/value pair (k, v) to the tree. If the key already + exists, the value is updated to the new value v. + + Returns True if the pair was inserted, and False if the key already + existed and the tree was updated. + + """ + node, inserted = self.tree.insert(k, v) + if not inserted: + node.value = v + return inserted + + def remove(self, k): + """Removes the mapping with the given key (k). + Raises a KeyError if the mapping was not found. + + """ + result = self.tree.remove(k) + if not result: + raise KeyError + + def __len__(self): + """Returns the number of mappings in the map.""" + return len(self.tree) + + +def main(script): + import string + m = TreeMap() + s = string.ascii_lowercase + + for k, v in enumerate(s): + m.add([k], v) + + for k in range(len(s)): + key = [k] + print key, m.get(key) + m.remove(key) + + assert len(m) == 0 + +if __name__ == '__main__': + import sys + main(*sys.argv)