Skip to content

Reusable test framework the Inspector tests #11314

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions devtools/inspector/tests/TARGETS
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest")
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")

oncall("executorch")

Expand All @@ -13,6 +14,7 @@ python_unittest(
"//executorch/devtools/inspector:inspector",
"//executorch/devtools/inspector:lib",
"//executorch/exir:lib",
"//executorch/devtools/inspector/tests:inspector_test_utils",
],
)

Expand Down Expand Up @@ -48,5 +50,16 @@ python_unittest(
"//executorch/devtools/inspector:lib",
"//executorch/devtools/inspector:intermediate_output_capturer",
"//executorch/exir:lib",
"//executorch/devtools/inspector/tests:inspector_test_utils",
],
)

python_library(
name = "inspector_test_utils",
srcs = [
"inspector_test_utils.py",
],
deps = [
"//caffe2:torch",
],
)
118 changes: 118 additions & 0 deletions devtools/inspector/tests/inspector_test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import torch
import torch.nn as nn
import torch.nn.functional as F


class ConvlLinearModel(nn.Module):
"""
A neural network model with a convolutional layer followed by a linear layer.
"""

def __init__(self):
super(ConvlLinearModel, self).__init__()
self.conv_layer = nn.Conv2d(
in_channels=1, out_channels=1, kernel_size=3, stride=1, padding=1
)
self.conv_layer.weight = nn.Parameter(
torch.tensor([[[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]]])
)
self.conv_layer.bias = nn.Parameter(torch.tensor([0.0]))

self.linear_layer = nn.Linear(in_features=4, out_features=2)
self.linear_layer.weight = nn.Parameter(
torch.tensor([[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]])
)
self.linear_layer.bias = nn.Parameter(torch.tensor([0.0, 0.0]))
self.additional_bias = nn.Parameter(
torch.tensor([0.5, -0.5]), requires_grad=False
)
self.scale_factor = nn.Parameter(torch.tensor([2.0, 0.5]), requires_grad=False)

def forward(self, x):
x = self.conv_layer(x)
x = x.view(x.size(0), -1)
x = self.linear_layer(x)
x = x + self.additional_bias
x = x - 0.1
x = x * self.scale_factor
x = x / (self.scale_factor + 1.0)
x = F.relu(x)
x = torch.sigmoid(x)
output1, output2 = torch.split(x, 1, dim=1)
return output1, output2

@staticmethod
def get_input():
"""
Returns the pre-defined input tensor for this model.
"""
return torch.tensor([[[[1.0, 2.0], [3.0, 4.0]]]], requires_grad=True)

@staticmethod
def get_expected_intermediate_outputs():
"""
Returns the expected outputs of the debug handles and intermediate output mapping for this model for the given input.
"""
return {
(10,): torch.tensor([[[[7.7000, 6.7000], [4.7000, 3.7000]]]]),
(11,): torch.tensor([[7.7000, 6.7000, 4.7000, 3.7000]]),
(12,): torch.tensor(
[
[0.1000, 0.5000],
[0.2000, 0.6000],
[0.3000, 0.7000],
[0.4000, 0.8000],
]
),
(13,): torch.tensor([[5.0000, 14.1200]]),
(14,): torch.tensor([[5.5000, 13.6200]]),
(15,): torch.tensor([[5.4000, 13.5200]]),
(16,): torch.tensor([[10.8000, 6.7600]]),
(17,): torch.tensor([3.0000, 1.5000]),
(18,): torch.tensor([[3.6000, 4.5067]]),
(19,): torch.tensor([[3.6000, 4.5067]]),
(20,): torch.tensor([[0.9734, 0.9891]]),
(21,): [torch.tensor([[0.9734]]), torch.tensor([[0.9891]])],
}


# Global model registry
model_registry = {
"ConvLinearModel": ConvlLinearModel,
# Add new models here
}


def check_if_final_outputs_match(model_name, actual_outputs_with_handles):
"""
Checks if the actual outputs match the expected outputs for the specified model.
Returns True if all outputs match, otherwise returns False.
"""
model_instance = model_registry[model_name]
expected_outputs_with_handles = model_instance.get_expected_intermediate_outputs()
if len(actual_outputs_with_handles) != len(expected_outputs_with_handles):
return False
for debug_handle, expected_output in expected_outputs_with_handles.items():
actual_output = actual_outputs_with_handles.get(debug_handle)
if actual_output is None:
return False
if isinstance(expected_output, list):
if not isinstance(actual_output, list):
return False
if len(actual_output) != len(expected_output):
return False
for actual, expected in zip(actual_output, expected_output):
if not torch.allclose(actual, expected, rtol=1e-4, atol=1e-5):
return False
else:
if not torch.allclose(actual_output, expected_output, rtol=1e-4, atol=1e-5):
return False
return True
151 changes: 46 additions & 105 deletions devtools/inspector/tests/intermediate_output_capturer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,127 +6,68 @@

# pyre-unsafe


import unittest

import torch
import torch.nn as nn
import torch.nn.functional as F
from executorch.devtools.inspector._intermediate_output_capturer import (
IntermediateOutputCapturer,
)

from executorch.devtools.inspector.tests.inspector_test_utils import (
check_if_final_outputs_match,
model_registry,
)
from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge
from torch.export import export, ExportedProgram
from torch.fx import GraphModule


class TestIntermediateOutputCapturer(unittest.TestCase):
@classmethod
def setUpClass(cls):
class TestModule(nn.Module):
def __init__(self):
super(TestModule, self).__init__()
self.conv = nn.Conv2d(
in_channels=1, out_channels=1, kernel_size=3, stride=1, padding=1
)
self.conv.weight = nn.Parameter(
torch.tensor(
[[[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]]]
)
)
self.conv.bias = nn.Parameter(torch.tensor([0.0]))

