1212from datasets import load_dataset
1313import numpy as np
1414import os
15+ import hashlib
1516from itertools import chain
1617from . import raw_datasets
1718
@@ -75,7 +76,7 @@ def get_raw_dataset_split_index(local_rank, output_path, dataset_name, seed,
7576 split_name , data_split , split_index ,
7677 data_size ):
7778 index_file_name = f"{ output_path } /{ dataset_name } _seed{ seed } _{ split_name } _{ data_split } _{ split_index } .npy"
78- if not os .path .isfile (index_file_name ) and local_rank <= 0 :
79+ if not os .path .isfile (index_file_name ):
7980 splits = [float (s ) for s in data_split .split (',' )]
8081 splits_sum = sum (splits )
8182 splits = [split / splits_sum for split in splits ]
@@ -96,7 +97,6 @@ def get_raw_dataset_split_index(local_rank, output_path, dataset_name, seed,
9697 np .save (shuffle_idx_split_file_name ,
9798 shuffle_idx_split ,
9899 allow_pickle = True )
99- torch .distributed .barrier ()
100100 index = np .load (index_file_name , allow_pickle = True )
101101 return index .tolist ()
102102
@@ -254,18 +254,16 @@ def create_prompt_dataset(local_rank,
254254 tokenizer_name = tokenizer .init_kwargs ["name_or_path" ].replace ("/" , "_" )
255255 fname = f"{ fname } _split{ data_split } _phase{ train_phase } _seed{ seed } _tokenizer{ tokenizer_name } _seqlen{ max_seq_len } _sft{ sft_cache_key } "
256256 fname = "_" .join (fname .split ("/" ))
257- fname = str (hash (fname )) # hash the file name to avoid too long file name
257+ fname = hashlib .sha256 (fname .encode ()).hexdigest (
258+ ) # hash the file name to avoid too long file name
258259 train_fname = f"{ output_path } /traindata_{ fname } .pt"
259260 eval_fname = f"{ output_path } /evaldata_{ fname } .pt"
260261
261262 cache_found = os .path .isfile (train_fname ) and os .path .isfile (eval_fname )
262263 buf_create_cache = torch .ByteTensor ([not cache_found ]).cuda ()
263264 torch .distributed .all_reduce (buf_create_cache )
264265
265- # Skip creating cache if we found it on all the nodes.
266- if buf_create_cache .item () == 0 :
267- return torch .load (train_fname ), torch .load (eval_fname )
268- else :
266+ if local_rank <= 0 and buf_create_cache .item () != 0 :
269267 if len (data_path ) == 1 : # Single dataset.
270268 train_dataset , eval_dataset = create_dataset (
271269 local_rank , data_path [0 ], data_split , output_path , train_phase ,
@@ -323,11 +321,10 @@ def create_prompt_dataset(local_rank,
323321 eval_dataset = ConcatDataset ([eval_dataset , sft_eval_dataset ])
324322 shuffle_idx = get_shuffle_idx (seed , len (eval_dataset ))
325323 eval_dataset = Subset (eval_dataset , shuffle_idx .tolist ())
326-
327- if local_rank <= 0 :
328- torch .save (train_dataset , train_fname )
329- torch .save (eval_dataset , eval_fname )
330- return train_dataset , eval_dataset
324+ torch .save (train_dataset , train_fname )
325+ torch .save (eval_dataset , eval_fname )
326+ torch .distributed .barrier ()
327+ return torch .load (train_fname ), torch .load (eval_fname )
331328
332329
333330class DataCollatorReward :
0 commit comments