Skip to content

Commit d89d3e7

Browse files
chunit-quicJoey Tsai
andauthored
Qualcomm AI Engine Direct - Suport batch prefill mode for llama3.2 (#6983)
* Qualcomm AI Engine Direct - Suport bert mode for llama3.2 - Enable bert mode - Change input sequence of static_llama - Tag bert output as uint8 - Unify both 1b and 3b in 1 runner - Add hybrid IO memory for llama3_2 runner - Align timer with llama * Rebase and minor fix - Fix rebase conflict - Change input dtype of calibration function * Change bert to batch prefill * Fix compile error * Fix lint - Fix transformers version - Refine pass quantization tagging function - Rebase * Add one line in the end of CmakeList * Remove trailing line of CmakeList * Move noqa to correct line number --------- Co-authored-by: Joey Tsai <[email protected]>
1 parent 22a75be commit d89d3e7

File tree

12 files changed

+702
-246
lines changed

12 files changed

+702
-246
lines changed

backends/qualcomm/quantizer/custom_annotation.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@
2222
from torch.fx import Node
2323

2424

25-
def annotate_matmul_16a8w(gm: torch.fx.GraphModule) -> None: # noqa: C901
25+
def annotate_matmul_16a8w( # noqa: C901
26+
gm: torch.fx.GraphModule, traverse_input1=True
27+
) -> None:
2628
"""
2729
This function is specific for matmul op 16a8w.
2830
"""
@@ -99,7 +101,8 @@ def annotate_matmul_input1(node: Node):
99101
for node in gm.graph.nodes:
100102
if node.op == "call_function" and node.target == torch.ops.aten.matmul.default:
101103
annotate_matmul(node, quantization_config_16a8w)
102-
annotate_matmul_input1(node.args[1])
104+
if traverse_input1:
105+
annotate_matmul_input1(node.args[1])
103106

104107

105108
def custom_annotate_llama_matmul_16a8w(gm: torch.fx.GraphModule) -> None: # noqa: C901

examples/qualcomm/oss_scripts/llama2/llama.py

Lines changed: 120 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -177,14 +177,15 @@ def annotate_conv2d(node: Node, quantization_config: QuantizationConfig) -> None
177177
)
178178

179179

