Skip to content

Commit 4e7c352

Browse files
committed
[Embedding] Fix HBM_DRAM Restore core when id_num > cachesize
Signed-off-by: RobertLou <[email protected]>
1 parent 04413cf commit 4e7c352

File tree

1 file changed

+11
-7
lines changed

1 file changed

+11
-7
lines changed

tensorflow/core/framework/embedding/hbm_dram_storage.h

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ class HbmDramStorage : public MultiTierStorage<K, V> {
4949

5050
~HbmDramStorage() override {
5151
MultiTierStorage<K, V>::DeleteFromEvictionManager();
52+
//delete restore_cache_;
5253
delete hbm_;
5354
delete dram_;
5455
delete dram_feat_desc_;
@@ -227,7 +228,7 @@ class HbmDramStorage : public MultiTierStorage<K, V> {
227228
}
228229

229230
void BatchEviction() override {
230-
constexpr int EvictionSize = 10000;
231+
constexpr int EvictionSize = 5000;
231232
K evic_ids[EvictionSize];
232233
if (!MultiTierStorage<K, V>::ready_eviction_) {
233234
return;
@@ -287,16 +288,18 @@ class HbmDramStorage : public MultiTierStorage<K, V> {
287288
partition_id, partition_num,
288289
is_incr, reset_version, reader);
289290

291+
restore_cache_.reset(CacheFactory::Create<K>(CacheStrategy::LFU, "ads"));
290292
restorer.RestoreCkpt(emb_config, device);
291293

292294
int64 num_of_hbm_ids =
293295
std::min(MultiTierStorage<K, V>::cache_capacity_,
294-
(int64)MultiTierStorage<K, V>::cache_->size());
296+
(int64)restore_cache_->size());
297+
295298
if (num_of_hbm_ids > 0) {
296299
K* hbm_ids = new K[num_of_hbm_ids];
297300
int64* hbm_freqs = new int64[num_of_hbm_ids];
298301
int64* hbm_versions = nullptr;
299-
MultiTierStorage<K, V>::cache_->get_cached_ids(hbm_ids, num_of_hbm_ids,
302+
restore_cache_->get_cached_ids(hbm_ids, num_of_hbm_ids,
300303
hbm_versions, hbm_freqs);
301304
ImportToHbm(hbm_ids, num_of_hbm_ids, value_len, emb_config.emb_index);
302305
MultiTierStorage<K, V>::cache_thread_pool_->Schedule(
@@ -329,10 +332,10 @@ class HbmDramStorage : public MultiTierStorage<K, V> {
329332
Status s = filter->Restore(key_num, bucket_num, partition_id,
330333
partition_num, value_len, is_filter,
331334
true/*to_dram*/, is_incr, restore_buff);
332-
333-
MultiTierStorage<K, V>::cache_->update((K*)restore_buff.key_buffer, key_num,
334-
(int64*)restore_buff.version_buffer,
335-
(int64*)restore_buff.freq_buffer);
335+
336+
restore_cache_->update((K*)restore_buff.key_buffer, key_num,
337+
(int64*)restore_buff.version_buffer,
338+
(int64*)restore_buff.freq_buffer);
336339
return s;
337340
}
338341

@@ -574,6 +577,7 @@ class HbmDramStorage : public MultiTierStorage<K, V> {
574577
DramStorage<K, V>* dram_ = nullptr;
575578
FeatureDescriptor<V>* hbm_feat_desc_ = nullptr;
576579
FeatureDescriptor<V>* dram_feat_desc_ = nullptr;
580+
std::unique_ptr<BatchCache<K>> restore_cache_ = nullptr;
577581
Allocator* gpu_alloc_;
578582
const int copyback_flag_offset_bits_ = 60;
579583
};

0 commit comments

Comments
 (0)