@@ -27,6 +27,7 @@ def quicksort_inplace(array, beg, end): # 注意这里我们都用左闭右
27
27
28
28
29
29
def partition (array , beg , end ):
30
+ """对给定数组执行 partition 操作,返回新的 pivot 位置"""
30
31
pivot_index = beg
31
32
pivot = array [pivot_index ]
32
33
left = pivot_index + 1
@@ -55,7 +56,7 @@ def test_partition():
55
56
l = [1 , 2 , 3 , 4 ]
56
57
assert partition (l , 0 , len (l )) == 0
57
58
l = [4 , 3 , 2 , 1 ]
58
- assert partition (l , 0 , len (l ))
59
+ assert partition (l , 0 , len (l )) == 3
59
60
60
61
61
62
def test_quicksort_inplace ():
@@ -65,3 +66,31 @@ def test_quicksort_inplace():
65
66
sorted_seq = sorted (seq )
66
67
quicksort_inplace (seq , 0 , len (seq ))
67
68
assert seq == sorted_seq
69
+
70
+
71
+ def nth_element (array , beg , end , nth ):
72
+ """查找一个数组第 n 大元素"""
73
+ if beg < end :
74
+ pivot_idx = partition (array , beg , end )
75
+ if pivot_idx == nth - 1 : # 数组小标从 0 开始
76
+ return array [pivot_idx ]
77
+ elif pivot_idx > nth - 1 :
78
+ return nth_element (array , beg , pivot_idx , nth )
79
+ else :
80
+ return nth_element (array , pivot_idx + 1 , end , nth )
81
+
82
+
83
+ def test_nth_element ():
84
+ l1 = [3 , 5 , 4 , 2 , 1 ]
85
+ assert nth_element (l1 , 0 , len (l1 ), 3 ) == 3
86
+ assert nth_element (l1 , 0 , len (l1 ), 2 ) == 2
87
+
88
+ l = [1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 ]
89
+ for i in l :
90
+ assert nth_element (l , 0 , len (l ), i ) == i
91
+ for i in reversed (l ):
92
+ assert nth_element (l , 0 , len (l ), i ) == i
93
+
94
+
95
+ if __name__ == '__main__' :
96
+ test_nth_element ()
0 commit comments