Skip to content

Commit 4b7050d

Browse files
committed
Define embedding_4bit ops
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent f938acb commit 4b7050d

File tree

3 files changed

+2
-3
lines changed

3 files changed

+2
-3
lines changed

examples/models/llama2/custom_ops/op_sdpa.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,6 @@ void cpu_flash_attention(
240240
" and num kv heads=%" PRId64,
241241
num_head,
242242
num_heads_kv);
243-
244243
int64_t num_reps = num_head / num_heads_kv;
245244

246245
bool has_attn_mask = attn_mask.has_value() && attn_mask.value().numel();

examples/models/llama2/export_llama_lib.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
Transformer,
3030
)
3131
from executorch.exir.backend.backend_details import CompileSpec
32-
from executorch.exir.passes import *
32+
3333
from executorch.sdk.etrecord import generate_etrecord
3434
from executorch.util.activation_memory_profiler import generate_memory_trace
3535
from sentencepiece import SentencePieceProcessor

exir/passes/_quant_patterns_and_replacements.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def embedding_byte_dtype_out_meta(
189189

190190
quantized_decomposed_lib.define(
191191
"embedding_4bit.dtype_out(Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, "
192-
"int weight_quant_min, int weight_quant_max, Tensor indices, ScalarType? dtype=None, *, Tensor(a!) out) -> Tensor(a!)",
192+
"int weight_quant_min, int weight_quant_max, Tensor indices, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)",
193193
)
194194

195195

0 commit comments

Comments
 (0)