21
21
from torch .export .exported_program import ExportedProgram
22
22
from torch .fx .passes .infra .partitioner import CapabilityBasedPartitioner
23
23
from torch .fx .passes .operator_support import OperatorSupportBase
24
+ from torch ._export .utils import is_buffer , is_lifted_tensor_constant , is_param
24
25
25
26
logger = logging .getLogger (__name__ )
26
27
logger .setLevel (logging .WARNING )
@@ -73,7 +74,6 @@ def __init__(
73
74
def partition (self , exported_program : ExportedProgram ) -> PartitionResult :
74
75
# Run the CapabilityBasedPartitioner to return the largest possible
75
76
# subgraphs containing the nodes with the tags
76
- logger .info ("CoreMLPartitioner::partition" )
77
77
partition_tags = {}
78
78
79
79
capability_partitioner = CapabilityBasedPartitioner (
@@ -87,8 +87,20 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
87
87
tag = f"tag{ partition .id } "
88
88
node .meta ["delegation_tag" ] = tag
89
89
partition_tags [tag ] = self .delegation_spec
90
-
91
- tag_constant_data (exported_program )
90
+ node_args = node .args
91
+ for arg in node_args :
92
+ # tag the data nodes if they're used by coreml backend
93
+ if isinstance (arg , torch .fx .Node ):
94
+ is_attr = (
95
+ arg .op == "placeholder"
96
+ and (
97
+ is_param (exported_program , arg )
98
+ or is_buffer (exported_program , arg )
99
+ or is_lifted_tensor_constant (exported_program , arg )
100
+ )
101
+ ) or (arg .op == "get_attr" )
102
+ if is_attr :
103
+ arg .meta ["delegation_tag" ] = tag
92
104
93
105
return PartitionResult (
94
106
tagged_exported_program = exported_program , partition_tags = partition_tags
0 commit comments