self.linear = nn.Linear(in_features=4, out_features=2)
self.linear.weight = nn.Parameter(
torch.tensor([[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]])
)
self.linear.bias = nn.Parameter(torch.tensor([0.0, 0.0]))
self.bias = nn.Parameter(torch.tensor([0.5, -0.5]), requires_grad=False)
self.scale = nn.Parameter(torch.tensor([2.0, 0.5]), requires_grad=False)

def forward(self, x):
x = self.conv(x)
x = x.view(x.size(0), -1)
x = self.linear(x)
x = x + self.bias
x = x - 0.1
x = x * self.scale
x = x / (self.scale + 1.0)
x = F.relu(x)
x = torch.sigmoid(x)
x1, x2 = torch.split(x, 1, dim=1)
return x1, x2

cls.model = TestModule()
cls.input = torch.tensor([[[[1.0, 2.0], [3.0, 4.0]]]], requires_grad=True)
cls.aten_model: ExportedProgram = export(cls.model, (cls.input,), strict=True)
cls.edge_program_manager: EdgeProgramManager = to_edge(
cls.aten_model, compile_config=EdgeCompileConfig(_check_ir_validity=True)
def _set_up_model(self, model_name):
model = model_registry[model_name]()
input_tensor = model.get_input()
aten_model: ExportedProgram = export(model, (input_tensor,), strict=True)
edge_program_manager: EdgeProgramManager = to_edge(
aten_model, compile_config=EdgeCompileConfig(_check_ir_validity=True)
)
cls.graph_module: GraphModule = cls.edge_program_manager._edge_programs[
graph_module: GraphModule = edge_program_manager._edge_programs[
"forward"
].module()
cls.capturer = IntermediateOutputCapturer(cls.graph_module)
cls.intermediate_outputs = cls.capturer.run_and_capture(cls.input)

def test_keying_with_debug_handle_tuple(self):
for key in self.intermediate_outputs.keys():
self.assertIsInstance(key, tuple)

def test_tensor_cloning_and_detaching(self):
for output in self.intermediate_outputs.values():
if isinstance(output, torch.Tensor):
self.assertFalse(output.requires_grad)
self.assertTrue(output.is_leaf)

def test_placeholder_nodes_are_skipped(self):
for node in self.graph_module.graph.nodes:
if node.op == "placeholder":
self.assertNotIn(
node.meta.get("debug_handle"), self.intermediate_outputs
capturer = IntermediateOutputCapturer(graph_module)
intermediate_outputs = capturer.run_and_capture(input_tensor)
return input_tensor, graph_module, capturer, intermediate_outputs

def test_models(self):
available_models = list(model_registry.keys())
for model_name in available_models:
with self.subTest(model=model_name):
input_tensor, graph_module, capturer, intermediate_outputs = (
self._set_up_model(model_name)
)

def test_multiple_outputs_capture(self):
outputs = self.capturer.run_and_capture(self.input)
for output in outputs.values():
if isinstance(output, tuple):
self.assertEqual(len(output), 2)
for part in output:
self.assertIsInstance(part, torch.Tensor)

def test_capture_correct_outputs(self):
expected_outputs_with_handles = {
(10,): torch.tensor([[[[7.7000, 6.7000], [4.7000, 3.7000]]]]),
(11,): torch.tensor([[7.7000, 6.7000, 4.7000, 3.7000]]),
(12,): torch.tensor(
[[0.1000, 0.5000], [0.2000, 0.6000], [0.3000, 0.7000], [0.4000, 0.8000]]
),
(13,): torch.tensor([[5.0000, 14.1200]]),
(14,): torch.tensor([[5.5000, 13.6200]]),
(15,): torch.tensor([[5.4000, 13.5200]]),
(16,): torch.tensor([[10.8000, 6.7600]]),
(17,): torch.tensor([3.0000, 1.5000]),
(18,): torch.tensor([[3.6000, 4.5067]]),
(19,): torch.tensor([[3.6000, 4.5067]]),
(20,): torch.tensor([[0.9734, 0.9891]]),
(21,): [torch.tensor([[0.9734]]), torch.tensor([[0.9891]])],
}
self.assertEqual(
len(self.intermediate_outputs), len(expected_outputs_with_handles)
)

for debug_handle, expected_output in expected_outputs_with_handles.items():
actual_output = self.intermediate_outputs.get(debug_handle)
self.assertIsNotNone(actual_output)
if isinstance(expected_output, list):
self.assertIsInstance(actual_output, list)
self.assertEqual(len(actual_output), len(expected_output))
for actual, expected in zip(actual_output, expected_output):
self.assertTrue(
torch.allclose(actual, expected, rtol=1e-4, atol=1e-5)
)
else:
# Test keying with debug handle tuple
for key in intermediate_outputs.keys():
self.assertIsInstance(key, tuple)

# Test tensor cloning and detaching
for output in intermediate_outputs.values():
if isinstance(output, torch.Tensor):
self.assertFalse(output.requires_grad)
self.assertTrue(output.is_leaf)

# Test placeholder nodes are skipped
for node in graph_module.graph.nodes:
if node.op == "placeholder":
self.assertNotIn(node.meta.get("debug_handle"), node.meta)

# Test multiple outputs capture
outputs = capturer.run_and_capture(input_tensor)
for output in outputs.values():
if isinstance(output, tuple):
self.assertEqual(len(output), 2)
for part in output:
self.assertIsInstance(part, torch.Tensor)

# Test capture correct outputs
self.assertTrue(
torch.allclose(actual_output, expected_output, rtol=1e-4, atol=1e-5)
check_if_final_outputs_match(model_name, intermediate_outputs)
)
Loading