Skip to content

Commit 987a9f0

Browse files
cccclaifacebook-github-bot
authored andcommitted
use partitioner instance directly in to_backend (#2513)
Summary: Pull Request resolved: #2513 to_backend either takes partitioner or a dict of partitioner `key: method_name, value: partitioner`. We shouldn't do key as the backend name and value as the partitioner. Differential Revision: D55078939
1 parent 12d9e25 commit 987a9f0

File tree

2 files changed

+10
-25
lines changed

2 files changed

+10
-25
lines changed

examples/models/llama2/builder.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -286,9 +286,7 @@ def export_to_edge(
286286
)
287287
return self
288288

289-
def to_backend(
290-
self, partitioner: Union[Partitioner, Dict[str, Partitioner]]
291-
) -> "LlamaEdgeManager":
289+
def to_backend(self, partitioner: Partitioner) -> "LlamaEdgeManager":
292290
"""
293291
Partition the model and lower to different backends. The signature is
294292
aligned with the signature of `to_backend` method of EdgeManager.
@@ -297,18 +295,7 @@ def to_backend(
297295
partitioner to be sent to EdgeManager.to_backend().
298296
"""
299297
assert self.edge_manager is not None, "Need to run export_to_edge() first"
300-
if isinstance(partitioner, dict):
301-
for key, p in partitioner.items():
302-
assert self.edge_manager is not None
303-
self.edge_manager = self.edge_manager.to_backend(p)
304-
if self.verbose:
305-
logging.info(
306-
print_delegated_graph(
307-
self.edge_manager.exported_program().graph_module
308-
)
309-
)
310-
logging.info(f"Applied partitioners: {key}")
311-
elif isinstance(partitioner, Partitioner):
298+
if isinstance(partitioner, Partitioner):
312299
assert self.edge_manager is not None
313300
self.edge_manager = self.edge_manager.to_backend(partitioner)
314301
if self.verbose:

examples/models/llama2/export_llama_lib.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -490,21 +490,17 @@ def _export_llama(modelname, args) -> str: # noqa: C901
490490
).export_to_edge(quantizers)
491491

492492
# to_backend
493-
partitioners = {}
493+
partitioner = None
494494
if pt2e_quant_params is not None and pt2e_quant_params.quantize_linear is not None:
495-
partitioners[XnnpackDynamicallyQuantizedPartitioner.__name__] = (
496-
XnnpackDynamicallyQuantizedPartitioner()
497-
)
495+
partitioner = XnnpackDynamicallyQuantizedPartitioner()
498496
modelname = f"xnnpack_dq_{modelname}"
499497

500498
if args.xnnpack:
501499
# Following changes due to.
502500
# 1. We need dynamically quantized partitioner for both pt2e_quantize options
503501
# as well as "qmode int4" which is also dynamic quantizes linear layers.
504502
# 2. XNNPACK partitioner seems to result in seg fault for non dqlinear ops.
505-
partitioners[XnnpackDynamicallyQuantizedPartitioner.__name__] = (
506-
XnnpackDynamicallyQuantizedPartitioner()
507-
)
503+
partitioner = XnnpackDynamicallyQuantizedPartitioner()
508504
# partitioners[XnnpackPartitioner.__name__] = XnnpackPartitioner()
509505
modelname = f"xnnpack_{modelname}"
510506

@@ -516,7 +512,7 @@ def _export_llama(modelname, args) -> str: # noqa: C901
516512
args.quantization_mode is None
517513
), "Vulkan backend does not support quantization at the moment"
518514

519-
partitioners[VulkanPartitioner.__name__] = VulkanPartitioner()
515+
partitioner = VulkanPartitioner()
520516
modelname = f"vulkan_{modelname}"
521517

522518
if args.mps:
@@ -545,7 +541,8 @@ def _export_llama(modelname, args) -> str: # noqa: C901
545541
logging.info("Generating etrecord")
546542
# Copy the edge manager which will be serialized into etrecord. This is memory-wise expensive.
547543
edge_manager_copy = copy.deepcopy(builder_exported_to_edge.edge_manager)
548-
builder = builder_exported_to_edge.to_backend(partitioners).to_executorch()
544+
# #pyre-ignore: pyre can't recognize the type of the instance
545+
builder = builder_exported_to_edge.to_backend(partitioner).to_executorch()
549546

550547
# Generate ETRecord
551548
if edge_manager_copy:
@@ -556,7 +553,8 @@ def _export_llama(modelname, args) -> str: # noqa: C901
556553
)
557554
logging.info("Generated etrecord.bin")
558555
else:
559-
builder = builder_exported_to_edge.to_backend(partitioners).to_executorch()
556+
# #pyre-ignore: pyre can't recognize the type of the instance
557+
builder = builder_exported_to_edge.to_backend(partitioner).to_executorch()
560558

561559
if args.profile_memory:
562560
generate_memory_trace(builder.export_program, "memory_profile.json")

0 commit comments

Comments
 (0)