@@ -142,6 +142,8 @@ def __init__(
142
142
verbose : bool = False ,
143
143
):
144
144
self .model = model
145
+ # graph module returned from capture_pre_autograd_graph
146
+ self .pre_autograd_graph_module : Optional [torch .fx .GraphModule ] = None
145
147
self .modelname = modelname
146
148
self .weight_type = weight_type
147
149
self .dtype = dtype
@@ -251,25 +253,27 @@ def _get_metadata(self):
251
253
self .metadata = metadata
252
254
return self .metadata
253
255
254
- def export_to_edge (
256
+ def pt2e_quantize (
255
257
self , quantizers : Optional [List [Quantizer ]]
256
258
) -> "LlamaEdgeManager" :
257
259
"""
258
- Export the model to Edge dialect and retrieve a EdgeManager .
260
+ Quantize the model via pt2e flow and retrieve LlamaEdgeManager including the quantized model .
259
261
Args:
260
262
quantizers (Optional[List[Quantizer]]): A list of quantizers.
261
263
"""
264
+ assert (
265
+ self .edge_manager is None
266
+ ), "export_to_edge is already called, please call pt2e_quantize before export_to_edge"
267
+ logging .info (f"Using pt2e { quantizers } to quantizing the model..." )
262
268
dynamic_shape = self ._get_dynamic_shape ()
263
- edge_config = self ._get_edge_config ()
264
- metadata = self ._get_metadata ()
265
269
266
270
# 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
267
271
# 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)
268
- with torch . nn . attention . sdpa_kernel ([ SDPBackend . MATH ]), torch . no_grad () :
269
- m = capture_pre_autograd_graph (
270
- self . model , self . example_inputs , dynamic_shapes = dynamic_shape
271
- )
272
- if quantizers :
272
+ if quantizers :
273
+ with torch . nn . attention . sdpa_kernel ([ SDPBackend . MATH ]), torch . no_grad ():
274
+ m = capture_pre_autograd_graph (
275
+ self . model , self . example_inputs , dynamic_shapes = dynamic_shape
276
+ )
273
277
if self .verbose :
274
278
logging .info (f"Applied quantizers: { quantizers } " )
275
279
composed_quantizer = ComposableQuantizer (quantizers )
@@ -278,8 +282,29 @@ def export_to_edge(
278
282
m (* self .example_inputs )
279
283
m = convert_pt2e (m )
280
284
DuplicateDynamicQuantChainPass ()(m )
285
+ self .pre_autograd_graph_module = m
286
+ return self
287
+ else :
288
+ logging .info ("No quantizer provided, passing..." )
289
+ return self
290
+
291
+ def export_to_edge (self ) -> "LlamaEdgeManager" :
292
+ """
293
+ Export the model to Edge dialect and retrieve a LlamaEdgeManager.
294
+ """
295
+ dynamic_shape = self ._get_dynamic_shape ()
296
+ edge_config = self ._get_edge_config ()
297
+ metadata = self ._get_metadata ()
298
+
299
+ # 1. torch.nn.attention.sdpa_kernel([SDPBackend.MATH]) is for bypassing the dynamo error when tracing
300
+ # 2. torch.no_grad() is for getting rid of the dropout (not sure why training ops will show up)
301
+ with torch .nn .attention .sdpa_kernel ([SDPBackend .MATH ]), torch .no_grad ():
302
+ if self .pre_autograd_graph_module is None :
303
+ self .pre_autograd_graph_module = capture_pre_autograd_graph (
304
+ self .model , self .example_inputs , dynamic_shapes = dynamic_shape
305
+ )
281
306
self .edge_manager = export_to_edge (
282
- m ,
307
+ self . pre_autograd_graph_module ,
283
308
self .example_inputs ,
284
309
dynamic_shapes = dynamic_shape ,
285
310
edge_constant_methods = metadata ,
0 commit comments