@@ -247,7 +247,7 @@ class WeightOnlyInt8Linear(torch.nn.Module):
247
247
__constants__ = ["in_features" , "out_features" ]
248
248
in_features : int
249
249
out_features : int
250
- weight : torch .Tensor
250
+ # weight: torch.Tensor
251
251
252
252
def __init__ (
253
253
self ,
@@ -260,10 +260,15 @@ def __init__(
260
260
super ().__init__ ()
261
261
self .in_features = in_features
262
262
self .out_features = out_features
263
- self .register_buffer (
264
- "weight" , torch .empty ((out_features , in_features ), dtype = torch .int8 )
263
+ self .register_parameter (
264
+ "weight" ,
265
+ torch .nn .Parameter (
266
+ torch .empty ((out_features , in_features ), dtype = torch .int8 )
267
+ ),
268
+ )
269
+ self .register_parameter (
270
+ "scales" , torch .nn .Parameter (torch .ones (out_features , dtype = torch .bfloat16 ))
265
271
)
266
- self .register_buffer ("scales" , torch .ones (out_features , dtype = torch .bfloat16 ))
267
272
268
273
def forward (self , input : torch .Tensor ) -> torch .Tensor :
269
274
return F .linear (input , self .weight .to (dtype = input .dtype )) * self .scales
@@ -372,17 +377,24 @@ def __init__(
372
377
group_size = embedding_dim
373
378
self .group_size = group_size
374
379
self .dtype = dtype
375
- self .register_buffer (
376
- "weight" , torch .empty ((vocab_size , embedding_dim ), dtype = torch .int8 )
380
+ self .register_parameter (
381
+ "weight" ,
382
+ torch .nn .Parameter (
383
+ torch .empty ((vocab_size , embedding_dim ), dtype = torch .int8 )
384
+ ),
377
385
)
378
386
groups_per_row = (embedding_dim + group_size - 1 ) // group_size
379
387
if groups_per_row > 1 :
380
- self .register_buffer (
381
- "scales" , torch .ones ((vocab_size , groups_per_row ), dtype = torch .float16 )
388
+ self .register_parameter (
389
+ "scales" ,
390
+ torch .nn .Parameter (
391
+ torch .ones ((vocab_size , groups_per_row ), dtype = torch .float16 )
392
+ ),
382
393
)
383
394
else :
384
- self .register_buffer (
385
- "scales" , torch .ones ((vocab_size ,), dtype = torch .float16 )
395
+ self .register_parameter (
396
+ "scales" ,
397
+ torch .nn .Parameter (torch .ones ((vocab_size ,), dtype = torch .float16 )),
386
398
)
387
399
388
400
@torch .no_grad ()
@@ -583,7 +595,7 @@ class Int8DynActInt4WeightLinear(torch.nn.Module):
583
595
584
596
in_features : int
585
597
out_features : int
586
- weight : torch .Tensor
598
+ # weight: torch.Tensor
587
599
588
600
"""
589
601
This module implements a dynamic quantized linear layer with int4 weight.
@@ -624,28 +636,30 @@ def __init__(
624
636
self .precision = precision
625
637
626
638
# currently storing unpacked int8 weights
627
- # TODO: ????!!!!!
628
- # weights should be registers as parameters, since they're
629
- # read-only for inference
630
- self .register_buffer (
639
+ self .register_parameter (
631
640
"weight" ,
632
- torch .empty ((out_features , in_features ), dtype = torch .int8 ),
641
+ torch .nn .Parameter (
642
+ torch .empty ((out_features , in_features ), dtype = torch .int8 )
643
+ ),
633
644
)
634
- self .register_buffer (
645
+ self .register_parameter (
635
646
"scales" ,
636
- torch .empty (
637
- (out_features , in_features // group_size ),
638
- dtype = scales_precision ,
647
+ torch .nn .Parameter (
648
+ torch .empty (
649
+ (out_features , in_features // group_size ), dtype = scales_precision
650
+ ),
639
651
),
640
652
)
641
653
# TODO:
642
654
# Let's not store 0 - and then have to process them?!
643
655
# All our quantization is symmetric.
644
- self .register_buffer (
656
+ self .register_parameter (
645
657
"zeros" ,
646
- torch .empty (
647
- (out_features , in_features // group_size ),
648
- dtype = scales_precision ,
658
+ torch .nn .Parameter (
659
+ torch .empty (
660
+ (out_features , in_features // group_size ),
661
+ dtype = scales_precision ,
662
+ )
649
663
),
650
664
)
651
665
0 commit comments