Skip to content

Commit 76c16f1

Browse files
authored
only prepare data on rank 0 (deepspeedai#409)
* only prepare data on rank 0 * fix hash
1 parent 794660f commit 76c16f1

File tree

1 file changed

+9
-12
lines changed

1 file changed

+9
-12
lines changed

applications/DeepSpeed-Chat/training/utils/data/data_utils.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from datasets import load_dataset
1313
import numpy as np
1414
import os
15+
import hashlib
1516
from itertools import chain
1617
from . import raw_datasets
1718

@@ -75,7 +76,7 @@ def get_raw_dataset_split_index(local_rank, output_path, dataset_name, seed,
7576
split_name, data_split, split_index,
7677
data_size):
7778
index_file_name = f"{output_path}/{dataset_name}_seed{seed}_{split_name}_{data_split}_{split_index}.npy"
78-
if not os.path.isfile(index_file_name) and local_rank <= 0:
79+
if not os.path.isfile(index_file_name):
7980
splits = [float(s) for s in data_split.split(',')]
8081
splits_sum = sum(splits)
8182
splits = [split / splits_sum for split in splits]
@@ -96,7 +97,6 @@ def get_raw_dataset_split_index(local_rank, output_path, dataset_name, seed,
9697
np.save(shuffle_idx_split_file_name,
9798
shuffle_idx_split,
9899
allow_pickle=True)
99-
torch.distributed.barrier()
100100
index = np.load(index_file_name, allow_pickle=True)
101101
return index.tolist()
102102

@@ -254,18 +254,16 @@ def create_prompt_dataset(local_rank,
254254
tokenizer_name = tokenizer.init_kwargs["name_or_path"].replace("/", "_")
255255
fname = f"{fname}_split{data_split}_phase{train_phase}_seed{seed}_tokenizer{tokenizer_name}_seqlen{max_seq_len}_sft{sft_cache_key}"
256256
fname = "_".join(fname.split("/"))
257-
fname = str(hash(fname)) # hash the file name to avoid too long file name
257+
fname = hashlib.sha256(fname.encode()).hexdigest(
258+
) # hash the file name to avoid too long file name
258259
train_fname = f"{output_path}/traindata_{fname}.pt"
259260
eval_fname = f"{output_path}/evaldata_{fname}.pt"
260261

261262
cache_found = os.path.isfile(train_fname) and os.path.isfile(eval_fname)
262263
buf_create_cache = torch.ByteTensor([not cache_found]).cuda()
263264
torch.distributed.all_reduce(buf_create_cache)
264265

265-
# Skip creating cache if we found it on all the nodes.
266-
if buf_create_cache.item() == 0:
267-
return torch.load(train_fname), torch.load(eval_fname)
268-
else:
266+
if local_rank <= 0 and buf_create_cache.item() != 0:
269267
if len(data_path) == 1: # Single dataset.
270268
train_dataset, eval_dataset = create_dataset(
271269
local_rank, data_path[0], data_split, output_path, train_phase,
@@ -323,11 +321,10 @@ def create_prompt_dataset(local_rank,
323321
eval_dataset = ConcatDataset([eval_dataset, sft_eval_dataset])
324322
shuffle_idx = get_shuffle_idx(seed, len(eval_dataset))
325323
eval_dataset = Subset(eval_dataset, shuffle_idx.tolist())
326-
327-
if local_rank <= 0:
328-
torch.save(train_dataset, train_fname)
329-
torch.save(eval_dataset, eval_fname)
330-
return train_dataset, eval_dataset
324+
torch.save(train_dataset, train_fname)
325+
torch.save(eval_dataset, eval_fname)
326+
torch.distributed.barrier()
327+
return torch.load(train_fname), torch.load(eval_fname)
331328

332329

333330
class DataCollatorReward:

0 commit comments

Comments
 (0)