Skip to content

Commit 00d9e0f

Browse files
author
Joey Tsai
committed
Change bert to batch prefill
1 parent 0cff7c9 commit 00d9e0f

File tree

6 files changed

+21
-21
lines changed

6 files changed

+21
-21
lines changed

examples/qualcomm/oss_scripts/llama2/llama.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def sample_top_p(probs: torch.Tensor, top_p: float) -> torch.Tensor:
229229
print(f"calibration data:\n{sp_model.decode(token_list)}")
230230

231231

232-
def _bert_calibrate(
232+
def _batch_prefill_calibrate(
233233
example_inputs,
234234
user_prompts,
235235
module: torch.fx.GraphModule,
@@ -273,7 +273,7 @@ def calibrate(
273273
max_seq_len=512,
274274
):
275275
if len(example_inputs) == 2:
276-
_bert_calibrate(
276+
_batch_prefill_calibrate(
277277
example_inputs,
278278
user_prompts,
279279
module,
@@ -332,7 +332,7 @@ def _tag_kv_ios(self, gm: torch.fx.GraphModule, kv_type):
332332
== self.llama_meta["get_head_dim"]
333333
):
334334
a.meta[QCOM_QUANTIZED_IO] = kv_type
335-
# single head, bert mode
335+
# single head, batch_prefill mode
336336
elif a.meta["val"].flatten().size()[0] == self.llama_meta[
337337
"get_head_dim"
338338
] * (self.llama_meta["get_max_seq_len"] - 1):
@@ -416,7 +416,7 @@ def compile(args):
416416

417417
if args.model_mode == "kv":
418418
use_kv_cache = output_new_cache_only = True
419-
elif args.model_mode == "bert" or args.model_mode == "hybrid":
419+
elif args.model_mode == "batch_prefill" or args.model_mode == "hybrid":
420420
raise NotImplementedError(
421421
f"model_mode {args.model_mode} is not implemented yet."
422422
)
@@ -653,9 +653,9 @@ def post_process():
653653

654654
parser.add_argument(
655655
"--model_mode",
656-
help="Export and inference bert mode, kv mode or hybrid(TBD) mode",
656+
help="Export and inference batch_prefill mode, kv mode or hybrid(TBD) mode",
657657
default="kv",
658-
choices=["bert", "kv", "hybrid"],
658+
choices=["batch_prefill", "kv", "hybrid"],
659659
type=str,
660660
)
661661

examples/qualcomm/oss_scripts/llama2/model/static_llama.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def apply_rotary_emb_single(
2121
) -> torch.Tensor:
2222
x_r, x_i = x[..., ::2], x[..., 1::2]
2323

24-
# brodcast for bert mode input x
24+
# brodcast for batch_prefill mode input x
2525
if x.dim() == 4:
2626
freqs_cos = freqs_cos[None, :, None, :]
2727
freqs_sin = freqs_sin[None, :, None, :]
@@ -111,7 +111,7 @@ def forward_sha(
111111
for i, _ in enumerate(k_caches):
112112
kh.append(torch.cat([k_caches[i], k[i]], dim=-1))
113113
vh.append(torch.cat([v_caches[i], v[i]], dim=1))
114-
# bert/prefill mode
114+
# batch_prefill mode
115115
else:
116116
kh = k
117117
vh = v
@@ -131,7 +131,7 @@ def forward_sha(
131131
if self.output_new_cache_only:
132132
if k_caches and v_caches:
133133
return y, k, v
134-
# bert mode. Consider to remove, it's not really used
134+
# batch_prefill mode. Consider to remove, it's not really used
135135
return y, k[-1], v[-1]
136136

137137
return y, kh, vh
@@ -172,7 +172,7 @@ def forward(
172172

173173
output_y.append(y)
174174

175-
# bert/prefill mode
175+
# batch_prefill mode
176176
else:
177177
kh = k
178178
vh = v

examples/qualcomm/oss_scripts/llama3_2/llama.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def _kv_calibrate(
103103
print(f"calibration data:\n{sp_model.decode(token_list)}")
104104

105105

106-
def _bert_calibrate(
106+
def _batch_prefill_calibrate(
107107
example_inputs,
108108
user_prompts,
109109
module: torch.fx.GraphModule,
@@ -147,7 +147,7 @@ def calibrate(
147147
max_seq_len=512,
148148
):
149149
if len(example_inputs) == 2:
150-
_bert_calibrate(
150+
_batch_prefill_calibrate(
151151
example_inputs,
152152
user_prompts,
153153
module,
@@ -206,7 +206,7 @@ def _tag_kv_ios(self, gm: torch.fx.GraphModule, kv_type, sharding_type):
206206
== self.llama_meta["get_head_dim"]
207207
):
208208
a.meta[QCOM_QUANTIZED_IO] = kv_type
209-
# single head, bert mode
209+
# single head, batch_prefill mode
210210
elif a.meta["val"].flatten().size()[0] == self.llama_meta[
211211
"get_head_dim"
212212
] * (self.llama_meta["get_max_seq_len"] - 1):
@@ -319,7 +319,7 @@ def compile(args):
319319

320320
if args.model_mode == "kv":
321321
use_kv_cache = output_new_cache_only = True
322-
elif args.model_mode == "bert":
322+
elif args.model_mode == "batch_prefill":
323323
use_kv_cache = output_new_cache_only = False
324324
elif args.model_mode == "hybrid":
325325
raise NotImplementedError(
@@ -409,7 +409,7 @@ def compile(args):
409409
def inference(args, pre_gen_pte=""):
410410
workspace = f"/data/local/tmp/{getpass.getuser()}/executorch/single_llama"
411411

412-
if args.model_mode == "bert":
412+
if args.model_mode == "batch_prefill":
413413
eval_mode = 0
414414
elif args.model_mode == "kv":
415415
eval_mode = 1
@@ -576,9 +576,9 @@ def post_process():
576576

577577
parser.add_argument(
578578
"--model_mode",
579-
help="Export and inference bert mode, kv mode or hybrid(TBD) mode",
579+
help="Export and inference batch_prefill mode, kv mode or hybrid(TBD) mode",
580580
default="kv",
581-
choices=["bert", "kv", "hybrid"],
581+
choices=["batch_prefill", "kv", "hybrid"],
582582
type=str,
583583
)
584584

examples/qualcomm/oss_scripts/llama3_2/qnn_llama3_2_runner.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ DEFINE_int32(
4646
DEFINE_int32(
4747
eval_mode,
4848
0,
49-
"0: PromptProcessor(bert) / 1: TokenGenerator(kv) / 2: HybridMode (TBD)");
49+
"0: PromptProcessor(batch_prefill) / 1: TokenGenerator(kv) / 2: HybridMode (TBD)");
5050

5151
int main(int argc, char** argv) {
5252
gflags::ParseCommandLineFlags(&argc, &argv, true);

examples/qualcomm/oss_scripts/llama3_2/runner/runner.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ Error Runner::load() {
107107

108108
// prepare io
109109
auto methods_meta = get_methods_meta();
110-
if (eval_mode_ == EvalMode::kBert) {
110+
if (eval_mode_ == EvalMode::kBatchPrefill) {
111111
io_mem_->prepare_prefill_io(methods_meta);
112112
} else {
113113
io_mem_->prepare_kv_io(methods_meta);
@@ -217,7 +217,7 @@ Error Runner::generate(
217217
HybridMemory::IO* ptr =
218218
static_cast<HybridMemory::IO*>(io_mem_->get_mutable_ptr());
219219

220-
if (eval_mode_ == EvalMode::kBert) {
220+
if (eval_mode_ == EvalMode::kBatchPrefill) {
221221
for (int i = 0; i < num_prompt_tokens; i++) {
222222
ptr->prefill_input_toks[i] = static_cast<int32_t>(prompt_tokens[i]);
223223
auto piece_res = tokenizer_->decode(prompt_tokens[i], prompt_tokens[i]);

examples/qualcomm/oss_scripts/llama3_2/runner/runner.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ class Runner {
7272

7373
private:
7474
enum EvalMode {
75-
kBert = 0,
75+
kBatchPrefill = 0,
7676
kKVCached,
7777
kUnsupported,
7878
};

0 commit comments

Comments
 (0)