Skip to content

Commit e854bbb

Browse files
authored
Merge pull request krahets#138 from a16su/master
add binary_tree and avl_tree python code
2 parents 36507b8 + f9cc3a5 commit e854bbb

11 files changed

+815
-36
lines changed

codes/cpp/chapter_tree/binary_tree_bfs.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@ vector<int> hierOrder(TreeNode* root) {
1515
vector<int> vec;
1616
while (!queue.empty()) {
1717
TreeNode* node = queue.front();
18-
queue.pop(); // 队列出队
19-
vec.push_back(node->val); // 保存结点
18+
queue.pop(); // 队列出队
19+
vec.push_back(node->val); // 保存结点
2020
if (node->left != nullptr)
21-
queue.push(node->left); // 左子结点入队
21+
queue.push(node->left); // 左子结点入队
2222
if (node->right != nullptr)
23-
queue.push(node->right); // 右子结点入队
23+
queue.push(node->right); // 右子结点入队
2424
}
2525
return vec;
2626
}

codes/python/chapter_tree/avl_tree.py

+208
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
"""
2+
File: avl_tree.py
3+
Created Time: 2022-12-20
4+
Author: a16su ([email protected])
5+
"""
6+
7+
import sys, os.path as osp
8+
import typing
9+
10+
sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))
11+
from include import *
12+
13+
14+
class AVLTree:
15+
def __init__(self, root: typing.Optional[TreeNode] = None):
16+
self.root = root
17+
18+
""" 获取结点高度 """
19+
def height(self, node: typing.Optional[TreeNode]) -> int:
20+
# 空结点高度为 -1 ,叶结点高度为 0
21+
if node is not None:
22+
return node.height
23+
return -1
24+
25+
""" 更新结点高度 """
26+
def __update_height(self, node: TreeNode):
27+
# 结点高度等于最高子树高度 + 1
28+
node.height = max([self.height(node.left), self.height(node.right)]) + 1
29+
30+
""" 获取平衡因子 """
31+
def balance_factor(self, node: TreeNode) -> int:
32+
# 空结点平衡因子为 0
33+
if node is None:
34+
return 0
35+
# 结点平衡因子 = 左子树高度 - 右子树高度
36+
return self.height(node.left) - self.height(node.right)
37+
38+
""" 右旋操作 """
39+
def __right_rotate(self, node: TreeNode) -> TreeNode:
40+
child = node.left
41+
grand_child = child.right
42+
# 以 child 为原点,将 node 向右旋转
43+
child.right = node
44+
node.left = grand_child
45+
# 更新结点高度
46+
self.__update_height(node)
47+
self.__update_height(child)
48+
# 返回旋转后子树的根节点
49+
return child
50+
51+
""" 左旋操作 """
52+
def __left_rotate(self, node: TreeNode) -> TreeNode:
53+
child = node.right
54+
grand_child = child.left
55+
# 以 child 为原点,将 node 向左旋转
56+
child.left = node
57+
node.right = grand_child
58+
# 更新结点高度
59+
self.__update_height(node)
60+
self.__update_height(child)
61+
# 返回旋转后子树的根节点
62+
return child
63+
64+
""" 执行旋转操作,使该子树重新恢复平衡 """
65+
def __rotate(self, node: TreeNode) -> TreeNode:
66+
# 获取结点 node 的平衡因子
67+
balance_factor = self.balance_factor(node)
68+
# 左偏树
69+
if balance_factor > 1:
70+
if self.balance_factor(node.left) >= 0:
71+
# 右旋
72+
return self.__right_rotate(node)
73+
else:
74+
# 先左旋后右旋
75+
node.left = self.__left_rotate(node.left)
76+
return self.__right_rotate(node)
77+
# 右偏树
78+
elif balance_factor < -1:
79+
if self.balance_factor(node.right) <= 0:
80+
# 左旋
81+
return self.__left_rotate(node)
82+
else:
83+
# 先右旋后左旋
84+
node.right = self.__right_rotate(node.right)
85+
return self.__left_rotate(node)
86+
# 平衡树,无需旋转,直接返回
87+
return node
88+
89+
""" 插入结点 """
90+
def insert(self, val) -> TreeNode:
91+
self.root = self.__insert_helper(self.root, val)
92+
return self.root
93+
94+
""" 递归插入结点(辅助函数)"""
95+
def __insert_helper(self, node: typing.Optional[TreeNode], val: int) -> TreeNode:
96+
if node is None:
97+
return TreeNode(val)
98+
# 1. 查找插入位置,并插入结点
99+
if val < node.val:
100+
node.left = self.__insert_helper(node.left, val)
101+
elif val > node.val:
102+
node.right = self.__insert_helper(node.right, val)
103+
else:
104+
# 重复结点不插入,直接返回
105+
return node
106+
# 更新结点高度
107+
self.__update_height(node)
108+
# 2. 执行旋转操作,使该子树重新恢复平衡
109+
return self.__rotate(node)
110+
111+
""" 删除结点 """
112+
def remove(self, val: int):
113+
root = self.__remove_helper(self.root, val)
114+
return root
115+
116+
""" 递归删除结点(辅助函数) """
117+
def __remove_helper(self, node: typing.Optional[TreeNode], val: int) -> typing.Optional[TreeNode]:
118+
if node is None:
119+
return None
120+
# 1. 查找结点,并删除之
121+
if val < node.val:
122+
node.left = self.__remove_helper(node.left, val)
123+
elif val > node.val:
124+
node.right = self.__remove_helper(node.right, val)
125+
else:
126+
if node.left is None or node.right is None:
127+
child = node.left or node.right
128+
# 子结点数量 = 0 ,直接删除 node 并返回
129+
if child is None:
130+
return None
131+
# 子结点数量 = 1 ,直接删除 node
132+
else:
133+
node = child
134+
else: # 子结点数量 = 2 ,则将中序遍历的下个结点删除,并用该结点替换当前结点
135+
temp = self.__min_node(node.right)
136+
node.right = self.__remove_helper(node.right, temp.val)
137+
node.val = temp.val
138+
# 更新结点高度
139+
self.__update_height(node)
140+
# 2. 执行旋转操作,使该子树重新恢复平衡
141+
return self.__rotate(node)
142+
143+
""" 获取最小结点 """
144+
def __min_node(self, node: typing.Optional[TreeNode]) -> typing.Optional[TreeNode]:
145+
if node is None:
146+
return None
147+
# 循环访问左子结点,直到叶结点时为最小结点,跳出
148+
while node.left is not None:
149+
node = node.left
150+
return node
151+
152+
""" 查找结点 """
153+
def search(self, val: int):
154+
cur = self.root
155+
# 循环查找,越过叶结点后跳出
156+
while cur is not None:
157+
# 目标结点在 root 的右子树中
158+
if cur.val < val:
159+
cur = cur.right
160+
# 目标结点在 root 的左子树中
161+
elif cur.val > val:
162+
cur = cur.left
163+
# 找到目标结点,跳出循环
164+
else:
165+
break
166+
# 返回目标结点
167+
return cur
168+
169+
170+
""" Driver Code """
171+
if __name__ == "__main__":
172+
def test_insert(tree: AVLTree, val: int):
173+
tree.insert(val)
174+
print("\n插入结点 {} 后,AVL 树为".format(val))
175+
print_tree(tree.root)
176+
177+
def test_remove(tree: AVLTree, val: int):
178+
tree.remove(val)
179+
print("\n删除结点 {} 后,AVL 树为".format(val))
180+
print_tree(tree.root)
181+
182+
# 初始化空 AVL 树
183+
avl_tree = AVLTree()
184+
185+
# 插入结点
186+
# 请关注插入结点后,AVL 树是如何保持平衡的
187+
test_insert(avl_tree, 1)
188+
test_insert(avl_tree, 2)
189+
test_insert(avl_tree, 3)
190+
test_insert(avl_tree, 4)
191+
test_insert(avl_tree, 5)
192+
test_insert(avl_tree, 8)
193+
test_insert(avl_tree, 7)
194+
test_insert(avl_tree, 9)
195+
test_insert(avl_tree, 10)
196+
test_insert(avl_tree, 6)
197+
198+
# 插入重复结点
199+
test_insert(avl_tree, 7)
200+
201+
# 删除结点
202+
# 请关注删除结点后,AVL 树是如何保持平衡的
203+
test_remove(avl_tree, 8) # 删除度为 0 的结点
204+
test_remove(avl_tree, 5) # 删除度为 1 的结点
205+
test_remove(avl_tree, 4) # 删除度为 2 的结点
206+
207+
result_node = avl_tree.search(7)
208+
print("\n查找到的结点对象为 {},结点值 = {}".format(result_node, result_node.val))
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,167 @@
11
"""
22
File: binary_search_tree.py
3-
Created Time: 2022-11-25
4-
Author: Krahets (krahets@163.com)
3+
Created Time: 2022-12-20
4+
Author: a16su (lpluls001@gmail.com)
55
"""
66

