Skip to content

Commit 7977cc2

Browse files
committed
Patch
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 0da7b40 commit 7977cc2

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

examples/models/llama2/export_llama_lib.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
Transformer,
3030
)
3131
from executorch.exir.backend.backend_details import CompileSpec
32-
32+
from executorch.exir.passes import *
3333
from executorch.sdk.etrecord import generate_etrecord
3434
from executorch.util.activation_memory_profiler import generate_memory_trace
3535
from sentencepiece import SentencePieceProcessor
@@ -541,7 +541,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager:
541541
bitwidth = int(bitwidth)
542542
transforms.append(
543543
lambda model: EmbeddingOnlyInt8QuantHandler(
544-
model, bitwidth=bitwidth, group_size=group_size
544+
model, bitwidth=bitwidth, group_size=group_size, packed=(bitwidth==4),
545545
).quantized_model()
546546
)
547547

examples/models/llama2/quantize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,6 @@ def forward(self, indices: torch.Tensor) -> torch.Tensor:
438438
self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype
439439
)
440440
else: # 4bit packed
441-
return torch.ops.llama_quantized.embedding_4bit.dtype(
441+
return torch.ops.quantized_decomposed.embedding_4bit.dtype(
442442
self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype
443443
)

0 commit comments

Comments
 (0)