Skip to content

Commit a28e73b

Browse files
committed
Define embedding_4bit ops
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent f938acb commit a28e73b

File tree

3 files changed

+6
-4
lines changed

3 files changed

+6
-4
lines changed

examples/models/llama2/custom_ops/op_sdpa.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,6 @@ void cpu_flash_attention(
240240
" and num kv heads=%" PRId64,
241241
num_head,
242242
num_heads_kv);
243-
244243
int64_t num_reps = num_head / num_heads_kv;
245244

246245
bool has_attn_mask = attn_mask.has_value() && attn_mask.value().numel();

examples/models/llama2/export_llama_lib.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
Transformer,
3030
)
3131
from executorch.exir.backend.backend_details import CompileSpec
32-
from executorch.exir.passes import *
32+
3333
from executorch.sdk.etrecord import generate_etrecord
3434
from executorch.util.activation_memory_profiler import generate_memory_trace
3535
from sentencepiece import SentencePieceProcessor
@@ -539,7 +539,10 @@ def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager:
539539
bitwidth = int(bitwidth)
540540
transforms.append(
541541
lambda model: EmbeddingQuantHandler(
542-
model, bitwidth=bitwidth, group_size=group_size, packed=(bitwidth==4),
542+
model,
543+
bitwidth=bitwidth,
544+
group_size=group_size,
545+
packed=(bitwidth == 4),
543546
).quantized_model()
544547
)
545548

exir/passes/_quant_patterns_and_replacements.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def embedding_byte_dtype_out_meta(
189189

190190
quantized_decomposed_lib.define(
191191
"embedding_4bit.dtype_out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
192-
"int weight_quant_min, int weight_quant_max, Tensor indices, ScalarType? dtype=None, *, Tensor(a!) out) -> Tensor(a!)",
192+
"int weight_quant_min, int weight_quant_max, Tensor indices, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)",
193193
)
194194

195195

0 commit comments

Comments
 (0)