Skip to content

Commit 59c0956

Browse files
Juntian777facebook-github-bot
authored andcommitted
Reusable test framework the Inspector tests (#11314)
Summary: This Diff introduces a reusable test framework. The inspector_test_utils.py file provides methods to instantiate the model, retrieve expected outputs, and assert the correctness of actual outputs in a extensible way. Also, update the intermediate_output_capturer_test to take advantage of the reusable framework and make the intermediate_output_capturer_test more extensible by using reusable setup method, making it easier to add new models and tests. Differential Revision: D75803288
1 parent 5ef38d3 commit 59c0956

File tree

3 files changed

+173
-111
lines changed

3 files changed

+173
-111
lines changed

devtools/inspector/tests/TARGETS

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest")
2+
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
23

34
oncall("executorch")
45

@@ -13,6 +14,7 @@ python_unittest(
1314
"//executorch/devtools/inspector:inspector",
1415
"//executorch/devtools/inspector:lib",
1516
"//executorch/exir:lib",
17+
"//executorch/devtools/inspector/tests:inspector_test_utils",
1618
],
1719
)
1820

@@ -48,5 +50,16 @@ python_unittest(
4850
"//executorch/devtools/inspector:lib",
4951
"//executorch/devtools/inspector:intermediate_output_capturer",
5052
"//executorch/exir:lib",
53+
"//executorch/devtools/inspector/tests:inspector_test_utils",
54+
],
55+
)
56+
57+
python_library(
58+
name = "inspector_test_utils",
59+
srcs = [
60+
"inspector_test_utils.py",
61+
],
62+
deps = [
63+
"//caffe2:torch",
5164
],
5265
)
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
import torch
10+
import torch.nn as nn
11+
import torch.nn.functional as F
12+
13+
class ConvlLinearModel(nn.Module):
14+
"""
15+
A neural network model with a convolutional layer followed by a linear layer.
16+
"""
17+
def __init__(self):
18+
super(ConvlLinearModel, self).__init__()
19+
self.conv_layer = nn.Conv2d(
20+
in_channels=1, out_channels=1, kernel_size=3, stride=1, padding=1
21+
)
22+
self.conv_layer.weight = nn.Parameter(
23+
torch.tensor([[[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]]])
24+
)
25+
self.conv_layer.bias = nn.Parameter(torch.tensor([0.0]))
26+
27+
self.linear_layer = nn.Linear(in_features=4, out_features=2)
28+
self.linear_layer.weight = nn.Parameter(
29+
torch.tensor([[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]])
30+
)
31+
self.linear_layer.bias = nn.Parameter(torch.tensor([0.0, 0.0]))
32+
self.additional_bias = nn.Parameter(torch.tensor([0.5, -0.5]), requires_grad=False)
33+
self.scale_factor = nn.Parameter(torch.tensor([2.0, 0.5]), requires_grad=False)
34+
35+
def forward(self, x):
36+
x = self.conv_layer(x)
37+
x = x.view(x.size(0), -1)
38+
x = self.linear_layer(x)
39+
x = x + self.additional_bias
40+
x = x - 0.1
41+
x = x * self.scale_factor
42+
x = x / (self.scale_factor + 1.0)
43+
x = F.relu(x)
44+
x = torch.sigmoid(x)
45+
output1, output2 = torch.split(x, 1, dim=1)
46+
return output1, output2
47+
48+
@staticmethod
49+
def get_input():
50+
"""
51+
Returns the pre-defined input tensor for this model.
52+
"""
53+
return torch.tensor([[[[1.0, 2.0], [3.0, 4.0]]]], requires_grad=True)
54+
55+
@staticmethod
56+
def get_expected_intermediate_outputs():
57+
"""
58+
Returns the expected outputs of the debug handles and intermediate output mapping for this model for the given input.
59+
"""
60+
return {
61+
(10,): torch.tensor([[[[7.7000, 6.7000], [4.7000, 3.7000]]]]),
62+
(11,): torch.tensor([[7.7000, 6.7000, 4.7000, 3.7000]]),
63+
(12,): torch.tensor(
64+
[
65+
[0.1000, 0.5000],
66+
[0.2000, 0.6000],
67+
[0.3000, 0.7000],
68+
[0.4000, 0.8000],
69+
]
70+
),
71+
(13,): torch.tensor([[5.0000, 14.1200]]),
72+
(14,): torch.tensor([[5.5000, 13.6200]]),
73+
(15,): torch.tensor([[5.4000, 13.5200]]),
74+
(16,): torch.tensor([[10.8000, 6.7600]]),
75+
(17,): torch.tensor([3.0000, 1.5000]),
76+
(18,): torch.tensor([[3.6000, 4.5067]]),
77+
(19,): torch.tensor([[3.6000, 4.5067]]),
78+
(20,): torch.tensor([[0.9734, 0.9891]]),
79+
(21,): [torch.tensor([[0.9734]]), torch.tensor([[0.9891]])],
80+
}
81+
82+
# Global model registry
83+
model_registry = {
84+
"ConvLinearModel": ConvlLinearModel,
85+
# Add new models here
86+
}
87+
88+
89+
def assert_final_outputs_match(test_case, model_name, actual_outputs_with_handles):
90+
"""
91+
Asserts that the actual outputs match the expected outputs for the specified model.
92+
"""
93+
model_instance = model_registry[model_name]
94+
expected_outputs_with_handles = model_instance.get_expected_intermediate_outputs()
95+
96+
test_case.assertEqual(
97+
len(actual_outputs_with_handles),
98+
len(expected_outputs_with_handles),
99+
)
100+
101+
for debug_handle, expected_output in expected_outputs_with_handles.items():
102+
actual_output = actual_outputs_with_handles.get(debug_handle)
103+
test_case.assertIsNotNone(actual_output)
104+
105+
if isinstance(expected_output, list):
106+
test_case.assertIsInstance(actual_output, list)
107+
test_case.assertEqual(len(actual_output), len(expected_output))
108+
for actual, expected in zip(actual_output, expected_output):
109+
test_case.assertTrue(
110+
torch.allclose(actual, expected, rtol=1e-4, atol=1e-5)
111+
)
112+
else:
113+
test_case.assertTrue(
114+
torch.allclose(actual_output, expected_output, rtol=1e-4, atol=1e-5)
115+
)

