@@ -150,3 +150,76 @@ def test_split_by_tags(self) -> None:
150
150
},
151
151
f"{ orig_to_split_fqn_mapping = } " ,
152
152
)
153
+
154
+ class TestSplitOutputType (TestCase ):
155
+ class TestModule (torch .nn .Module ):
156
+ def __init__ (self ):
157
+ super ().__init__ ()
158
+ self .conv = torch .nn .Conv2d (3 , 16 , 3 , stride = 1 , bias = True )
159
+ self .relu = torch .nn .ReLU ()
160
+
161
+ def forward (self , x ):
162
+ conv = self .conv (x )
163
+ conv = conv * 0.5
164
+ relu = self .relu (conv )
165
+ return relu
166
+
167
+ @staticmethod
168
+ def trace_and_tag (
169
+ module : torch .nn .Module , inputs : torch .Tensor , tags : List [str ]
170
+ ) -> Tuple [torch .fx .GraphModule , Dict [str , List [str ]]]:
171
+ """
172
+ Test simple gm consists of nodes with tag (only show call_module nodes here):
173
+ conv - tag: "red"
174
+ mul - tag: "blue"
175
+ relu - tag: "green"
176
+
177
+ At the beginning we have:
178
+ gm:
179
+ conv
180
+ mul
181
+ relu
182
+
183
+ split_gm = split_by_tags(gm, tags)
184
+
185
+ Then we have:
186
+ split_gm:
187
+ red:
188
+ conv
189
+ blue:
190
+ mul
191
+ green:
192
+ relu
193
+ """
194
+ tag_node = defaultdict (list )
195
+ gm : torch .fx .GraphModule = torch .export .export (module , (inputs ,)).module ()
196
+ # Add tag to all nodes and build dictionary record tag to call_module nodes
197
+ for node in gm .graph .nodes :
198
+ if "conv" in node .name :
199
+ node .tag = tags [0 ]
200
+ tag_node [tags [0 ]].append (node .name )
201
+ elif "mul" in node .name :
202
+ node .tag = tags [1 ]
203
+ tag_node [tags [1 ]].append (node .name )
204
+ else :
205
+ node .tag = tags [2 ]
206
+ if node .op == "call_module" :
207
+ tag_node [tags [2 ]].append (node .name )
208
+ return gm , tag_node
209
+
210
+ def test_split_by_tags (self ) -> None :
211
+ tags = ["red" , "blue" , "green" ]
212
+ module = TestSplitOutputType .TestModule ()
213
+
214
+ inputs = torch .randn ((1 , 3 , 224 , 224 ))
215
+
216
+ gm , tag_node = TestSplitOutputType .trace_and_tag (module , inputs , tags )
217
+ split_gm , orig_to_split_fqn_mapping = split_by_tags (
218
+ gm , tags , return_fqn_mapping = True
219
+ )
220
+
221
+ gm_output = module (inputs )
222
+ split_gm_output = split_gm (inputs )
223
+
224
+ self .assertTrue (type (gm_output ) == type (split_gm_output ))
225
+ self .assertTrue (torch .equal (gm_output , split_gm_output ))
0 commit comments