Skip to content

Commit dfa5dbc

Browse files
authored
Reusable test framework the Inspector tests
Differential Revision: D75803288 Pull Request resolved: #11314
1 parent 450d1f9 commit dfa5dbc

File tree

3 files changed

+177
-105
lines changed

3 files changed

+177
-105
lines changed

devtools/inspector/tests/TARGETS

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest")
2+
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
23

34
oncall("executorch")
45

@@ -13,6 +14,7 @@ python_unittest(
1314
"//executorch/devtools/inspector:inspector",
1415
"//executorch/devtools/inspector:lib",
1516
"//executorch/exir:lib",
17+
"//executorch/devtools/inspector/tests:inspector_test_utils",
1618
],
1719
)
1820

@@ -48,5 +50,16 @@ python_unittest(
4850
"//executorch/devtools/inspector:lib",
4951
"//executorch/devtools/inspector:intermediate_output_capturer",
5052
"//executorch/exir:lib",
53+
"//executorch/devtools/inspector/tests:inspector_test_utils",
54+
],
55+
)
56+
57+
python_library(
58+
name = "inspector_test_utils",
59+
srcs = [
60+
"inspector_test_utils.py",
61+
],
62+
deps = [
63+
"//caffe2:torch",
5164
],
5265
)
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
import torch
10+
import torch.nn as nn
11+
import torch.nn.functional as F
12+
13+
14+
class ConvlLinearModel(nn.Module):
15+
"""
16+
A neural network model with a convolutional layer followed by a linear layer.
17+
"""
18+
19+
def __init__(self):
20+
super(ConvlLinearModel, self).__init__()
21+
self.conv_layer = nn.Conv2d(
22+
in_channels=1, out_channels=1, kernel_size=3, stride=1, padding=1
23+
)
24+
self.conv_layer.weight = nn.Parameter(
25+
torch.tensor([[[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]]])
26+
)
27+
self.conv_layer.bias = nn.Parameter(torch.tensor([0.0]))
28+
29+
self.linear_layer = nn.Linear(in_features=4, out_features=2)
30+
self.linear_layer.weight = nn.Parameter(
31+
torch.tensor([[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]])
32+
)
33+
self.linear_layer.bias = nn.Parameter(torch.tensor([0.0, 0.0]))
34+
self.additional_bias = nn.Parameter(
35+
torch.tensor([0.5, -0.5]), requires_grad=False
36+
)
37+
self.scale_factor = nn.Parameter(torch.tensor([2.0, 0.5]), requires_grad=False)
38+
39+
def forward(self, x):
40+
x = self.conv_layer(x)
41+
x = x.view(x.size(0), -1)
42+
x = self.linear_layer(x)
43+
x = x + self.additional_bias
44+
x = x - 0.1
45+
x = x * self.scale_factor
46+
x = x / (self.scale_factor + 1.0)
47+
x = F.relu(x)
48+
x = torch.sigmoid(x)
49+
output1, output2 = torch.split(x, 1, dim=1)
50+
return output1, output2
51+
52+
@staticmethod
53+
def get_input():
54+
"""
55+
Returns the pre-defined input tensor for this model.
56+
"""
57+
return torch.tensor([[[[1.0, 2.0], [3.0, 4.0]]]], requires_grad=True)
58+
59+
@staticmethod
60+
def get_expected_intermediate_outputs():
61+
"""
62+
Returns the expected outputs of the debug handles and intermediate output mapping for this model for the given input.
63+
"""
64+
return {
65+
(10,): torch.tensor([[[[7.7000, 6.7000], [4.7000, 3.7000]]]]),
66+
(11,): torch.tensor([[7.7000, 6.7000, 4.7000, 3.7000]]),
67+
(12,): torch.tensor(
68+
[
69+
[0.1000, 0.5000],
70+
[0.2000, 0.6000],
71+
[0.3000, 0.7000],
72+
[0.4000, 0.8000],
73+
]
74+
),
75+
(13,): torch.tensor([[5.0000, 14.1200]]),
76+
(14,): torch.tensor([[5.5000, 13.6200]]),
77+
(15,): torch.tensor([[5.4000, 13.5200]]),
78+
(16,): torch.tensor([[10.8000, 6.7600]]),
79+
(17,): torch.tensor([3.0000, 1.5000]),
80+
(18,): torch.tensor([[3.6000, 4.5067]]),
81+
(19,): torch.tensor([[3.6000, 4.5067]]),
82+
(20,): torch.tensor([[0.9734, 0.9891]]),
83+
(21,): [torch.tensor([[0.9734]]), torch.tensor([[0.9891]])],
84+
}
85+
86+
87+
# Global model registry
88+
model_registry = {
89+
"ConvLinearModel": ConvlLinearModel,
90+
# Add new models here
91+
}
92+
93+
94+
def check_if_final_outputs_match(model_name, actual_outputs_with_handles):
95+
"""
96+
Checks if the actual outputs match the expected outputs for the specified model.
97+
Returns True if all outputs match, otherwise returns False.
98+
"""
99+
model_instance = model_registry[model_name]
100+
expected_outputs_with_handles = model_instance.get_expected_intermediate_outputs()
101+
if len(actual_outputs_with_handles) != len(expected_outputs_with_handles):
102+
return False
103+
for debug_handle, expected_output in expected_outputs_with_handles.items():
104+
actual_output = actual_outputs_with_handles.get(debug_handle)
105+
if actual_output is None:
106+
return False
107+
if isinstance(expected_output, list):
108+
if not isinstance(actual_output, list):
109+
return False
110+
if len(actual_output) != len(expected_output):
111+
return False
112+
for actual, expected in zip(actual_output, expected_output):
113+
if not torch.allclose(actual, expected, rtol=1e-4, atol=1e-5):
114+
return False
115+
else:
116+
if not torch.allclose(actual_output, expected_output, rtol=1e-4, atol=1e-5):
117+
return False
118+
return True

