@@ -101,7 +101,7 @@ def quantize_q5_1(x):
101101def quantize_q8_0 (x ):
102102 assert x .shape [- 1 ] % QK8_0 == 0 and x .shape [- 1 ] > QK8_0
103103 x = x .reshape (- 1 , QK8_0 )
104- amax = np .max (np .abs (x ), axis = - 1 , keepdims = True )
104+ amax = np .max (np .abs (x ), axis = - 1 , keepdims = True )
105105 d = amax / ((1 << 7 ) - 1 )
106106 qs = (x / d ).round ().clip (min = - 128 , max = 127 ).astype (np .int8 )
107107 d = d .astype (np .float16 ).view (np .int8 )
@@ -178,7 +178,7 @@ def preprocess(state_dict):
178178 print ("no alphas_cumprod in file, generate new one" )
179179 alphas_cumprod = get_alpha_comprod ()
180180 state_dict ["alphas_cumprod" ] = alphas_cumprod
181-
181+
182182 new_state_dict = {}
183183 for name , w in state_dict .items ():
184184 # ignore unused tensors
@@ -192,7 +192,7 @@ def preprocess(state_dict):
192192 if skip :
193193 continue
194194
195- # # convert BF16 to FP16
195+ # convert BF16 to FP16
196196 if w .dtype == torch .bfloat16 :
197197 w = w .to (torch .float16 )
198198
@@ -251,7 +251,7 @@ def preprocess(state_dict):
251251 new_state_dict [new_name ] = w
252252 print (f"preprocess { name } => { new_name } " )
253253 continue
254-
254+
255255 # convert unet transformer linear to conv2d 1x1
256256 if name .startswith ("model.diffusion_model." ) and (name .endswith ("proj_in.weight" ) or name .endswith ("proj_out.weight" )):
257257 if len (w .shape ) == 2 :
@@ -342,6 +342,11 @@ def preprocess_lora(state_dict):
342342 for name , w in state_dict .items ():
343343 if not isinstance (w , torch .Tensor ):
344344 continue
345+
346+ # convert BF16 to FP16
347+ if w .dtype == torch .bfloat16 :
348+ w = w .to (torch .float16 )
349+
345350 name_without_network_parts , network_part = name .split ("." , 1 )
346351 new_name_without_network_parts = convert_diffusers_name_to_compvis (name_without_network_parts )
347352 if new_name_without_network_parts == None :
@@ -421,6 +426,7 @@ def convert(model_path, out_type = None, out_file=None, lora=False):
421426 continue
422427 if name in unused_tensors :
423428 continue
429+
424430 data = state_dict [name ].numpy ()
425431
426432 n_dims = len (data .shape )
@@ -452,7 +458,7 @@ def convert(model_path, out_type = None, out_file=None, lora=False):
452458 else :
453459 data = data .astype (np .float32 )
454460 ttype = "f32"
455-
461+
456462 print ("Processing tensor: {} with shape {}, {} -> {}" .format (name , data .shape , old_type , ttype ))
457463
458464 # header
0 commit comments