Skip to content

Commit a740969

Browse files
angelayipytorchmergebot
authored andcommitted
[export] Verifier for exported program (#109519)
Summary: X-link: pytorch/executorch#292 Added a verifier for the graph signature in a exported program Test Plan: CI Differential Revision: D48926643 Pull Request resolved: #109519 Approved by: https://github.com/zhxchen17
1 parent 0a60219 commit a740969

File tree

3 files changed

+426
-242
lines changed

3 files changed

+426
-242
lines changed

test/export/test_verifier.py

Lines changed: 170 additions & 165 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,8 @@
11
# Owner(s): ["module: dynamo"]
2-
import copy
3-
from typing import Tuple
42
import unittest
53

6-
import torch # noqa: F401
7-
import torch.nn as nn
8-
import torch._dynamo as torchdynamo
9-
from functorch import make_fx
10-
from functorch.experimental import functionalize, control_flow
4+
import torch
5+
from functorch.experimental import control_flow
116
from torch import Tensor
127
from torch.testing._internal.common_utils import run_tests, TestCase
138
from torch._dynamo.eval_frame import is_dynamo_supported
@@ -16,208 +11,218 @@
1611
SpecViolationError,
1712
Verifier,
1813
ATenDialectVerifier,
19-
)
20-
21-
22-
@torch.no_grad()
23-
def capture(f, args):
24-
torchdynamo.config.allow_rnn = True
25-
torchdynamo.reset()
26-
graphmodule, _ = torchdynamo.export(
27-
f,
28-
*copy.deepcopy(args),
29-
aten_graph=True,
30-
)
31-
32-
def graph_with_interpreter(*args):
33-
with torch.fx.traceback.preserve_node_meta():
34-
return torch.fx.Interpreter(graphmodule).run(*args)
35-
36-
functionalized_callable = functionalize(
37-
graph_with_interpreter,
38-
remove='mutations_and_views',
39-
)
40-
gm = make_fx(functionalized_callable, tracing_mode='fake', _allow_non_fake_inputs=True)(*args)
41-
return gm
42-
43-
44-
class Transpose(nn.Module):
45-
def __init__(self) -> None:
46-
super().__init__()
47-
48-
def forward(self, x: Tensor, dim0: int, dim1: int) -> Tensor:
49-
return x.transpose(dim0, dim1)
5014

15+
)
16+
from torch._export import export
5117