devtools/inspector/tests/intermediate_output_capturer_test.py

Lines changed: 46 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -6,127 +6,68 @@
66

77
# pyre-unsafe
88

9-
109
import unittest
1110

1211
import torch
13-
import torch.nn as nn
14-
import torch.nn.functional as F
1512
from executorch.devtools.inspector._intermediate_output_capturer import (
1613
IntermediateOutputCapturer,
1714
)
18-
15+
from executorch.devtools.inspector.tests.inspector_test_utils import (
16+
check_if_final_outputs_match,
17+
model_registry,
18+
)
1919
from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge
2020
from torch.export import export, ExportedProgram
2121
from torch.fx import GraphModule
2222

2323

2424
class TestIntermediateOutputCapturer(unittest.TestCase):
25-
@classmethod
26-
def setUpClass(cls):
27-
class TestModule(nn.Module):
28-
def __init__(self):
29-
super(TestModule, self).__init__()
30-
self.conv = nn.Conv2d(
31-
in_channels=1, out_channels=1, kernel_size=3, stride=1, padding=1
32-
)
33-
self.conv.weight = nn.Parameter(
34-
torch.tensor(
35-
[[[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]]]
36-
)
37-
)
38-
self.conv.bias = nn.Parameter(torch.tensor([0.0]))
39-
40-
self.linear = nn.Linear(in_features=4, out_features=2)
41-
self.linear.weight = nn.Parameter(
42-
torch.tensor([[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]])
43-
)
44-
self.linear.bias = nn.Parameter(torch.tensor([0.0, 0.0]))
45-
self.bias = nn.Parameter(torch.tensor([0.5, -0.5]), requires_grad=False)
46-
self.scale = nn.Parameter(torch.tensor([2.0, 0.5]), requires_grad=False)
47-
48-
def forward(self, x):
49-
x = self.conv(x)
50-
x = x.view(x.size(0), -1)
51-
x = self.linear(x)
52-
x = x + self.bias
53-
x = x - 0.1
54-
x = x * self.scale
55-
x = x / (self.scale + 1.0)
56-
x = F.relu(x)
57-
x = torch.sigmoid(x)
58-
x1, x2 = torch.split(x, 1, dim=1)
59-
return x1, x2
60-
61-
cls.model = TestModule()
62-
cls.input = torch.tensor([[[[1.0, 2.0], [3.0, 4.0]]]], requires_grad=True)
63-
cls.aten_model: ExportedProgram = export(cls.model, (cls.input,), strict=True)
64-
cls.edge_program_manager: EdgeProgramManager = to_edge(
65-
cls.aten_model, compile_config=EdgeCompileConfig(_check_ir_validity=True)
25+
def _set_up_model(self, model_name):
26+
model = model_registry[model_name]()
27+
input_tensor = model.get_input()
28+
aten_model: ExportedProgram = export(model, (input_tensor,), strict=True)
29+
edge_program_manager: EdgeProgramManager = to_edge(
30+
aten_model, compile_config=EdgeCompileConfig(_check_ir_validity=True)
6631
)
67-
cls.graph_module: GraphModule = cls.edge_program_manager._edge_programs[
32+
graph_module: GraphModule = edge_program_manager._edge_programs[
6833
"forward"
6934
].module()
70-
cls.capturer = IntermediateOutputCapturer(cls.graph_module)
71-
cls.intermediate_outputs = cls.capturer.run_and_capture(cls.input)
72-
73-
def test_keying_with_debug_handle_tuple(self):
74-
for key in self.intermediate_outputs.keys():
75-
self.assertIsInstance(key, tuple)
76-
77-
def test_tensor_cloning_and_detaching(self):
78-
for output in self.intermediate_outputs.values():
79-
if isinstance(output, torch.Tensor):
80-
self.assertFalse(output.requires_grad)
81-
self.assertTrue(output.is_leaf)
82-
83-
def test_placeholder_nodes_are_skipped(self):
84-
for node in self.graph_module.graph.nodes:
85-
if node.op == "placeholder":
86-
self.assertNotIn(
87-
node.meta.get("debug_handle"), self.intermediate_outputs
35+
capturer = IntermediateOutputCapturer(graph_module)
36+
intermediate_outputs = capturer.run_and_capture(input_tensor)
37+
return input_tensor, graph_module, capturer, intermediate_outputs
38+
39+
def test_models(self):
40+
available_models = list(model_registry.keys())
41+
for model_name in available_models:
42+
with self.subTest(model=model_name):
43+
input_tensor, graph_module, capturer, intermediate_outputs = (
44+
self._set_up_model(model_name)
8845
)
8946

90-
def test_multiple_outputs_capture(self):
91-
outputs = self.capturer.run_and_capture(self.input)
92-
for output in outputs.values():
93-
if isinstance(output, tuple):
94-
self.assertEqual(len(output), 2)
95-
for part in output:
96-
self.assertIsInstance(part, torch.Tensor)
97-
98-
def test_capture_correct_outputs(self):
99-
expected_outputs_with_handles = {
100-
(10,): torch.tensor([[[[7.7000, 6.7000], [4.7000, 3.7000]]]]),
101-
(11,): torch.tensor([[7.7000, 6.7000, 4.7000, 3.7000]]),
102-
(12,): torch.tensor(
103-
[[0.1000, 0.5000], [0.2000, 0.6000], [0.3000, 0.7000], [0.4000, 0.8000]]
104-
),
105-
(13,): torch.tensor([[5.0000, 14.1200]]),
106-
(14,): torch.tensor([[5.5000, 13.6200]]),
107-
(15,): torch.tensor([[5.4000, 13.5200]]),
108-
(16,): torch.tensor([[10.8000, 6.7600]]),
109-
(17,): torch.tensor([3.0000, 1.5000]),
110-
(18,): torch.tensor([[3.6000, 4.5067]]),
111-
(19,): torch.tensor([[3.6000, 4.5067]]),
112-
(20,): torch.tensor([[0.9734, 0.9891]]),
113-
(21,): [torch.tensor([[0.9734]]), torch.tensor([[0.9891]])],
114-
}
115-
self.assertEqual(
116-
len(self.intermediate_outputs), len(expected_outputs_with_handles)
117-
)
118-
119-
for debug_handle, expected_output in expected_outputs_with_handles.items():
120-
actual_output = self.intermediate_outputs.get(debug_handle)
121-
self.assertIsNotNone(actual_output)
122-
if isinstance(expected_output, list):
123-
self.assertIsInstance(actual_output, list)
124-
self.assertEqual(len(actual_output), len(expected_output))
125-
for actual, expected in zip(actual_output, expected_output):
126-
self.assertTrue(
127-
torch.allclose(actual, expected, rtol=1e-4, atol=1e-5)
128-
)
129-
else:
47+
# Test keying with debug handle tuple
48+
for key in intermediate_outputs.keys():
49+
self.assertIsInstance(key, tuple)
50+
51+
# Test tensor cloning and detaching
52+
for output in intermediate_outputs.values():
53+
if isinstance(output, torch.Tensor):
54+
self.assertFalse(output.requires_grad)
55+
self.assertTrue(output.is_leaf)
56+
57+
# Test placeholder nodes are skipped
58+
for node in graph_module.graph.nodes:
59+
if node.op == "placeholder":
60+
self.assertNotIn(node.meta.get("debug_handle"), node.meta)
61+
62+
# Test multiple outputs capture
63+
outputs = capturer.run_and_capture(input_tensor)
64+
for output in outputs.values():
65+
if isinstance(output, tuple):
66+
self.assertEqual(len(output), 2)
67+
for part in output:
68+
self.assertIsInstance(part, torch.Tensor)
69+
70+
# Test capture correct outputs
13071
self.assertTrue(
131-
torch.allclose(actual_output, expected_output, rtol=1e-4, atol=1e-5)
72+
check_if_final_outputs_match(model_name, intermediate_outputs)
13273
)

0 commit comments

Comments
 (0)