@@ -494,7 +494,8 @@ def compile(args, pte_filename, tokenizer):
494
494
annotate_linear_16a8w_in_affine_layer ,
495
495
)
496
496
if args .ptq != None :
497
- kv_quant_attrs = {}
497
+ import hashlib
498
+ kv_quant_attrs , parameter_hash = {}, []
498
499
for i , llama_instance in enumerate (llama_instance_list ):
499
500
llama_instance .quantize (
500
501
quant_dtype = quant_dtype ,
@@ -517,6 +518,31 @@ def compile(args, pte_filename, tokenizer):
517
518
kv_quant_attrs = kv_quant_attrs ,
518
519
),
519
520
)
521
+
522
+ tensor_to_md5 = {}
523
+ for name , buffer in llama_instance .llama_model .named_buffers ():
524
+ md5_buffer = hashlib .md5 (buffer .numpy ().tobytes ()).hexdigest ()
525
+ if md5_buffer in tensor_to_md5 :
526
+ tensor_to_md5 [md5_buffer ].append (name )
527
+ else :
528
+ tensor_to_md5 [md5_buffer ] = [name ]
529
+ parameter_hash .append (tensor_to_md5 )
530
+
531
+ # check tensors in prefill & decode are exactly the same
532
+ assert len (parameter_hash [0 ]) == len (parameter_hash [1 ])
533
+ num_keys = len (parameter_hash [0 ])
534
+ # Remove common keys from both dictionaries
535
+ for key in set (parameter_hash [0 ]).intersection (set (parameter_hash [1 ])):
536
+ del parameter_hash [0 ][key ]
537
+ del parameter_hash [1 ][key ]
538
+ print (f"{ num_keys - len (parameter_hash [0 ])} / { num_keys } tensors are matched" )
539
+
540
+ for buf , name in parameter_hash [0 ].items (): # kv
541
+ print (f"KV buffers: { name } cannot find a match" )
542
+ for buf , name in parameter_hash [1 ].items (): # prefill
543
+ print (f"Prefill buffers: { name } cannot find a match" )
544
+
545
+
520
546
end_quantize_ts = time .time ()
521
547
logging .info (f"Time for quantizing: { end_quantize_ts - start_quantize_ts } " )
522
548
0 commit comments