diff --git a/algorithms/binary_tree.py b/algorithms/binary_tree.py index 28bac22..2a8d71d 100644 --- a/algorithms/binary_tree.py +++ b/algorithms/binary_tree.py @@ -18,16 +18,19 @@ def insert(self, data): @param data node data object to insert """ - if data < self.data: - if self.left is None: - self.left = Node(data) - else: - self.left.insert(data) - elif data > self.data: - if self.right is None: - self.right = Node(data) - else: - self.right.insert(data) + if self.data: + if data < self.data: + if self.left is None: + self.left = Node(data) + else: + self.left.insert(data) + elif data > self.data: + if self.right is None: + self.right = Node(data) + else: + self.right.insert(data) + else: + self.data = data def lookup(self, data, parent=None): """ @@ -60,11 +63,14 @@ def delete(self, data): children_count = node.children_count() if children_count == 0: # if node has no children, just remove it - if parent.left is node: - parent.left = None + if parent: + if parent.left is node: + parent.left = None + else: + parent.right = None + del node else: - parent.right = None - del node + self.data = None elif children_count == 1: # if node has 1 child # replace node by its child @@ -77,7 +83,11 @@ def delete(self, data): parent.left = n else: parent.right = n - del node + del node + else: + self.left = n.left + self.right = n.right + self.data = n.data else: # if node has 2 children # find its successor @@ -157,4 +167,4 @@ def children_count(self): cnt += 1 if self.right: cnt += 1 - return cnt + return cnt \ No newline at end of file diff --git a/algorithms/tests/test_binary_tree.py b/algorithms/tests/test_binary_tree.py index 229ea7e..080ef5f 100644 --- a/algorithms/tests/test_binary_tree.py +++ b/algorithms/tests/test_binary_tree.py @@ -64,6 +64,22 @@ def test_binary_tree(self): t.append(d) self.assertEquals(t, [7, 10, 11, 14, 17]) + # check for root deletion + root = binary_tree.Node(1) + root.insert(2) + root.insert(0) + root.delete(1) + self.assertEquals(root.data, 2) + root.delete(2) + self.assertEquals(root.data, 0) + root.delete(0) + self.assertEquals(root.data, None) + root.insert(1) + self.assertEquals(root.data, 1) + self.assertEquals(root.left, None) + self.assertEquals(root.right, None) + + if __name__ == '__main__': unittest.main()