Skip to content

Commit f0e3ed0

Browse files
Michael Gschwindfacebook-github-bot
authored andcommitted
Remove exir.capture and its warnings
Summary: Remove exir.capture and its warnings Differential Revision: D55039873
1 parent 0475af0 commit f0e3ed0

File tree

2 files changed

+230
-198
lines changed

2 files changed

+230
-198
lines changed

examples/models/llama2/quantize.py

Lines changed: 38 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ class WeightOnlyInt8Linear(torch.nn.Module):
247247
__constants__ = ["in_features", "out_features"]
248248
in_features: int
249249
out_features: int
250-
weight: torch.Tensor
250+
# weight: torch.Tensor
251251

252252
def __init__(
253253
self,
@@ -260,10 +260,15 @@ def __init__(
260260
super().__init__()
261261
self.in_features = in_features
262262
self.out_features = out_features
263-
self.register_buffer(
264-
"weight", torch.empty((out_features, in_features), dtype=torch.int8)
263+
self.register_parameter(
264+
"weight",
265+
torch.nn.Parameter(
266+
torch.empty((out_features, in_features), dtype=torch.int8)
267+
),
268+
)
269+
self.register_parameter(
270+
"scales", torch.nn.Parameter(torch.ones(out_features, dtype=torch.bfloat16))
265271
)
266-
self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16))
267272

268273
def forward(self, input: torch.Tensor) -> torch.Tensor:
269274
return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales
@@ -372,17 +377,24 @@ def __init__(
372377
group_size = embedding_dim
373378
self.group_size = group_size
374379
self.dtype = dtype
375-
self.register_buffer(
376-
"weight", torch.empty((vocab_size, embedding_dim), dtype=torch.int8)
380+
self.register_parameter(
381+
"weight",
382+
torch.nn.Parameter(
383+
torch.empty((vocab_size, embedding_dim), dtype=torch.int8)
384+
),
377385
)
378386
groups_per_row = (embedding_dim + group_size - 1) // group_size
379387
if groups_per_row > 1:
380-
self.register_buffer(
381-
"scales", torch.ones((vocab_size, groups_per_row), dtype=torch.float16)
388+
self.register_parameter(
389+
"scales",
390+
torch.nn.Parameter(
391+
torch.ones((vocab_size, groups_per_row), dtype=torch.float16)
392+
),
382393
)
383394
else:
384-
self.register_buffer(
385-
"scales", torch.ones((vocab_size,), dtype=torch.float16)
395+
self.register_parameter(
396+
"scales",
397+
torch.nn.Parameter(torch.ones((vocab_size,), dtype=torch.float16)),
386398
)
387399

388400
@torch.no_grad()
@@ -583,7 +595,7 @@ class Int8DynActInt4WeightLinear(torch.nn.Module):
583595

584596
in_features: int
585597
out_features: int
586-
weight: torch.Tensor
598+
# weight: torch.Tensor
587599

588600
"""
589601
This module implements a dynamic quantized linear layer with int4 weight.
@@ -624,28 +636,30 @@ def __init__(
624636
self.precision = precision
625637

626638
# currently storing unpacked int8 weights
627-
# TODO: ????!!!!!
628-
# weights should be registers as parameters, since they're
629-
# read-only for inference
630-
self.register_buffer(
639+
self.register_parameter(
631640
"weight",
632-
torch.empty((out_features, in_features), dtype=torch.int8),
641+
torch.nn.Parameter(
642+
torch.empty((out_features, in_features), dtype=torch.int8)
643+
),
633644
)
634-
self.register_buffer(
645+
self.register_parameter(
635646
"scales",
636-
torch.empty(
637-
(out_features, in_features // group_size),
638-
dtype=scales_precision,
647+
torch.nn.Parameter(
648+
torch.empty(
649+
(out_features, in_features // group_size), dtype=scales_precision
650+
),
639651
),
640652
)
641653
# TODO:
642654
# Let's not store 0 - and then have to process them?!
643655
# All our quantization is symmetric.
644-
self.register_buffer(
656+
self.register_parameter(
645657
"zeros",
646-
torch.empty(
647-
(out_features, in_features // group_size),
648-
dtype=scales_precision,
658+
torch.nn.Parameter(
659+
torch.empty(
660+
(out_features, in_features // group_size),
661+
dtype=scales_precision,
662+
)
649663
),
650664
)
651665

0 commit comments

Comments
 (0)