180-
def calibrate(
180+
def _kv_calibrate(
181181
example_inputs,
182182
user_prompts,
183183
module: torch.fx.GraphModule,
184184
tokenizer_model_path="tokenizer.model",
185+
max_seq_len=512,
185186
):
186187
sp_model = SentencePieceProcessor(model_file=tokenizer_model_path)
187-
_, _, atten_mask, k_caches, v_caches = example_inputs
188+
_, atten_mask, _, k_caches, v_caches = example_inputs
188189

189190
# TODO: change criteria & support batch inputs if necessary
190191
pos = torch.tensor(0, dtype=torch.int32)
@@ -202,11 +203,11 @@ def sample_top_p(probs: torch.Tensor, top_p: float) -> torch.Tensor:
202203
return probs_indices.gather(dim=-1, index=next_token)
203204

204205
with torch.no_grad():
205-
while token_list[-1] != sp_model.eos_id() and pos < 128:
206+
while token_list[-1] != sp_model.eos_id() and pos < max_seq_len - 1:
206207
logits, new_k_caches, new_v_caches = module(
207208
torch.full((1, 1), token_list[pos]),
208-
torch.full((1, 1), pos),
209209
atten_mask,
210+
torch.full((1, 1), pos),
210211
*k_caches,
211212
*v_caches,
212213
)
@@ -228,15 +229,84 @@ def sample_top_p(probs: torch.Tensor, top_p: float) -> torch.Tensor:
228229
print(f"calibration data:\n{sp_model.decode(token_list)}")
229230

230231

232+
def _batch_prefill_calibrate(
233+
example_inputs,
234+
user_prompts,
235+
module: torch.fx.GraphModule,
236+
tokenizer_model_path="tokenizer.model",
237+
max_seq_len=512,
238+
):
239+
sp_model = SentencePieceProcessor(model_file=tokenizer_model_path)
240+
_, atten_mask = example_inputs
241+
max_cache_len = max_seq_len - 1
242+
243+
# TODO: change criteria & support batch inputs if necessary
244+
token_list = sp_model.encode(user_prompts, bos=True, eos=False)
245+
token_list = torch.tensor(token_list)[:max_cache_len].reshape(1, -1)
246+
last_prompt_pos = token_list.numel()
247+
if last_prompt_pos < max_cache_len:
248+
token_list = torch.cat(
249+
[
250+
token_list,
251+
torch.zeros((1, max_cache_len - last_prompt_pos), dtype=torch.int32),
252+
],
253+
dim=1,
254+
)
255+
else:
256+
token_list = token_list[:, :max_cache_len]
257+
258+
with torch.no_grad():
259+
logits, new_k_caches, new_v_caches = module(
260+
token_list,
261+
atten_mask,
262+
)
263+
predict = [torch.argmax(logits[:, last_prompt_pos - 1], dim=-1).item()]
264+
265+
print(f"calibration data:\n{sp_model.decode(predict)}")
266+
267+
268+
def calibrate(
269+
example_inputs,
270+
user_prompts,
271+
module: torch.fx.GraphModule,
272+
tokenizer_model_path="tokenizer.model",
273+
max_seq_len=512,
274+
):
275+
if len(example_inputs) == 2:
276+
_batch_prefill_calibrate(
277+
example_inputs,
278+
user_prompts,
279+
module,
280+
tokenizer_model_path,
281+
max_seq_len,
282+
)
283+
elif len(example_inputs) == 5:
284+
_kv_calibrate(
285+
example_inputs,
286+
user_prompts,
287+
module,
288+
tokenizer_model_path,
289+
max_seq_len,
290+
)
291+
else:
292+
raise RuntimeError("Get wrong inputs")
293+
294+
231295
class SingleLlama:
232296
def __init__(self, llama_model) -> None:
233297
super().__init__()
234298
self.llama_model = llama_model
235299
self.quant_dtype = None
236300
self.llama_meta = self.llama_model.get_metadata()
237301
self.has_quant_io = False
238-
tokens, pos_ids, atten_mask, k_caches, v_caches = self.get_example_inputs()
239-
self.inputs = (tokens, pos_ids, atten_mask, *k_caches, *v_caches)
302+
if self.llama_meta["get_use_kv_cache"]:
303+
tokens, atten_mask, pos_ids, k_caches, v_caches = self.get_example_inputs(
304+
use_kv_cache=True
305+
)
306+
self.inputs = (tokens, atten_mask, pos_ids, *k_caches, *v_caches)
307+
else:
308+
tokens, atten_mask = self.get_example_inputs(use_kv_cache=False)
309+
self.inputs = (tokens, atten_mask)
240310

241311
def _tag_kv_ios(self, gm: torch.fx.GraphModule, kv_type):
242312
if not self.has_quant_io:
@@ -256,11 +326,17 @@ def _tag_kv_ios(self, gm: torch.fx.GraphModule, kv_type):
256326
n.meta[QCOM_QUANTIZED_IO] = kv_type
257327
elif n.op == "output":
258328
for a in n.args[0]:
329+
# single head, kv mode
259330
if (
260331
a.meta["val"].flatten().size()[0]
261332
== self.llama_meta["get_head_dim"]
262333
):
263334
a.meta[QCOM_QUANTIZED_IO] = kv_type
335+
# single head, batch_prefill mode
336+
elif a.meta["val"].flatten().size()[0] == self.llama_meta[
337+
"get_head_dim"
338+
] * (self.llama_meta["get_max_seq_len"] - 1):
339+
a.meta[QCOM_QUANTIZED_IO] = kv_type
264340

265341
def quantize(self, quant_dtype, custom_annotations=()):
266342
self.quant_dtype = quant_dtype
@@ -281,11 +357,13 @@ def quantize(self, quant_dtype, custom_annotations=()):
281357
).module()
282358
fx_graph_module = prepare_pt2e(fx_graph_module, quantizer)
283359
print("Quantizing the model...")
360+
284361
calibrate(
285-
self.get_example_inputs(),
362+
self.get_example_inputs(self.llama_meta["get_use_kv_cache"]),
286363
args.prompt,
287364
fx_graph_module,
288365
tokenizer_model_path=args.tokenizer_model,
366+
max_seq_len=args.seq_len,
289367
)
290368

