Skip to content

Commit 9eac127

Browse files
committed
add binary_tree and avl_tree python code
1 parent 2a1bb23 commit 9eac127

File tree

9 files changed

+968
-23
lines changed

9 files changed

+968
-23
lines changed

codes/python/chapter_tree/avl_tree.py

Lines changed: 281 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,281 @@
1+
import sys, os.path as osp
2+
import typing
3+
4+
sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))
5+
from include import *
6+
7+
8+
class AVLTreeNode:
9+
def __init__(
10+
self,
11+
val=None,
12+
height: int = 0,
13+
left: typing.Optional["AVLTreeNode"] = None,
14+
right: typing.Optional["AVLTreeNode"] = None,
15+
):
16+
self.val = val
17+
self.height = height
18+
self.left = left
19+
self.right = right
20+
21+
def __str__(self):
22+
val = self.val
23+
left_val = self.left.val if self.left else None
24+
right_val = self.right.val if self.right else None
25+
return "<AVLTreeNode: {}, leftAVLTreeNode: {}, rightAVLTreeNode: {}>".format(
26+
val, left_val, right_val
27+
)
28+
29+
30+
class AVLTree:
31+
def __init__(self, root: typing.Optional[AVLTreeNode] = None):
32+
self.root = root
33+
34+
@staticmethod
35+
def height(node: typing.Optional[AVLTreeNode]) -> int:
36+
"""
37+
获取结点高度
38+
Args:
39+
node:起始结点
40+
41+
Returns: 高度 or -1
42+
43+
"""
44+
# 空结点高度为 -1 ,叶结点高度为 0
45+
if node is not None:
46+
return node.height
47+
return -1
48+
49+
def __update_height(self, node: AVLTreeNode):
50+
"""
51+
更新结点高度
52+
Args:
53+
node: 要更新高度的结点
54+
55+
Returns: None
56+
57+
"""
58+
# 结点高度等于最高子树高度 + 1
59+
node.height = max([self.height(node.left), self.height(node.right)]) + 1
60+
61+
def balance_factor(self, node: AVLTreeNode) -> int:
62+
"""
63+
获取结点平衡因子
64+
Args:
65+
node: 要获取平衡因子的结点
66+
67+
Returns: 平衡因子
68+
69+
"""
70+
# 空结点平衡因子为 0
71+
if node is None:
72+
return 0
73+
# 结点平衡因子 = 左子树高度 - 右子树高度
74+
return self.height(node.left) - self.height(node.right)
75+
76+
def __right_rotate(self, node: AVLTreeNode) -> AVLTreeNode:
77+
child = node.left
78+
grand_child = child.right
79+
# 以 child 为原点,将 node 向右旋转
80+
child.right = node
81+
node.left = grand_child
82+
# 更新结点高度
83+
self.__update_height(node)
84+
self.__update_height(child)
85+
# 返回旋转后子树的根节点
86+
return child
87+
88+
def __left_rotate(self, node: AVLTreeNode) -> AVLTreeNode:
89+
child = node.right
90+
grand_child = child.left
91+
# 以 child 为原点,将 node 向左旋转
92+
child.left = node
93+
node.right = grand_child
94+
# 更新结点高度
95+
self.__update_height(node)
96+
self.__update_height(child)
97+
# 返回旋转后子树的根节点
98+
return child
99+
100+
def rotate(self, node: AVLTreeNode):
101+
"""
102+
执行旋转操作,使该子树重新恢复平衡
103+
Args:
104+
node: 要旋转的根结点
105+
106+
Returns: 旋转后的根结点
107+
108+
"""
109+
# 获取结点 node 的平衡因子
110+
balance_factor = self.balance_factor(node)
111+
# 左偏树
112+
if balance_factor > 1:
113+
if self.balance_factor(node.left) >= 0:
114+
# 右旋
115+
return self.__right_rotate(node)
116+
else:
117+
# 先左旋后右旋
118+
node.left = self.__left_rotate(node.left)
119+
return self.__right_rotate(node)
120+
# 右偏树
121+
elif balance_factor < -1:
122+
if self.balance_factor(node.right) <= 0:
123+
# 左旋
124+
return self.__left_rotate(node)
125+
else:
126+
# 先右旋后左旋
127+
node.right = self.__right_rotate(node.right)
128+
return self.__left_rotate(node)
129+
# 平衡树,无需旋转,直接返回
130+
return node
131+
132+
def insert(self, val) -> AVLTreeNode:
133+
"""
134+
插入结点
135+
Args:
136+
val: 结点的值
137+
138+
Returns:
139+
node: 插入结点后的根结点
140+
"""
141+
self.root = self.insert_helper(self.root, val)
142+
return self.root
143+
144+
def insert_helper(
145+
self, node: typing.Optional[AVLTreeNode], val: int
146+
) -> AVLTreeNode:
147+
"""
148+
递归插入结点(辅助函数)
149+
Args:
150+
node: 要插入的根结点
151+
val: 要插入的结点的值
152+
153+
Returns: 插入结点后的根结点
154+
155+
"""
156+
if node is None:
157+
return AVLTreeNode(val)
158+
# 1. 查找插入位置,并插入结点
159+
if val < node.val:
160+
node.left = self.insert_helper(node.left, val)
161+
elif val > node.val:
162+
node.right = self.insert_helper(node.right, val)
163+
else:
164+
# 重复结点不插入,直接返回
165+
return node
166+
# 更新结点高度
167+
self.__update_height(node)
168+
# 2. 执行旋转操作,使该子树重新恢复平衡
169+
return self.rotate(node)
170+
171+
def remove(self, val: int):
172+
"""
173+
删除结点
174+
Args:
175+
val: 要删除的结点的值
176+
177+
Returns:
178+
179+
"""
180+
root = self.remove_helper(self.root, val)
181+
return root
182+
183+
def remove_helper(
184+
self, node: typing.Optional[AVLTreeNode], val: int
185+
) -> typing.Optional[AVLTreeNode]:
186+
"""
187+
递归删除结点(辅助函数)
188+
Args:
189+
node: 删除的起始结点
190+
val: 要删除的结点的值
191+
192+
Returns: 删除目标结点后的起始结点
193+
194+
"""
195+
if node is None:
196+
return None
197+
# 1. 查找结点,并删除之
198+
if val < node.val:
199+
node.left = self.remove_helper(node.left, val)
200+
elif val > node.val:
201+
node.right = self.remove_helper(node.right, val)
202+
else:
203+
if node.left is None or node.right is None:
204+
child = node.left or node.right
205+
# 子结点数量 = 0 ,直接删除 node 并返回
206+
if child is None:
207+
return None
208+
# 子结点数量 = 1 ,直接删除 node
209+
else:
210+
node = child
211+
else: # 子结点数量 = 2 ,则将中序遍历的下个结点删除,并用该结点替换当前结点
212+
temp = self.min_node(node.right)
213+
node.right = self.remove_helper(node.right, temp.val)
214+
node.val = temp.val
215+
# 更新结点高度
216+
self.__update_height(node)
217+
# 2. 执行旋转操作,使该子树重新恢复平衡
218+
return self.rotate(node)
219+
220+
def min_node(
221+
self, node: typing.Optional[AVLTreeNode]
222+
) -> typing.Optional[AVLTreeNode]:
223+
# 获取最小结点
224+
if node is None:
225+
return None
226+
# 循环访问左子结点,直到叶结点时为最小结点,跳出
227+
while node.left is not None:
228+
node = node.left
229+
return node
230+
231+
def search(self, val: int):
232+
cur = self.root
233+
while cur is not None:
234+
if cur.val < val:
235+
cur = cur.right
236+
elif cur.val > val:
237+
cur = cur.left
238+
else:
239+
break
240+
return cur
241+
242+
243+
if __name__ == "__main__":
244+
245+
def test_insert(tree: AVLTree, val: int):
246+
tree.insert(val)
247+
print("\n插入结点 {} 后,AVL 树为".format(val))
248+
print_tree(tree.root)
249+
250+
def test_remove(tree: AVLTree, val: int):
251+
tree.remove(val)
252+
print("\n删除结点 {} 后,AVL 树为".format(val))
253+
print_tree(tree.root)
254+
255+
# 初始化空 AVL 树
256+
avl_tree = AVLTree()
257+
258+
# 插入结点
259+
# 请关注插入结点后,AVL 树是如何保持平衡的
260+
test_insert(avl_tree, 1)
261+
test_insert(avl_tree, 2)
262+
test_insert(avl_tree, 3)
263+
test_insert(avl_tree, 4)
264+
test_insert(avl_tree, 5)
265+
test_insert(avl_tree, 8)
266+
test_insert(avl_tree, 7)
267+
test_insert(avl_tree, 9)
268+
test_insert(avl_tree, 10)
269+
test_insert(avl_tree, 6)
270+
271+
# 插入重复结点
272+
test_insert(avl_tree, 7)
273+
274+
# 删除结点
275+
# 请关注删除结点后,AVL 树是如何保持平衡的
276+
test_remove(avl_tree, 8) # 删除度为 0 的结点
277+
test_remove(avl_tree, 5) # 删除度为 1 的结点
278+
test_remove(avl_tree, 4) # 删除度为 2 的结点
279+
280+
result_node = avl_tree.search(7)
281+
print("\n查找到的结点对象为 {},结点值 = {}".format(result_node, result_node.val))

0 commit comments

Comments
 (0)