devtools/inspector/tests/intermediate_output_capturer_test.py

Lines changed: 45 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -6,127 +6,61 @@
66

77
# pyre-unsafe
88

9-
109
import unittest
11-
1210
import torch
13-
import torch.nn as nn
14-
import torch.nn.functional as F
1511
from executorch.devtools.inspector._intermediate_output_capturer import (
1612
IntermediateOutputCapturer,
1713
)
18-
14+
from executorch.devtools.inspector.tests.inspector_test_utils import model_registry, assert_final_outputs_match
1915
from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge
2016
from torch.export import export, ExportedProgram
2117
from torch.fx import GraphModule
2218

23-
2419
class TestIntermediateOutputCapturer(unittest.TestCase):
25-
@classmethod
26-
def setUpClass(cls):
27-
class TestModule(nn.Module):
28-
def __init__(self):
29-
super(TestModule, self).__init__()
30-
self.conv = nn.Conv2d(
31-
in_channels=1, out_channels=1, kernel_size=3, stride=1, padding=1
32-
)
33-
self.conv.weight = nn.Parameter(
34-
torch.tensor(
35-
[[[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]]]
36-
)
37-
)
38-
self.conv.bias = nn.Parameter(torch.tensor([0.0]))
39-
40-
self.linear = nn.Linear(in_features=4, out_features=2)
41-
self.linear.weight = nn.Parameter(
42-
torch.tensor([[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]])
43-
)
44-
self.linear.bias = nn.Parameter(torch.tensor([0.0, 0.0]))
45-
self.bias = nn.Parameter(torch.tensor([0.5, -0.5]), requires_grad=False)
46-
self.scale = nn.Parameter(torch.tensor([2.0, 0.5]), requires_grad=False)
47-
48-
def forward(self, x):
49-
x = self.conv(x)
50-
x = x.view(x.size(0), -1)
51-
x = self.linear(x)
52-
x = x + self.bias
53-
x = x - 0.1
54-
x = x * self.scale
55-
x = x / (self.scale + 1.0)
56-
x = F.relu(x)
57-
x = torch.sigmoid(x)
58-
x1, x2 = torch.split(x, 1, dim=1)
59-
return x1, x2
60-
61-
cls.model = TestModule()
62-
cls.input = torch.tensor([[[[1.0, 2.0], [3.0, 4.0]]]], requires_grad=True)
63-
cls.aten_model: ExportedProgram = export(cls.model, (cls.input,), strict=True)
64-
cls.edge_program_manager: EdgeProgramManager = to_edge(
65-
cls.aten_model, compile_config=EdgeCompileConfig(_check_ir_validity=True)
20+
def _setUpModel(self, model_name):
21+
model = model_registry[model_name]()
22+
input_tensor = model.get_input()
23+
aten_model: ExportedProgram = export(model, (input_tensor,), strict=True)
24+
edge_program_manager: EdgeProgramManager = to_edge(
25+
aten_model, compile_config=EdgeCompileConfig(_check_ir_validity=True)
6626
)
67-
cls.graph_module: GraphModule = cls.edge_program_manager._edge_programs[
27+
graph_module: GraphModule = edge_program_manager._edge_programs[
6828
"forward"
6929
].module()
70-
cls.capturer = IntermediateOutputCapturer(cls.graph_module)
71-
cls.intermediate_outputs = cls.capturer.run_and_capture(cls.input)
72-
73-
def test_keying_with_debug_handle_tuple(self):
74-
for key in self.intermediate_outputs.keys():
75-
self.assertIsInstance(key, tuple)
76-
77-
def test_tensor_cloning_and_detaching(self):
78-
for output in self.intermediate_outputs.values():
79-
if isinstance(output, torch.Tensor):
80-
self.assertFalse(output.requires_grad)
81-
self.assertTrue(output.is_leaf)
82-
83-
def test_placeholder_nodes_are_skipped(self):
84-
for node in self.graph_module.graph.nodes:
85-
if node.op == "placeholder":
86-
self.assertNotIn(
87-
node.meta.get("debug_handle"), self.intermediate_outputs
88-
)
89-
90-
def test_multiple_outputs_capture(self):
91-
outputs = self.capturer.run_and_capture(self.input)
92-
for output in outputs.values():
93-
if isinstance(output, tuple):
94-
self.assertEqual(len(output), 2)
95-
for part in output:
96-
self.assertIsInstance(part, torch.Tensor)
97-
98-
def test_capture_correct_outputs(self):
99-
expected_outputs_with_handles = {
100-
(10,): torch.tensor([[[[7.7000, 6.7000], [4.7000, 3.7000]]]]),
101-
(11,): torch.tensor([[7.7000, 6.7000, 4.7000, 3.7000]]),
102-
(12,): torch.tensor(
103-
[[0.1000, 0.5000], [0.2000, 0.6000], [0.3000, 0.7000], [0.4000, 0.8000]]
104-
),
105-
(13,): torch.tensor([[5.0000, 14.1200]]),
106-
(14,): torch.tensor([[5.5000, 13.6200]]),
107-
(15,): torch.tensor([[5.4000, 13.5200]]),
108-
(16,): torch.tensor([[10.8000, 6.7600]]),
109-
(17,): torch.tensor([3.0000, 1.5000]),
110-
(18,): torch.tensor([[3.6000, 4.5067]]),
111-
(19,): torch.tensor([[3.6000, 4.5067]]),
112-
(20,): torch.tensor([[0.9734, 0.9891]]),
113-
(21,): [torch.tensor([[0.9734]]), torch.tensor([[0.9891]])],
114-
}
115-
self.assertEqual(
116-
len(self.intermediate_outputs), len(expected_outputs_with_handles)
117-
)
118-
119-
for debug_handle, expected_output in expected_outputs_with_handles.items():
120-
actual_output = self.intermediate_outputs.get(debug_handle)
121-
self.assertIsNotNone(actual_output)
122-
if isinstance(expected_output, list):
123-
self.assertIsInstance(actual_output, list)
124-
self.assertEqual(len(actual_output), len(expected_output))
125-
for actual, expected in zip(actual_output, expected_output):
126-
self.assertTrue(
127-
torch.allclose(actual, expected, rtol=1e-4, atol=1e-5)
128-
)
129-
else:
130-
self.assertTrue(
131-
torch.allclose(actual_output, expected_output, rtol=1e-4, atol=1e-5)
132-
)
30+
capturer = IntermediateOutputCapturer(graph_module)
31+
intermediate_outputs = capturer.run_and_capture(input_tensor)
32+
return input_tensor, graph_module, capturer, intermediate_outputs
33+
34+
def test_models(self):
35+
available_models = list(model_registry.keys())
36+
for model_name in available_models:
37+
with self.subTest(model=model_name):
38+
input_tensor, graph_module, capturer, intermediate_outputs = self._setUpModel(model_name)
39+
40+
# Test keying with debug handle tuple
41+
for key in intermediate_outputs.keys():
42+
self.assertIsInstance(key, tuple)
43+
44+
# Test tensor cloning and detaching
45+
for output in intermediate_outputs.values():
46+
if isinstance(output, torch.Tensor):
47+
self.assertFalse(output.requires_grad)
48+
self.assertTrue(output.is_leaf)
49+
50+
# Test placeholder nodes are skipped
51+
for node in graph_module.graph.nodes:
52+
if node.op == "placeholder":
53+
self.assertNotIn(
54+
node.meta.get("debug_handle"), node.meta
55+
)
56+
57+
# Test multiple outputs capture
58+
outputs = capturer.run_and_capture(input_tensor)
59+
for output in outputs.values():
60+
if isinstance(output, tuple):
61+
self.assertEqual(len(output), 2)
62+
for part in output:
63+
self.assertIsInstance(part, torch.Tensor)
64+
65+
# Test capture correct outputs
66+
assert_final_outputs_match(self, model_name, intermediate_outputs)

0 commit comments

Comments
 (0)