52-
class Mul(nn.Module):
53-
def __init__(self) -> None:
54-
super().__init__()
55-
56-
def forward(self, input: Tensor, other: Tensor) -> Tensor:
57-
# or return torch.mul(input, other)
58-
return input * other
59-
60-
def get_random_inputs(self) -> Tuple[Tensor, Tensor]:
61-
return (torch.randn(3, 2), torch.randn(3, 2))
62-
63-
64-
class ElementwiseAdd(nn.Module):
65-
def __init__(self) -> None:
66-
super().__init__()
67-
68-
def forward(self, x: Tensor, y: Tensor) -> Tensor:
69-
return x + y
18+
@unittest.skipIf(not is_dynamo_supported(), "dynamo isn't supported")
19+
class TestVerifier(TestCase):
20+
def test_verifier_basic(self) -> None:
21+
def f(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
22+
return x + y
7023

71-
def get_random_inputs(self) -> Tuple[Tensor, Tensor]:
72-
return (torch.randn(1, 3), torch.randn(1, 3))
24+
ep = export(f, (torch.randn(100), torch.randn(100)))
7325

26+
verifier = Verifier()
27+
verifier(ep.graph_module)
7428

75-
class Cat(nn.Module):
76-
def __init__(self) -> None:
77-
super().__init__()
29+
def test_verifier_call_module(self) -> None:
30+
class M(torch.nn.Module):
31+
def __init__(self) -> None:
32+
super().__init__()
33+
self.linear = torch.nn.Linear(10, 10)
7834

79-
# def forward(self, tensors, dim=0):
80-
def forward(self, *args: Tensor, dim: int) -> Tensor:
81-
tensors = args[:-1]
82-
return torch.cat(tensors, dim)
35+
def forward(self, x: Tensor) -> Tensor:
36+
return self.linear(x)
8337

38+
gm = torch.fx.symbolic_trace(M())
8439

85-
class FeedForwardBlock(nn.Module):
86-
def __init__(self, input_dim: int, hidden_dim: int) -> None:
87-
super().__init__()
88-
self.input_dim = input_dim
89-
self.hidden_dim = hidden_dim
40+
verifier = Verifier()
41+
with self.assertRaises(SpecViolationError):
42+
verifier(gm)
9043

91-
self.layer_norm = nn.LayerNorm(input_dim)
44+
def test_verifier_no_functional(self) -> None:
45+
def f(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
46+
return x + y
9247

93-
self.relu = nn.ReLU()
48+
ep = export(f, (torch.randn(100), torch.randn(100)))
49+
for node in ep.graph.nodes:
50+
if node.target == torch.ops.aten.add.Tensor:
51+
node.target = torch.ops.aten.add_.Tensor
9452

95-
self.linear1 = nn.Linear(input_dim, hidden_dim)
96-
self.dropout1 = nn.Dropout()
53+
verifier = Verifier()
54+
with self.assertRaises(SpecViolationError):
55+
verifier(ep.graph_module)
9756

98-
self.linear2 = nn.Linear(hidden_dim, input_dim)
99-
self.dropout2 = nn.Dropout()
57+
def test_verifier_higher_order(self) -> None:
58+
def f(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
59+
def true_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
60+
return x + y
10061

101-
def forward(self, x: Tensor) -> Tensor:
102-
# LayerNorm -> Linear -> Dropout -> ReLU -> Linear -> Dropout
103-
y = self.layer_norm(x)
104-
y = self.linear1(y)
105-
y = self.dropout1(y)
106-
y = self.relu(y)
107-
y = self.linear2(y)
108-
y = self.dropout2(y)
109-
return y
62+
def false_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
63+
return x - y
11064

111-
class ControlFlow(nn.Module):
65+
return control_flow.cond(
66+
x.shape[0] > 2, true_fn, false_fn, [x, y]
67+
)
11268

113-
def forward(self, pred: Tensor, x: Tensor) -> Tensor:
114-
return control_flow.cond(pred, lambda x: x.sin(), lambda x: x.cos(), (x,))
69+
ep = export(f, (torch.randn(3, 3), torch.randn(3, 3)))
11570

71+
verifier = Verifier()
72+
verifier(ep.graph_module)
11673

117-
class VerifierTest(TestCase):
74+
def test_verifier_nested_invalid_module(self) -> None:
75+
def f(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
76+
def true_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
77+
return x + y
11878

119-
@unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
120-
def test_verifier(self) -> None:
121-
m = ElementwiseAdd()
122-
egm = capture(m, (torch.randn(100), torch.randn(100)))
123-
# assert not throw
124-
verifier = Verifier()
125-
verifier(egm)
126-
self.assertTrue(verifier.is_valid(egm))
79+
def false_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
80+
return x - y
12781

128-
@unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
129-
def test_verifier_call_module(self) -> None:
130-
m = FeedForwardBlock(10, 10)
131-
gm = torch.fx.symbolic_trace(m)
132-
# this would have modules that are not delegates
133-
verifier = Verifier()
134-
with self.assertRaises(SpecViolationError):
135-
verifier(gm)
136-
self.assertFalse(verifier.is_valid(gm))
82+
return control_flow.cond(
83+
x.shape[0] > 2, true_fn, false_fn, [x, y]
84+
)
13785

138-
@unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
139-
def test_verifier_no_functional(self) -> None:
140-
m = ElementwiseAdd()
141-
egm = capture(m, (torch.randn(100), torch.randn(100)))
142-
for node in egm.graph.nodes:
86+
ep = export(f, (torch.randn(3, 3), torch.randn(3, 3)))
87+
for node in ep.graph_module.true_graph_0.graph.nodes:
14388
if node.target == torch.ops.aten.add.Tensor:
144-
node.target = torch.ops.aten.add.out
89+
node.target = torch.ops.aten.add_.Tensor
90+
14591
verifier = Verifier()
14692
with self.assertRaises(SpecViolationError):
147-
verifier(egm)
148-
self.assertFalse(verifier.is_valid(egm))
149-
150-
@unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
151-
def test_aten_dialect(self) -> None:
152-
m = ElementwiseAdd()
153-
egm = capture(m, (torch.randn(100), torch.randn(100)))
154-
verifier = ATenDialectVerifier()
155-
verifier(egm)
156-
self.assertTrue(verifier.is_valid(egm))
93+
verifier(ep.graph_module)
15794

158-
@unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
159-
def test_aten_wrong_mem_format(self) -> None:
95+
def test_aten_verifier_wrong_op(self) -> None:
16096
class TestModel(torch.nn.Module):
16197
def __init__(self):
16298
super().__init__()
163-
self.a = torch.nn.parameter.Parameter(
164-
torch.randn(1, 3, 100, 100).to(memory_format=torch.channels_last)
165-
)
16699

167100
def forward(self, x):
168-
return self.a + x
101+
return torch.ops.aten._add_relu(x, x)
169102

170103
m = TestModel()
171-
egm = capture(m, (torch.randn(1, 3, 100, 100),))
172-
egm._apply(lambda t: t.to(memory_format=torch.channels_last))
104+
egm = torch.fx.symbolic_trace(m)
173105
verifier = ATenDialectVerifier()
174106
with self.assertRaises(SpecViolationError):
175107
verifier(egm)
176108
self.assertFalse(verifier.is_valid(egm))
177109

178-
@unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
179-
def test_aten_wrong_mem_format_buffer(self) -> None:
180-
class TestModel(torch.nn.Module):
110+
def test_ep_verifier_basic(self) -> None:
111+
class M(torch.nn.Module):
112+
def __init__(self) -> None:
113+
super().__init__()
114+
self.linear = torch.nn.Linear(10, 10)
115+
116+
def forward(self, x: Tensor) -> Tensor:
117+
return self.linear(x)
118+
119+
ep = export(M(), (torch.randn(10, 10),))
120+
ep._validate()
121+
122+
def test_ep_verifier_invalid_param(self) -> None:
123+
def f(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
124+
return x + y
125+
126+
ep = export(f, (torch.randn(100), torch.randn(100)))
127+
128+
# Parameter doesn't exist in the state dict
129+
ep.graph_signature.parameters.append("bad_param")
130+
with self.assertRaisesRegex(SpecViolationError, "not in the state dict"):
131+
ep._validate()
132+
133+
# Add non-torch.nn.Parameter parameter to the state dict
134+
ep.state_dict["bad_param"] = torch.randn(100)
135+
with self.assertRaisesRegex(
136+
SpecViolationError, "not an instance of torch.nn.Parameter"
137+
):
138+
ep._validate()
139+
140+
# Add torch.nn.Parameter to state dict, but this should still error
141+
# because there are an incorrect number of placeholder nodes
142+
ep.state_dict["bad_param"] = torch.nn.Parameter(torch.randn(100))
143+
with self.assertRaisesRegex(
144+
SpecViolationError, "not found in the exported program's parameter list"
145+
):
146+
ep._validate()
147+
148+
def test_ep_verifier_invalid_buffer(self) -> None:
149+
def f(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
150+
return x + y
151+
152+
ep = export(f, (torch.randn(100), torch.randn(100)))
153+
154+
# Buffer doesn't exist in the state dict
155+
ep.graph_signature.buffers.append("bad_buffer")
156+
with self.assertRaisesRegex(SpecViolationError, "not in the state dict"):
157+
ep._validate()
158+
159+
# Incorrect number of placeholder nodes
160+
ep.state_dict["bad_buffer"] = torch.randn(100)
161+
with self.assertRaisesRegex(
162+
SpecViolationError, "not found in the exported program's buffer list"
163+
):
164+
ep._validate()
165+
166+
def test_ep_verifier_buffer_mutate(self) -> None:
167+
class M(torch.nn.Module):
181168
def __init__(self):
182169
super().__init__()
183-
self.register_buffer(
184-
"a",
185-
torch.randn(1, 3, 100, 100).to(memory_format=torch.channels_last),
186-
)
187170

188-
def forward(self, x):
189-
return self.a + x
171+
self.my_parameter = torch.nn.Parameter(torch.tensor(2.0))
190172

191-
m = TestModel()
192-
egm = capture(m, (torch.randn(1, 3, 100, 100),))
193-
egm._apply(lambda t: t.to(memory_format=torch.channels_last))
194-
verifier = ATenDialectVerifier()
195-
with self.assertRaises(SpecViolationError):
196-
verifier(egm)
197-
self.assertFalse(verifier.is_valid(egm))
173+
self.register_buffer("my_buffer1", torch.tensor(3.0))
174+
self.register_buffer("my_buffer2", torch.tensor(4.0))
198175

199-
def test_aten_wrong_op(self) -> None:
200-
class TestModel(torch.nn.Module):
201-
def __init__(self):
202-
super().__init__()
176+
def forward(self, x1, x2):
177+
# Use the parameter, buffers, and both inputs in the forward method
178+
output = (
179+
x1 + self.my_parameter
180+
) * self.my_buffer1 + x2 * self.my_buffer2
203181

204-
def forward(self, x):
205-
return torch.ops.aten._add_relu(x, x)
182+
# Mutate one of the buffers (e.g., increment it by 1)
183+
self.my_buffer2.add_(1.0)
184+
return output
206185

207-
m = TestModel()
208-
egm = torch.fx.symbolic_trace(m)
209-
verifier = ATenDialectVerifier()
210-
with self.assertRaises(SpecViolationError):
211-
verifier(egm)
212-
self.assertFalse(verifier.is_valid(egm))
186+
ep = export(M(), (torch.tensor(5.0), torch.tensor(6.0)))
187+
ep._validate()
213188

214-
@unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
215-
def test_verifier_control_flow_success(self) -> None:
216-
m = ControlFlow()
217-
gm = torch._export.export(m, (torch.tensor(True), torch.randn(3, 4))).graph_module
218-
# No error should be raised
219-
verifier = ATenDialectVerifier()
220-
verifier(gm)
189+
def test_ep_verifier_invalid_output(self) -> None:
190+
class M(torch.nn.Module):
191+
def __init__(self):
192+
super().__init__()
193+
194+
self.my_parameter = torch.nn.Parameter(torch.tensor(2.0))
195+
196+
self.register_buffer("my_buffer1", torch.tensor(3.0))
197+
self.register_buffer("my_buffer2", torch.tensor(4.0))
198+
199+
def forward(self, x1, x2):
200+
# Use the parameter, buffers, and both inputs in the forward method
201+
output = (
202+
x1 + self.my_parameter
203+
) * self.my_buffer1 + x2 * self.my_buffer2
204+
205+
# Mutate one of the buffers (e.g., increment it by 1)
206+
self.my_buffer2.add_(1.0)
207+
return output
208+
209+
ep = export(M(), (torch.tensor(5.0), torch.tensor(6.0)))
210+
211+
output_node = list(ep.graph.nodes)[-1]
212+
with ep.graph.inserting_before(output_node):
213+
additional_output_node = ep.graph.call_function(
214+
torch.add, args=(output_node.args[0][0], output_node.args[0][0])
215+
)
216+
output_node.args = (
217+
(
218+
output_node.args[0][0],
219+
additional_output_node,
220+
output_node.args[0][1],
221+
),
222+
)
223+
224+
with self.assertRaisesRegex(SpecViolationError, "Number of output nodes"):
225+
ep._validate()
221226

222227

223228
if __name__ == '__main__':

0 commit comments

Comments
 (0)