|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | + |
| 3 | +# 从数组和列表章复制的代码 |
| 4 | + |
| 5 | + |
| 6 | +class Array(object): |
| 7 | + |
| 8 | + def __init__(self, size=32): |
| 9 | + self._size = size |
| 10 | + self._items = [None] * size |
| 11 | + |
| 12 | + def __getitem__(self, index): |
| 13 | + return self._items[index] |
| 14 | + |
| 15 | + def __setitem__(self, index, value): |
| 16 | + self._items[index] = value |
| 17 | + |
| 18 | + def __len__(self): |
| 19 | + return self._size |
| 20 | + |
| 21 | + def clear(self, value=None): |
| 22 | + for i in range(self._items): |
| 23 | + self._items[i] = value |
| 24 | + |
| 25 | + def __iter__(self): |
| 26 | + for item in self._items: |
| 27 | + yield item |
| 28 | + |
| 29 | + |
| 30 | +class Slot(object): |
| 31 | + """定义一个 hash 表 数组的槽 |
| 32 | + 注意,一个槽有三种状态,看你能否想明白。相比链接法解决冲突,二次探查法删除一个 key 的操作稍微复杂。 |
| 33 | + 1.从未使用 HashMap.UNUSED。此槽没有被使用和冲突过,查找时只要找到 UNUSED 就不用再继续探查了 |
| 34 | + 2.使用过但是 remove 了,此时是 HashMap.EMPTY,该探查点后边的元素扔可能是有key |
| 35 | + 3.槽正在使用 Slot 节点 |
| 36 | + """ |
| 37 | + |
| 38 | + def __init__(self, key, value): |
| 39 | + self.key, self.value = key, value |
| 40 | + |
| 41 | + |
| 42 | +class HashTable(object): |
| 43 | + |
| 44 | + UNUSED = None # 没被使用过的槽,作为该类变量的一个单例,下边都是is 判断 |
| 45 | + EMPTY = Slot(None, None) # 使用过但是被删除的槽 |
| 46 | + |
| 47 | + def __init__(self): |
| 48 | + self._table = Array(7) |
| 49 | + self.length = 0 |
| 50 | + |
| 51 | + @property |
| 52 | + def _load_factor(self): |
| 53 | + # load factor 超过 2/3 就重新分配空间 |
| 54 | + return self.length / float(len(self._table)) |
| 55 | + |
| 56 | + def __len__(self): |
| 57 | + return self.length |
| 58 | + |
| 59 | + def _hash1(self, key): |
| 60 | + """ 计算key的hash值""" |
| 61 | + return abs(hash(key)) % len(self._table) |
| 62 | + |
| 63 | + def _hash2(self, key): |
| 64 | + """ key冲突时候用来计算新槽的位置""" |
| 65 | + return 1 + abs(hash(key)) % (len(self._table) - 2) |
| 66 | + |
| 67 | + def _find_slot(self, key, for_insert=False): |
| 68 | + """_find_slot |
| 69 | +
|
| 70 | + :param key: |
| 71 | + :param for_insert: 是否插入,还是仅仅查询 |
| 72 | + :return: slot index or None |
| 73 | + """ |
| 74 | + index = self._hash1(key) |
| 75 | + step = self._hash2(key) |
| 76 | + _len = len(self._table) |
| 77 | + |
| 78 | + if not for_insert: # 查找是否存在 key |
| 79 | + while self._table[index] is not HashTable.UNUSED: |
| 80 | + if self._table[index] is HashTable.EMPTY: |
| 81 | + index = (index + step) % _len |
| 82 | + continue |
| 83 | + elif self._table[index].key == key: |
| 84 | + return index |
| 85 | + index = (index + step) % _len |
| 86 | + return None |
| 87 | + else: |
| 88 | + while not self._slot_can_insert(index): # 循环直到找到一个可以插入的槽 |
| 89 | + index = (index + step) % _len |
| 90 | + return index |
| 91 | + |
| 92 | + def _slot_can_insert(self, index): |
| 93 | + return (self._table[index] is HashTable.EMPTY or self._table[index] is HashTable.UNUSED) |
| 94 | + |
| 95 | + def __contains__(self, key): # in operator |
| 96 | + index = self._find_slot(key, for_insert=False) |
| 97 | + return index is not None |
| 98 | + |
| 99 | + def add(self, key, value): |
| 100 | + if key in self: # key 相同值不一样的时候,用新的值 |
| 101 | + index = self._find_slot(key, for_insert=False) |
| 102 | + self._table[index].value = value |
| 103 | + return False |
| 104 | + else: |
| 105 | + index = self._find_slot(key, for_insert=True) |
| 106 | + self._table[index] = Slot(key, value) |
| 107 | + self.length += 1 |
| 108 | + if self._load_factor >= 0.8: # 注意超过了 阈值 rehashing |
| 109 | + self._rehash() |
| 110 | + return True |
| 111 | + |
| 112 | + def _rehash(self): |
| 113 | + old_table = self._table |
| 114 | + newsize = len(self._table) * 2 + 1 # 扩大 2*n + 1 |
| 115 | + self._table = Array(newsize) |
| 116 | + |
| 117 | + self.length = 0 |
| 118 | + |
| 119 | + for slot in old_table: |
| 120 | + if slot is not HashTable.UNUSED and slot is not HashTable.EMPTY: |
| 121 | + index = self._find_slot(slot.key, for_insert=True) |
| 122 | + self._table[index] = slot |
| 123 | + self.length += 1 |
| 124 | + |
| 125 | + def get(self, key, default=None): |
| 126 | + index = self._find_slot(key, for_insert=False) |
| 127 | + if index is None: |
| 128 | + return default |
| 129 | + else: |
| 130 | + return self._table[index].value |
| 131 | + |
| 132 | + def remove(self, key): |
| 133 | + assert key in self, 'keyerror' |
| 134 | + index = self._find_slot(key, for_insert=False) |
| 135 | + value = self._table[index].value |
| 136 | + self.length -= 1 |
| 137 | + self._table[index] = HashTable.EMPTY |
| 138 | + return value |
| 139 | + |
| 140 | + def __iter__(self): |
| 141 | + for slot in self._table: |
| 142 | + if slot not in (HashTable.EMPTY, HashTable.UNUSED): |
| 143 | + yield slot.key # 和 python dict 一样,默认遍历 key,需要value 的话写个 items() 方法 |
| 144 | + |
| 145 | + |
| 146 | +def test_hash_table(): |
| 147 | + h = HashTable() |
| 148 | + h.add('a', 0) |
| 149 | + h.add('b', 1) |
| 150 | + h.add('c', 2) |
| 151 | + |
| 152 | + assert len(h) == 3 |
| 153 | + assert h.get('a') == 0 |
| 154 | + assert h.get('b') == 1 |
| 155 | + assert h.get('hehe') is None |
| 156 | + |
| 157 | + h.remove('a') |
| 158 | + assert h.get('a') is None |
| 159 | + |
| 160 | + assert sorted(list(h)) == ['b', 'c'] |
| 161 | + |
| 162 | + for i in range(50): |
| 163 | + h.add(i, i) |
| 164 | + |
| 165 | + for i in range(50): |
| 166 | + assert h.get(i) == i |
0 commit comments