Skip to content

Commit 0a1b3be

Browse files
peri044pytorchmergebot
authored andcommitted
chore: add unit test to verify split_by_tags output_type (pytorch#121262)
Add a test case as per pytorch#120361 (comment) Pull Request resolved: pytorch#121262 Approved by: https://github.com/atalman
1 parent 676a771 commit 0a1b3be

File tree

1 file changed

+73
-0
lines changed

1 file changed

+73
-0
lines changed

test/fx/test_fx_split.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,3 +150,76 @@ def test_split_by_tags(self) -> None:
150150
},
151151
f"{orig_to_split_fqn_mapping=}",
152152
)
153+
154+
class TestSplitOutputType(TestCase):
155+
class TestModule(torch.nn.Module):
156+
def __init__(self):
157+
super().__init__()
158+
self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True)
159+
self.relu = torch.nn.ReLU()
160+
161+
def forward(self, x):
162+
conv = self.conv(x)
163+
conv = conv * 0.5
164+
relu = self.relu(conv)
165+
return relu
166+
167+
@staticmethod
168+
def trace_and_tag(
169+
module: torch.nn.Module, inputs: torch.Tensor, tags: List[str]
170+
) -> Tuple[torch.fx.GraphModule, Dict[str, List[str]]]:
171+
"""
172+
Test simple gm consists of nodes with tag (only show call_module nodes here):
173+
conv - tag: "red"
174+
mul - tag: "blue"
175+
relu - tag: "green"
176+
177+
At the beginning we have:
178+
gm:
179+
conv
180+
mul
181+
relu
182+
183+
split_gm = split_by_tags(gm, tags)
184+
185+
Then we have:
186+
split_gm:
187+
red:
188+
conv
189+
blue:
190+
mul
191+
green:
192+
relu
193+
"""
194+
tag_node = defaultdict(list)
195+
gm: torch.fx.GraphModule = torch.export.export(module, (inputs,)).module()
196+
# Add tag to all nodes and build dictionary record tag to call_module nodes
197+
for node in gm.graph.nodes:
198+
if "conv" in node.name:
199+
node.tag = tags[0]
200+
tag_node[tags[0]].append(node.name)
201+
elif "mul" in node.name:
202+
node.tag = tags[1]
203+
tag_node[tags[1]].append(node.name)
204+
else:
205+
node.tag = tags[2]
206+
if node.op == "call_module":
207+
tag_node[tags[2]].append(node.name)
208+
return gm, tag_node
209+
210+
def test_split_by_tags(self) -> None:
211+
tags = ["red", "blue", "green"]
212+
module = TestSplitOutputType.TestModule()
213+
214+
inputs = torch.randn((1, 3, 224, 224))
215+
216+
gm, tag_node = TestSplitOutputType.trace_and_tag(module, inputs, tags)
217+
split_gm, orig_to_split_fqn_mapping = split_by_tags(
218+
gm, tags, return_fqn_mapping=True
219+
)
220+
221+
gm_output = module(inputs)
222+
split_gm_output = split_gm(inputs)
223+
224+
self.assertTrue(type(gm_output) == type(split_gm_output))
225+
self.assertTrue(torch.equal(gm_output, split_gm_output))

0 commit comments

Comments
 (0)