77
import sys, os.path as osp
8+
import typing
9+
810
sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))
911
from include import *
1012

13+
14+
""" 二叉搜索树 """
15+
class BinarySearchTree:
16+
def __init__(self, nums: typing.List[int]) -> None:
17+
nums.sort()
18+
self.__root = self.build_tree(nums, 0, len(nums) - 1)
19+
20+
""" 构建二叉搜索树 """
21+
def build_tree(self, nums: typing.List[int], start_index: int, end_index: int) -> typing.Optional[TreeNode]:
22+
if start_index > end_index:
23+
return None
24+
25+
# 将数组中间结点作为根结点
26+
mid = (start_index + end_index) // 2
27+
root = TreeNode(nums[mid])
28+
# 递归建立左子树和右子树
29+
root.left = self.build_tree(nums=nums, start_index=start_index, end_index=mid - 1)
30+
root.right = self.build_tree(nums=nums, start_index=mid + 1, end_index=end_index)
31+
return root
32+
33+
@property
34+
def root(self) -> typing.Optional[TreeNode]:
35+
return self.__root
36+
37+
""" 查找结点 """
38+
def search(self, num: int) -> typing.Optional[TreeNode]:
39+
cur = self.root
40+
# 循环查找,越过叶结点后跳出
41+
while cur is not None:
42+
# 目标结点在 root 的右子树中
43+
if cur.val < num:
44+
cur = cur.right
45+
# 目标结点在 root 的左子树中
46+
elif cur.val > num:
47+
cur = cur.left
48+
# 找到目标结点,跳出循环
49+
else:
50+
break
51+
return cur
52+
53+
""" 插入结点 """
54+
def insert(self, num: int) -> typing.Optional[TreeNode]:
55+
root = self.root
56+
# 若树为空,直接提前返回
57+
if root is None:
58+
return None
59+
60+
cur = root
61+
pre = None
62+
63+
# 循环查找,越过叶结点后跳出
64+
while cur is not None:
65+
# 找到重复结点,直接返回
66+
if cur.val == num:
67+
return None
68+
pre = cur
69+
70+
if cur.val < num: # 插入位置在 root 的右子树中
71+
cur = cur.right
72+
else: # 插入位置在 root 的左子树中
73+
cur = cur.left
74+
75+
# 插入结点 val
76+
node = TreeNode(num)
77+
if pre.val < num:
78+
pre.right = node
79+
else:
80+
pre.left = node
81+
return node
82+
83+
""" 删除结点 """
84+
def remove(self, num: int) -> typing.Optional[TreeNode]:
85+
root = self.root
86+
# 若树为空,直接提前返回
87+
if root is None:
88+
return None
89+
90+
cur = root
91+
pre = None
92+
93+
# 循环查找,越过叶结点后跳出
94+
while cur is not None:
95+
# 找到待删除结点,跳出循环
96+
if cur.val == num:
97+
break
98+
pre = cur
99+
if cur.val < num: # 待删除结点在 root 的右子树中
100+
cur = cur.right
101+
else: # 待删除结点在 root 的左子树中
102+
cur = cur.left
103+
104+
# 若无待删除结点,则直接返回
105+
if cur is None:
106+
return None
107+
108+
# 子结点数量 = 0 or 1
109+
if cur.left is None or cur.right is None:
110+
# 当子结点数量 = 0 / 1 时, child = null / 该子结点
111+
child = cur.left or cur.right
112+
# 删除结点 cur
113+
if pre.left == cur:
114+
pre.left = child
115+
else:
116+
pre.right = child
117+
# 子结点数量 = 2
118+
else:
119+
# 获取中序遍历中 cur 的下一个结点
120+
nex = self.min(cur.right)
121+
tmp = nex.val
122+
# 递归删除结点 nex
123+
self.remove(nex.val)
124+
# 将 nex 的值复制给 cur
125+
cur.val = tmp
126+
return cur
127+
128+
""" 获取最小结点 """
129+
def min(self, root: typing.Optional[TreeNode]) -> typing.Optional[TreeNode]:
130+
if root is None:
131+
return root
132+
133+
# 循环访问左子结点,直到叶结点时为最小结点,跳出
134+
while root.left is not None:
135+
root = root.left
136+
return root
137+
138+
139+
""" Driver Code """
140+
if __name__ == "__main__":
141+
# 初始化二叉搜索树
142+
nums = list(range(1, 16))
143+
bst = BinarySearchTree(nums=nums)
144+
print("\n初始化的二叉树为\n")
145+
print_tree(bst.root)
146+
147+
# 查找结点
148+
node = bst.search(5)
149+
print("\n查找到的结点对象为: {},结点值 = {}".format(node, node.val))
150+
151+
# 插入结点
152+
ndoe = bst.insert(16)
153+
print("\n插入结点 16 后,二叉树为\n")
154+
print_tree(bst.root)
155+
156+
# 删除结点
157+
bst.remove(1)
158+
print("\n删除结点 1 后,二叉树为\n")
159+
print_tree(bst.root)
160+
161+
bst.remove(2)
162+
print("\n删除结点 2 后,二叉树为\n")
163+
print_tree(bst.root)
164+
165+
bst.remove(4)
166+
print("\n删除结点 4 后,二叉树为\n")
167+
print_tree(bst.root)

0 commit comments

Comments
 (0)