|
18 | 18 | PartitionResult,
|
19 | 19 | )
|
20 | 20 | from executorch.exir.backend.utils import tag_constant_data
|
| 21 | +from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param |
21 | 22 | from torch.export.exported_program import ExportedProgram
|
22 | 23 | from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
|
23 | 24 | from torch.fx.passes.operator_support import OperatorSupportBase
|
@@ -88,7 +89,25 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
|
88 | 89 | node.meta["delegation_tag"] = tag
|
89 | 90 | partition_tags[tag] = self.delegation_spec
|
90 | 91 |
|
91 |
| - tag_constant_data(exported_program) |
| 92 | + is_attr = ( |
| 93 | + node.op == "placeholder" |
| 94 | + and ( |
| 95 | + is_param(exported_program, node) |
| 96 | + or is_buffer(exported_program, node) |
| 97 | + or is_lifted_tensor_constant(exported_program, node) |
| 98 | + ) |
| 99 | + ) |
| 100 | + # if all users of const/param/buffer nodes are partitioned then partition |
| 101 | + if is_attr: |
| 102 | + user_tags = set() |
| 103 | + for user in node.users: |
| 104 | + user_tag = user.meta.get("delegation_tag", None) |
| 105 | + if user_tag is not None: |
| 106 | + user_tags.add(user_tag) |
| 107 | + if len(user_tags) >= 1: |
| 108 | + # There are more than one user tag, just pick one and there is a pass later to copy the constant data |
| 109 | + node.meta["delegation_tag"] = user_tags.pop() |
| 110 | + |
92 | 111 |
|
93 | 112 | return PartitionResult(
|
94 | 113 | tagged_exported_program=exported_program, partition_tags=partition_tags
|
|
0 commit comments