Skip to content

Commit 0dc9488

Browse files
cccclaifacebook-github-bot
authored andcommitted
Enable eval for pt2e quantization (#3652)
Summary: Pull Request resolved: #3652 Have the eval framework to work with pt2e quantizer to unblock measuring the accuracy for qnn quantizer Reviewed By: Jack-Khuu Differential Revision: D57316602 fbshipit-source-id: ca2b6f0dc6a6ebb4f2f323d09c27c20978cec569
1 parent 435ea9d commit 0dc9488

File tree

3 files changed

+45
-13
lines changed

3 files changed

+45
-13
lines changed

examples/models/llama2/builder.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,16 @@ def _get_metadata(self):
253253
self.metadata = metadata
254254
return self.metadata
255255

256+
def capture_pre_autograd_graph(self) -> "LlamaEdgeManager":
257+
dynamic_shape = self._get_dynamic_shape()
258+
# 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
259+
# 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)
260+
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
261+
self.pre_autograd_graph_module = capture_pre_autograd_graph(
262+
self.model, self.example_inputs, dynamic_shapes=dynamic_shape
263+
)
264+
return self
265+
256266
def pt2e_quantize(
257267
self, quantizers: Optional[List[Quantizer]]
258268
) -> "LlamaEdgeManager":
@@ -265,19 +275,18 @@ def pt2e_quantize(
265275
self.edge_manager is None
266276
), "export_to_edge is already called, please call pt2e_quantize before export_to_edge"
267277
logging.info(f"Using pt2e {quantizers} to quantizing the model...")
268-
dynamic_shape = self._get_dynamic_shape()
269278

270279
# 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
271280
# 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)
272281
if quantizers:
273282
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
274-
m = capture_pre_autograd_graph(
275-
self.model, self.example_inputs, dynamic_shapes=dynamic_shape
276-
)
277283
if self.verbose:
278284
logging.info(f"Applied quantizers: {quantizers}")
279285
composed_quantizer = ComposableQuantizer(quantizers)
280-
m = prepare_pt2e(m, composed_quantizer)
286+
assert (
287+
self.pre_autograd_graph_module is not None
288+
), "Please run capture_pre_autograd_graph first"
289+
m = prepare_pt2e(self.pre_autograd_graph_module, composed_quantizer)
281290
# Calibrate
282291
m(*self.example_inputs)
283292
m = convert_pt2e(m)

examples/models/llama2/eval_llama_lib.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111

1212
import lm_eval
1313
import torch
14-
14+
from executorch.examples.models.llama2.export_llama_lib import (
15+
get_quantizer_and_quant_params,
16+
)
1517
from executorch.examples.models.llama2.tokenizer.tiktoken import Tokenizer as Tiktoken
1618
from executorch.examples.models.llama2.tokenizer.tokenizer import (
1719
Tokenizer as SentencePieceTokenizer,
@@ -233,13 +235,27 @@ def gen_eval_wrapper(
233235
max_seq_length=args.max_seq_length - 1,
234236
)
235237

238+
pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args)
236239
# GPTFastEvalWrapper: Create a wrapper around a pre-exported model
237240
manager: LlamaEdgeManager = _prepare_for_llama_export(model_name, args)
238-
model = (
239-
manager.model.eval().to(device="cuda")
240-
if torch.cuda.is_available()
241-
else manager.model.to(device="cpu")
242-
)
241+
242+
if len(quantizers) != 0:
243+
manager = manager.capture_pre_autograd_graph().pt2e_quantize(quantizers)
244+
model = (
245+
manager.pre_autograd_graph_module.to(device="cuda")
246+
if torch.cuda.is_available()
247+
else manager.pre_autograd_graph_module.to(device="cpu")
248+
)
249+
else:
250+
# TODO: use manager.pre_autograd_graph_module for the eval to remove the if-else branch
251+
# for quantizers. Currently capture_pre_autograd_graph only works with --kv_cache, but
252+
# fails without the kv_cache mode
253+
model = (
254+
manager.model.eval().to(device="cuda")
255+
if torch.cuda.is_available()
256+
else manager.model.eval().to(device="cpu")
257+
)
258+
243259
return EagerEvalWrapper(
244260
model=model,
245261
tokenizer=tokenizer,

examples/models/llama2/export_llama_lib.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -371,18 +371,25 @@ def _prepare_for_llama_export(modelname: str, args) -> LlamaEdgeManager:
371371
)
372372

373373

374-
def _export_llama(modelname, args) -> str: # noqa: C901
375-
# export_to_edge
374+
def get_quantizer_and_quant_params(args):
376375
pt2e_quant_params = _get_pt2e_quantization_params(args)
377376
quantizers = get_pt2e_quantizers(pt2e_quant_params, args)
378377
quant_dtype = None
379378
if args.qnn and args.pt2e_quantize:
380379
assert len(quantizers) == 0, "Should not enable both xnnpack and qnn"
381380
qnn_quantizer, quant_dtype = get_qnn_quantizer(args)
382381
quantizers.append(qnn_quantizer)
382+
logging.info(f"Applying quantizers: {quantizers}")
383+
return pt2e_quant_params, quantizers, quant_dtype
384+
383385

386+
def _export_llama(modelname, args) -> str: # noqa: C901
387+
pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args)
388+
389+
# export_to_edge
384390
builder_exported_to_edge = (
385391
_prepare_for_llama_export(modelname, args)
392+
.capture_pre_autograd_graph()
386393
.pt2e_quantize(quantizers)
387394
.export_to_edge()
388395
)

0 commit comments

Comments
 (0)