Mercurial > public > think_complexity
comparison redblacktree.py @ 19:3c74185c5047
Added the remove() method to the red-black tree.
Made insert() and remove() return useful return values to support use in maps
or sets.
author | Brian Neal <bgneal@gmail.com> |
---|---|
date | Thu, 27 Dec 2012 13:46:12 -0600 |
parents | 92e2879e2e33 |
children | 0326803882ad |
comparison
equal
deleted
inserted
replaced
18:92e2879e2e33 | 19:3c74185c5047 |
---|---|
69 self.key = key | 69 self.key = key |
70 self.value = value | 70 self.value = value |
71 self.color = color | 71 self.color = color |
72 self.link = [None, None] | 72 self.link = [None, None] |
73 | 73 |
74 def free(self): | |
75 """Call this function when removing a node from a tree. | |
76 | |
77 It updates the links to encourage garbage collection. | |
78 | |
79 """ | |
80 self.link[0] = self.link[1] = None | |
81 | |
74 def __str__(self): | 82 def __str__(self): |
75 c = 'B' if self.color == BLACK else 'R' | 83 c = 'B' if self.color == BLACK else 'R' |
76 if self.value: | 84 if self.value: |
77 return '({}: {} => {})'.format(c, self.key, self.value) | 85 return '({}: {} => {})'.format(c, self.key, self.value) |
78 else: | 86 else: |
107 | 115 |
108 """ | 116 """ |
109 | 117 |
110 def __init__(self): | 118 def __init__(self): |
111 self.root = None | 119 self.root = None |
120 self._size = 0 | |
121 | |
122 def __len__(self): | |
123 """To support the len() function by returning the number of elements in | |
124 the tree. | |
125 | |
126 """ | |
127 return self._size | |
112 | 128 |
113 def __iter__(self): | 129 def __iter__(self): |
130 """Return an iterator to perform an inorder traversal of the tree.""" | |
114 return self._inorder(self.root) | 131 return self._inorder(self.root) |
115 | 132 |
116 def _inorder(self, node): | 133 def _inorder(self, node): |
117 """A generator to perform an inorder traversal of the nodes in the | 134 """A generator to perform an inorder traversal of the nodes in the |
118 tree starting at the given node. | 135 tree starting at the given node. |
119 | 136 |
120 """ | 137 """ |
121 if node.link[0]: | 138 if node: |
122 for n in self._inorder(node.link[0]): | 139 if node.link[0]: |
123 yield n | 140 for n in self._inorder(node.link[0]): |
124 | 141 yield n |
125 yield node | 142 |
126 | 143 yield node |
127 if node.link[1]: | 144 |
128 for n in self._inorder(node.link[1]): | 145 if node.link[1]: |
129 yield n | 146 for n in self._inorder(node.link[1]): |
147 yield n | |
130 | 148 |
131 def _single_rotate(self, root, d): | 149 def _single_rotate(self, root, d): |
132 """Perform a single rotation about the node 'root' in the given | 150 """Perform a single rotation about the node 'root' in the given |
133 direction 'd' (LEFT or RIGHT). | 151 direction 'd' (LEFT or RIGHT). |
134 | 152 |
202 # Only count black links | 220 # Only count black links |
203 | 221 |
204 return lh if is_red(root) else lh + 1 | 222 return lh if is_red(root) else lh + 1 |
205 | 223 |
206 def insert(self, key, value=None): | 224 def insert(self, key, value=None): |
207 """Insert the (key, value) pair into the tree.""" | 225 """Insert the (key, value) pair into the tree. Duplicate keys are not |
226 allowed in the tree. | |
227 | |
228 Returns a tuple of the form (node, flag) where node is either the newly | |
229 inserted tree node or an already existing node that has the same key. | |
230 The flag member is a Boolean that will be True if the node was inserted | |
231 and False if a node with the given key already exists in the tree. | |
232 | |
233 """ | |
208 | 234 |
209 # Check for the empty tree case: | 235 # Check for the empty tree case: |
210 if self.root is None: | 236 if self.root is None: |
211 self.root = Node(key=key, value=value, color=BLACK) | 237 self.root = Node(key=key, value=value, color=BLACK) |
212 return | 238 self._size = 1 |
239 return self.root, True | |
213 | 240 |
214 # False/dummy tree root | 241 # False/dummy tree root |
215 head = Node(key=None, value=None, color=BLACK) | 242 head = Node(key=None, value=None, color=BLACK) |
216 d = LEFT | 243 d = last = LEFT # direction variables |
217 last = LEFT | 244 |
218 | 245 # Set up helpers |
219 t = head | 246 t = head |
220 g = p = None | 247 g = p = None |
221 t.link[1] = self.root | 248 q = t.link[1] = self.root |
222 q = self.root | 249 |
250 # Return values | |
251 target, insert_flag = (None, False) | |
223 | 252 |
224 # Search down the tree | 253 # Search down the tree |
225 while True: | 254 while True: |
226 if q is None: | 255 if q is None: |
227 # Insert new node at the bottom | 256 # Insert new node at the bottom |
228 p.link[d] = q = Node(key=key, value=value, color=RED) | 257 p.link[d] = q = Node(key=key, value=value, color=RED) |
258 self._size += 1 | |
259 insert_flag = True | |
229 elif is_red(q.link[0]) and is_red(q.link[1]): | 260 elif is_red(q.link[0]) and is_red(q.link[1]): |
230 # Color flip | 261 # Color flip |
231 q.color = RED | 262 q.color = RED |
232 q.link[0].color = BLACK | 263 q.link[0].color = BLACK |
233 q.link[1].color = BLACK | 264 q.link[1].color = BLACK |
241 else: | 272 else: |
242 t.link[d2] = self._double_rotate(g, not last) | 273 t.link[d2] = self._double_rotate(g, not last) |
243 | 274 |
244 # Stop if found | 275 # Stop if found |
245 if q.key == key: | 276 if q.key == key: |
277 target = q | |
246 break | 278 break |
247 | 279 |
248 last = d | 280 last = d |
249 d = q.key < key | 281 d = q.key < key |
250 | 282 |
257 | 289 |
258 # Update root | 290 # Update root |
259 self.root = head.link[1] | 291 self.root = head.link[1] |
260 self.root.color = BLACK | 292 self.root.color = BLACK |
261 | 293 |
294 return target, insert_flag | |
295 | |
296 def remove(self, key): | |
297 """Remove the given key from the tree. | |
298 | |
299 Returns True if the key was found and removed and False if the key was | |
300 not found in the tree. | |
301 | |
302 """ | |
303 remove_flag = False # return value | |
304 | |
305 if self.root is not None: | |
306 # False/dummy tree root | |
307 head = Node(key=None, value=None, color=BLACK) | |
308 f = None # found item | |
309 d = RIGHT # direction | |
310 | |
311 # Set up helpers | |
312 q = head | |
313 g = p = None | |
314 q.link[d] = self.root | |
315 | |
316 # Search and push a red down to fix red violations as we go | |
317 while q.link[d]: | |
318 last = d | |
319 | |
320 # Move the helpers down | |
321 g = p | |
322 p = q | |
323 q = q.link[d] | |
324 d = q.key < key | |
325 | |
326 # Save the node with the matching data and keep going; we'll do | |
327 # removal tasks at the end | |
328 if q.key == key: | |
329 f = q | |
330 | |
331 # Push the red node down with rotations and color flips | |
332 if not is_red(q) and not is_red(q.link[d]): | |
333 if is_red(q.link[not d]): | |
334 p.link[last] = self._single_rotate(q, d) | |
335 p = p.link[last] | |
336 elif not is_red(q.link[not d]): | |
337 s = p.link[not last] | |
338 | |
339 if s: | |
340 if not is_red(s.link[not last]) and not is_red(s.link[last]): | |
341 # Color flip | |
342 p.color = BLACK | |
343 s.color = RED | |
344 q.color = RED | |
345 else: | |
346 d2 = g.link[1] is p | |
347 | |
348 if is_red(s.link[last]): | |
349 g.link[d2] = self._double_rotate(p, last) | |
350 elif is_red(s.link[not last]): | |
351 g.link[d2] = self._single_rotate(p, last) | |
352 | |
353 # Ensure correct coloring | |
354 q.color = g.link[d2].color = RED | |
355 g.link[d2].link[0].color = BLACK | |
356 g.link[d2].link[1].color = BLACK | |
357 | |
358 # Replace and remove the saved node | |
359 if f: | |
360 f.key, f.value = q.key, q.value | |
361 p.link[p.link[1] is q] = q.link[q.link[0] is None] | |
362 q.free() | |
363 self._size -= 1 | |
364 remove_flag = True | |
365 | |
366 # Update root and make it black | |
367 self.root = head.link[1] | |
368 if self.root: | |
369 self.root.color = BLACK | |
370 | |
371 return remove_flag | |
372 | |
262 | 373 |
263 if __name__ == '__main__': | 374 if __name__ == '__main__': |
264 import random | 375 import random |
265 | 376 |
266 for n in range(30): | 377 for n in range(30): |
267 tree = Tree() | 378 tree = Tree() |
268 tree.validate() | 379 tree.validate() |
269 | 380 |
381 vals = [] | |
382 vals_set = set() | |
270 for i in range(25): | 383 for i in range(25): |
271 val = random.randint(0, 100) | 384 val = random.randint(0, 100) |
385 while val in vals_set: | |
386 val = random.randint(0, 100) | |
387 vals.append(val) | |
388 vals_set.add(val) | |
389 | |
272 tree.insert(val) | 390 tree.insert(val) |
273 | 391 tree.validate() |
392 | |
393 print 'Inserted in this order:', vals | |
394 assert len(vals) == len(vals_set) | |
395 keys = [] | |
274 for n in tree: | 396 for n in tree: |
275 print n.key, | 397 print n.key, |
276 print | 398 keys.append(n.key) |
277 | 399 print ' - len:', len(tree) |
278 tree.validate() | 400 |
401 # delete in a random order | |
402 random.shuffle(keys) | |
403 | |
404 for k in keys: | |
405 print 'Removing', k | |
406 tree.remove(k) | |
407 for n in tree: | |
408 print n.key, | |
409 print ' - len:', len(tree) | |
410 tree.validate() |