291369
self.llama_model = convert_pt2e(fx_graph_module)
@@ -328,18 +406,29 @@ def lowering_modules(
328406
with open(f"{work_space}/{pte_filename}.pte", "wb") as file:
329407
exec_prog_mgr.write_to_file(file)
330408

331-
def get_example_inputs(self):
332-
return self.llama_model.get_example_inputs()
409+
def get_example_inputs(self, use_kv_cache=True):
410+
return self.llama_model.get_example_inputs(use_kv_cache)
333411

334412

335413
def compile(args):
336414
os.makedirs(args.artifact, exist_ok=True)
337415
start_ts = time.time()
416+
417+
if args.model_mode == "kv":
418+
use_kv_cache = output_new_cache_only = True
419+
elif args.model_mode == "batch_prefill" or args.model_mode == "hybrid":
420+
raise NotImplementedError(
421+
f"model_mode {args.model_mode} is not implemented yet."
422+
)
423+
else:
424+
raise RuntimeError(f"No such model_mode {args.model_mode}.")
425+
338426
with open(args.params) as f:
339427
config = ModelArgs(**json.load(f))
340428
# TODO: support batch inputs if necessary
341429
config.max_batch_size = 1
342-
config.max_seq_len = 1024
430+
config.max_seq_len = args.seq_len
431+
config.use_kv_cache = use_kv_cache
343432
state_dict = torch.load(
344433
args.checkpoint, weights_only=True, map_location="cpu", mmap=True
345434
)
@@ -348,7 +437,7 @@ def compile(args):
348437

349438
llama_instance = None
350439
with torch.device("meta"):
351-
llama_instance = LlamaModel(config, output_new_cache_only=True)
440+
llama_instance = LlamaModel(config, output_new_cache_only=output_new_cache_only)
352441
if "model" in state_dict:
353442
state_dict = state_dict["model"]
354443
llama_instance.load_state_dict(
@@ -398,6 +487,11 @@ def compile(args):
398487
def inference(args, pre_gen_pte=""):
399488
workspace = f"/data/local/tmp/{getpass.getuser()}/executorch/single_llama"
400489

490+
if args.model_mode != "kv":
491+
raise NotImplementedError(
492+
f"model_mode {args.model_mode} is not implemented yet."
493+
)
494+
401495
runner_args = " ".join(
402496
[
403497
f"--model_path {pte_filename}.pte",
@@ -550,6 +644,21 @@ def post_process():
550644
type=str,
551645
)
552646

647+
parser.add_argument(
648+
"--num_sharding",
649+
type=int,
650+
default=0,
651+
help="Specify the number of splits by inserting the fallback custom op. The graph will be split evenly by layers.",
652+
)
653+
654+
parser.add_argument(
655+
"--model_mode",
656+
help="Export and inference batch_prefill mode, kv mode or hybrid(TBD) mode",
657+
default="kv",
658+
choices=["batch_prefill", "kv", "hybrid"],
659+
type=str,
660+
)
661+
553662
args = parser.parse_args()
554663
if args.compile_only and args.pre_gen_pte:
555664
exit("Cannot set both compile_only and pre_gen_pte as true")

0 commit comments

Comments
 (0)