|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | + |
| 3 | +# 从数组和列表章复制的代码 |
| 4 | + |
| 5 | + |
| 6 | +class Array(object): |
| 7 | + |
| 8 | + def __init__(self, size=32): |
| 9 | + self._size = size |
| 10 | + self._items = [None] * size |
| 11 | + |
| 12 | + def __getitem__(self, index): |
| 13 | + return self._items[index] |
| 14 | + |
| 15 | + def __setitem__(self, index, value): |
| 16 | + self._items[index] = value |
| 17 | + |
| 18 | + def __len__(self): |
| 19 | + return self._size |
| 20 | + |
| 21 | + def clear(self, value=None): |
| 22 | + for i in range(self._items): |
| 23 | + self._items[i] = value |
| 24 | + |
| 25 | + def __iter__(self): |
| 26 | + for item in self._items: |
| 27 | + yield item |
| 28 | + |
| 29 | + |
| 30 | +class Slot(object): |
| 31 | + """定义一个 hash 表 数组的槽 |
| 32 | + 注意,一个槽有三种状态,看你能否想明白。相比链接法解决冲突,二次探查法删除一个 key 的操作稍微复杂。 |
| 33 | + 1.从未使用 HashMap.UNUSED。此槽没有被使用和冲突过,查找时只要找到 UNUSED 就不用再继续探查了 |
| 34 | + 2.使用过但是 remove 了,此时是 HashMap.EMPTY,该探查点后边的元素扔可能是有key |
| 35 | + 3.槽正在使用 Slot 节点 |
| 36 | + """ |
| 37 | + |
| 38 | + def __init__(self, key, value): |
| 39 | + self.key, self.value = key, value |
| 40 | + |
| 41 | + |
| 42 | +class HashTable(object): |
| 43 | + |
| 44 | + UNUSED = None # 没被使用过的槽,作为该类变量的一个单例,下边都是is 判断 |
| 45 | + EMPTY = Slot(None, None) # 使用过但是被删除的槽 |
| 46 | + |
| 47 | + def __init__(self): |
| 48 | + self._table = Array(7) |
| 49 | + self.length = 0 |
| 50 | + |
| 51 | + @property |
| 52 | + def _load_factor(self): |
| 53 | + # load factor 超过 2/3 就重新分配空间 |
| 54 | + return self.length / float(len(self._table)) |
| 55 | + |
| 56 | + def __len__(self): |
| 57 | + return self.length |
| 58 | + |
| 59 | + def _hash1(self, key): |
| 60 | + """ 计算key的hash值""" |
| 61 | + return abs(hash(key)) % len(self._table) |
| 62 | + |
| 63 | + def _find_slot(self, key, for_insert=False): |
| 64 | + """_find_slot |
| 65 | +
|
| 66 | + :param key: |
| 67 | + :param for_insert: 是否插入,还是仅仅查询 |
| 68 | + :return: slot index or None |
| 69 | + """ |
| 70 | + index = self._hash1(key) |
| 71 | + base_index = index |
| 72 | + hash_times = 1 |
| 73 | + _len = len(self._table) |
| 74 | + |
| 75 | + if not for_insert: # 查找是否存在 key |
| 76 | + while self._table[index] is not HashTable.UNUSED: |
| 77 | + if self._table[index] is HashTable.EMPTY: |
| 78 | + index = (index + hash_times * hash_times) % _len # 一个简单的二次方探查 |
| 79 | + continue |
| 80 | + elif self._table[index].key == key: |
| 81 | + return index |
| 82 | + index = (index + hash_times * hash_times) % _len |
| 83 | + hash_times += 1 |
| 84 | + return None |
| 85 | + else: |
| 86 | + while not self._slot_can_insert(index): # 循环直到找到一个可以插入的槽 |
| 87 | + index = (index + hash_times * hash_times) % _len |
| 88 | + hash_times += 1 |
| 89 | + return index |
| 90 | + |
| 91 | + def _slot_can_insert(self, index): |
| 92 | + return (self._table[index] is HashTable.EMPTY or self._table[index] is HashTable.UNUSED) |
| 93 | + |
| 94 | + def __contains__(self, key): # in operator |
| 95 | + index = self._find_slot(key, for_insert=False) |
| 96 | + return index is not None |
| 97 | + |
| 98 | + def add(self, key, value): |
| 99 | + if key in self: # key 相同值不一样的时候,用新的值 |
| 100 | + index = self._find_slot(key, for_insert=False) |
| 101 | + self._table[index].value = value |
| 102 | + return False |
| 103 | + else: |
| 104 | + index = self._find_slot(key, for_insert=True) |
| 105 | + self._table[index] = Slot(key, value) |
| 106 | + self.length += 1 |
| 107 | + if self._load_factor >= 0.8: # 注意超过了 阈值 rehashing |
| 108 | + self._rehash() |
| 109 | + return True |
| 110 | + |
| 111 | + def _rehash(self): |
| 112 | + old_table = self._table |
| 113 | + newsize = len(self._table) * 2 + 1 # 扩大 2*n + 1 |
| 114 | + self._table = Array(newsize) |
| 115 | + |
| 116 | + self.length = 0 |
| 117 | + |
| 118 | + for slot in old_table: |
| 119 | + if slot is not HashTable.UNUSED and slot is not HashTable.EMPTY: |
| 120 | + index = self._find_slot(slot.key, for_insert=True) |
| 121 | + self._table[index] = slot |
| 122 | + self.length += 1 |
| 123 | + |
| 124 | + def get(self, key, default=None): |
| 125 | + index = self._find_slot(key, for_insert=False) |
| 126 | + if index is None: |
| 127 | + return default |
| 128 | + else: |
| 129 | + return self._table[index].value |
| 130 | + |
| 131 | + def remove(self, key): |
| 132 | + assert key in self, 'keyerror' |
| 133 | + index = self._find_slot(key, for_insert=False) |
| 134 | + value = self._table[index].value |
| 135 | + self.length -= 1 |
| 136 | + self._table[index] = HashTable.EMPTY |
| 137 | + return value |
| 138 | + |
| 139 | + def __iter__(self): |
| 140 | + for slot in self._table: |
| 141 | + if slot not in (HashTable.EMPTY, HashTable.UNUSED): |
| 142 | + yield slot.key # 和 python dict 一样,默认遍历 key,需要value 的话写个 items() 方法 |
| 143 | + |
| 144 | + |
| 145 | +######################################### |
| 146 | +# 上边是从 哈希表章 拷贝过来的代码,我们会直接继承 HashTable 实现 集合 set |
| 147 | +######################################### |
| 148 | + |
| 149 | +class SetADT(HashTable): |
| 150 | + |
| 151 | + def add(self, key): |
| 152 | + # 集合其实就是一个 dict,只不过我们把它的 value 设置成 1 |
| 153 | + return super(SetADT, self).add(key, True) |
| 154 | + |
| 155 | + def __and__(self, other_set): |
| 156 | + """交集 A&B""" |
| 157 | + new_set = SetADT() |
| 158 | + for element_a in self: |
| 159 | + if element_a in other_set: |
| 160 | + new_set.add(element_a) |
| 161 | + for element_b in other_set: |
| 162 | + if element_b in self: |
| 163 | + new_set.add(element_b) |
| 164 | + return new_set |
| 165 | + |
| 166 | + def __sub__(self, other_set): |
| 167 | + """差集 A-B""" |
| 168 | + new_set = SetADT() |
| 169 | + new_set = SetADT() |
| 170 | + for element_a in self: |
| 171 | + if element_a not in other_set: |
| 172 | + new_set.add(element_a) |
| 173 | + return new_set |
| 174 | + |
| 175 | + def __or__(self, other_set): |
| 176 | + """并集 A|B""" |
| 177 | + new_set = SetADT() |
| 178 | + for element_a in self: |
| 179 | + new_set.add(element_a) |
| 180 | + for element_b in other_set: |
| 181 | + new_set.add(element_b) |
| 182 | + return new_set |
| 183 | + |
| 184 | + |
| 185 | +def test_set_adt(): |
| 186 | + sa = SetADT() |
| 187 | + sa.add(1) |
| 188 | + sa.add(2) |
| 189 | + sa.add(3) |
| 190 | + assert 1 in sa # 测试 __contains__ 方法,实现了 add 和 __contains__,集合最基本的功能就实现啦 |
| 191 | + |
| 192 | + sb = SetADT() |
| 193 | + sb.add(3) |
| 194 | + sb.add(4) |
| 195 | + sb.add(5) |
| 196 | + |
| 197 | + sorted(list(sa & sb)) == [3] |
| 198 | + sorted(list(sa - sb)) == [1, 2] |
| 199 | + sorted(list(sa | sb)) == [1, 2, 3, 4, 5] |
| 200 | + |
| 201 | + |
| 202 | +if __name__ == '__main__': |
| 203 | + test_set_adt() |
0 commit comments