|
1 | 1 | # Owner(s): ["module: dynamo"]
|
2 |
| -import copy |
3 |
| -from typing import Tuple |
4 | 2 | import unittest
|
5 | 3 |
|
6 |
| -import torch # noqa: F401 |
7 |
| -import torch.nn as nn |
8 |
| -import torch._dynamo as torchdynamo |
9 |
| -from functorch import make_fx |
10 |
| -from functorch.experimental import functionalize, control_flow |
| 4 | +import torch |
| 5 | +from functorch.experimental import control_flow |
11 | 6 | from torch import Tensor
|
12 | 7 | from torch.testing._internal.common_utils import run_tests, TestCase
|
13 | 8 | from torch._dynamo.eval_frame import is_dynamo_supported
|
|
16 | 11 | SpecViolationError,
|
17 | 12 | Verifier,
|
18 | 13 | ATenDialectVerifier,
|
19 |
| -) |
20 |
| - |
21 |
| - |
22 |
| -@torch.no_grad() |
23 |
| -def capture(f, args): |
24 |
| - torchdynamo.config.allow_rnn = True |
25 |
| - torchdynamo.reset() |
26 |
| - graphmodule, _ = torchdynamo.export( |
27 |
| - f, |
28 |
| - *copy.deepcopy(args), |
29 |
| - aten_graph=True, |
30 |
| - ) |
31 |
| - |
32 |
| - def graph_with_interpreter(*args): |
33 |
| - with torch.fx.traceback.preserve_node_meta(): |
34 |
| - return torch.fx.Interpreter(graphmodule).run(*args) |
35 |
| - |
36 |
| - functionalized_callable = functionalize( |
37 |
| - graph_with_interpreter, |
38 |
| - remove='mutations_and_views', |
39 |
| - ) |
40 |
| - gm = make_fx(functionalized_callable, tracing_mode='fake', _allow_non_fake_inputs=True)(*args) |
41 |
| - return gm |
42 |
| - |
43 |
| - |
44 |
| -class Transpose(nn.Module): |
45 |
| - def __init__(self) -> None: |
46 |
| - super().__init__() |
47 |
| - |
48 |
| - def forward(self, x: Tensor, dim0: int, dim1: int) -> Tensor: |
49 |
| - return x.transpose(dim0, dim1) |
50 | 14 |
|
| 15 | +) |
| 16 | +from torch._export import export |
51 | 17 |
|
52 |
| -class Mul(nn.Module): |
53 |
| - def __init__(self) -> None: |
54 |
| - super().__init__() |
55 |
| - |
56 |
| - def forward(self, input: Tensor, other: Tensor) -> Tensor: |
57 |
| - # or return torch.mul(input, other) |
58 |
| - return input * other |
59 |
| - |
60 |
| - def get_random_inputs(self) -> Tuple[Tensor, Tensor]: |
61 |
| - return (torch.randn(3, 2), torch.randn(3, 2)) |
62 |
| - |
63 |
| - |
64 |
| -class ElementwiseAdd(nn.Module): |
65 |
| - def __init__(self) -> None: |
66 |
| - super().__init__() |
67 |
| - |
68 |
| - def forward(self, x: Tensor, y: Tensor) -> Tensor: |
69 |
| - return x + y |
| 18 | +@unittest.skipIf(not is_dynamo_supported(), "dynamo isn't supported") |
| 19 | +class TestVerifier(TestCase): |
| 20 | + def test_verifier_basic(self) -> None: |
| 21 | + def f(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: |
| 22 | + return x + y |
70 | 23 |
|
71 |
| - def get_random_inputs(self) -> Tuple[Tensor, Tensor]: |
72 |
| - return (torch.randn(1, 3), torch.randn(1, 3)) |
| 24 | + ep = export(f, (torch.randn(100), torch.randn(100))) |
73 | 25 |
|
| 26 | + verifier = Verifier() |
| 27 | + verifier(ep.graph_module) |
74 | 28 |
|
75 |
| -class Cat(nn.Module): |
76 |
| - def __init__(self) -> None: |
77 |
| - super().__init__() |
| 29 | + def test_verifier_call_module(self) -> None: |
| 30 | + class M(torch.nn.Module): |
| 31 | + def __init__(self) -> None: |
| 32 | + super().__init__() |
| 33 | + self.linear = torch.nn.Linear(10, 10) |
78 | 34 |
|
79 |
| - # def forward(self, tensors, dim=0): |
80 |
| - def forward(self, *args: Tensor, dim: int) -> Tensor: |
81 |
| - tensors = args[:-1] |
82 |
| - return torch.cat(tensors, dim) |
| 35 | + def forward(self, x: Tensor) -> Tensor: |
| 36 | + return self.linear(x) |
83 | 37 |
|
| 38 | + gm = torch.fx.symbolic_trace(M()) |
84 | 39 |
|
85 |
| -class FeedForwardBlock(nn.Module): |
86 |
| - def __init__(self, input_dim: int, hidden_dim: int) -> None: |
87 |
| - super().__init__() |
88 |
| - self.input_dim = input_dim |
89 |
| - self.hidden_dim = hidden_dim |
| 40 | + verifier = Verifier() |
| 41 | + with self.assertRaises(SpecViolationError): |
| 42 | + verifier(gm) |
90 | 43 |
|
91 |
| - self.layer_norm = nn.LayerNorm(input_dim) |
| 44 | + def test_verifier_no_functional(self) -> None: |
| 45 | + def f(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: |
| 46 | + return x + y |
92 | 47 |
|
93 |
| - self.relu = nn.ReLU() |
| 48 | + ep = export(f, (torch.randn(100), torch.randn(100))) |
| 49 | + for node in ep.graph.nodes: |
| 50 | + if node.target == torch.ops.aten.add.Tensor: |
| 51 | + node.target = torch.ops.aten.add_.Tensor |
94 | 52 |
|
95 |
| - self.linear1 = nn.Linear(input_dim, hidden_dim) |
96 |
| - self.dropout1 = nn.Dropout() |
| 53 | + verifier = Verifier() |
| 54 | + with self.assertRaises(SpecViolationError): |
| 55 | + verifier(ep.graph_module) |
97 | 56 |
|
98 |
| - self.linear2 = nn.Linear(hidden_dim, input_dim) |
99 |
| - self.dropout2 = nn.Dropout() |
| 57 | + def test_verifier_higher_order(self) -> None: |
| 58 | + def f(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: |
| 59 | + def true_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: |
| 60 | + return x + y |
100 | 61 |
|
101 |
| - def forward(self, x: Tensor) -> Tensor: |
102 |
| - # LayerNorm -> Linear -> Dropout -> ReLU -> Linear -> Dropout |
103 |
| - y = self.layer_norm(x) |
104 |
| - y = self.linear1(y) |
105 |
| - y = self.dropout1(y) |
106 |
| - y = self.relu(y) |
107 |
| - y = self.linear2(y) |
108 |
| - y = self.dropout2(y) |
109 |
| - return y |
| 62 | + def false_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: |
| 63 | + return x - y |
110 | 64 |
|
111 |
| -class ControlFlow(nn.Module): |
| 65 | + return control_flow.cond( |
| 66 | + x.shape[0] > 2, true_fn, false_fn, [x, y] |
| 67 | + ) |
112 | 68 |
|
113 |
| - def forward(self, pred: Tensor, x: Tensor) -> Tensor: |
114 |
| - return control_flow.cond(pred, lambda x: x.sin(), lambda x: x.cos(), (x,)) |
| 69 | + ep = export(f, (torch.randn(3, 3), torch.randn(3, 3))) |
115 | 70 |
|
| 71 | + verifier = Verifier() |
| 72 | + verifier(ep.graph_module) |
116 | 73 |
|
117 |
| -class VerifierTest(TestCase): |
| 74 | + def test_verifier_nested_invalid_module(self) -> None: |
| 75 | + def f(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: |
| 76 | + def true_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: |
| 77 | + return x + y |
118 | 78 |
|
119 |
| - @unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported") |
120 |
| - def test_verifier(self) -> None: |
121 |
| - m = ElementwiseAdd() |
122 |
| - egm = capture(m, (torch.randn(100), torch.randn(100))) |
123 |
| - # assert not throw |
124 |
| - verifier = Verifier() |
125 |
| - verifier(egm) |
126 |
| - self.assertTrue(verifier.is_valid(egm)) |
| 79 | + def false_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: |
| 80 | + return x - y |
127 | 81 |
|
128 |
| - @unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported") |
129 |
| - def test_verifier_call_module(self) -> None: |
130 |
| - m = FeedForwardBlock(10, 10) |
131 |
| - gm = torch.fx.symbolic_trace(m) |
132 |
| - # this would have modules that are not delegates |
133 |
| - verifier = Verifier() |
134 |
| - with self.assertRaises(SpecViolationError): |
135 |
| - verifier(gm) |
136 |
| - self.assertFalse(verifier.is_valid(gm)) |
| 82 | + return control_flow.cond( |
| 83 | + x.shape[0] > 2, true_fn, false_fn, [x, y] |
| 84 | + ) |
137 | 85 |
|
138 |
| - @unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported") |
139 |
| - def test_verifier_no_functional(self) -> None: |
140 |
| - m = ElementwiseAdd() |
141 |
| - egm = capture(m, (torch.randn(100), torch.randn(100))) |
142 |
| - for node in egm.graph.nodes: |
| 86 | + ep = export(f, (torch.randn(3, 3), torch.randn(3, 3))) |
| 87 | + for node in ep.graph_module.true_graph_0.graph.nodes: |
143 | 88 | if node.target == torch.ops.aten.add.Tensor:
|
144 |
| - node.target = torch.ops.aten.add.out |
| 89 | + node.target = torch.ops.aten.add_.Tensor |
| 90 | + |
145 | 91 | verifier = Verifier()
|
146 | 92 | with self.assertRaises(SpecViolationError):
|
147 |
| - verifier(egm) |
148 |
| - self.assertFalse(verifier.is_valid(egm)) |
149 |
| - |
150 |
| - @unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported") |
151 |
| - def test_aten_dialect(self) -> None: |
152 |
| - m = ElementwiseAdd() |
153 |
| - egm = capture(m, (torch.randn(100), torch.randn(100))) |
154 |
| - verifier = ATenDialectVerifier() |
155 |
| - verifier(egm) |
156 |
| - self.assertTrue(verifier.is_valid(egm)) |
| 93 | + verifier(ep.graph_module) |
157 | 94 |
|
158 |
| - @unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported") |
159 |
| - def test_aten_wrong_mem_format(self) -> None: |
| 95 | + def test_aten_verifier_wrong_op(self) -> None: |
160 | 96 | class TestModel(torch.nn.Module):
|
161 | 97 | def __init__(self):
|
162 | 98 | super().__init__()
|
163 |
| - self.a = torch.nn.parameter.Parameter( |
164 |
| - torch.randn(1, 3, 100, 100).to(memory_format=torch.channels_last) |
165 |
| - ) |
166 | 99 |
|
167 | 100 | def forward(self, x):
|
168 |
| - return self.a + x |
| 101 | + return torch.ops.aten._add_relu(x, x) |
169 | 102 |
|
170 | 103 | m = TestModel()
|
171 |
| - egm = capture(m, (torch.randn(1, 3, 100, 100),)) |
172 |
| - egm._apply(lambda t: t.to(memory_format=torch.channels_last)) |
| 104 | + egm = torch.fx.symbolic_trace(m) |
173 | 105 | verifier = ATenDialectVerifier()
|
174 | 106 | with self.assertRaises(SpecViolationError):
|
175 | 107 | verifier(egm)
|
176 | 108 | self.assertFalse(verifier.is_valid(egm))
|
177 | 109 |
|
178 |
| - @unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported") |
179 |
| - def test_aten_wrong_mem_format_buffer(self) -> None: |
180 |
| - class TestModel(torch.nn.Module): |
| 110 | + def test_ep_verifier_basic(self) -> None: |
| 111 | + class M(torch.nn.Module): |
| 112 | + def __init__(self) -> None: |
| 113 | + super().__init__() |
| 114 | + self.linear = torch.nn.Linear(10, 10) |
| 115 | + |
| 116 | + def forward(self, x: Tensor) -> Tensor: |
| 117 | + return self.linear(x) |
| 118 | + |
| 119 | + ep = export(M(), (torch.randn(10, 10),)) |
| 120 | + ep._validate() |
| 121 | + |
| 122 | + def test_ep_verifier_invalid_param(self) -> None: |
| 123 | + def f(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: |
| 124 | + return x + y |
| 125 | + |
| 126 | + ep = export(f, (torch.randn(100), torch.randn(100))) |
| 127 | + |
| 128 | + # Parameter doesn't exist in the state dict |
| 129 | + ep.graph_signature.parameters.append("bad_param") |
| 130 | + with self.assertRaisesRegex(SpecViolationError, "not in the state dict"): |
| 131 | + ep._validate() |
| 132 | + |
| 133 | + # Add non-torch.nn.Parameter parameter to the state dict |
| 134 | + ep.state_dict["bad_param"] = torch.randn(100) |
| 135 | + with self.assertRaisesRegex( |
| 136 | + SpecViolationError, "not an instance of torch.nn.Parameter" |
| 137 | + ): |
| 138 | + ep._validate() |
| 139 | + |
| 140 | + # Add torch.nn.Parameter to state dict, but this should still error |
| 141 | + # because there are an incorrect number of placeholder nodes |
| 142 | + ep.state_dict["bad_param"] = torch.nn.Parameter(torch.randn(100)) |
| 143 | + with self.assertRaisesRegex( |
| 144 | + SpecViolationError, "not found in the exported program's parameter list" |
| 145 | + ): |
| 146 | + ep._validate() |
| 147 | + |
| 148 | + def test_ep_verifier_invalid_buffer(self) -> None: |
| 149 | + def f(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: |
| 150 | + return x + y |
| 151 | + |
| 152 | + ep = export(f, (torch.randn(100), torch.randn(100))) |
| 153 | + |
| 154 | + # Buffer doesn't exist in the state dict |
| 155 | + ep.graph_signature.buffers.append("bad_buffer") |
| 156 | + with self.assertRaisesRegex(SpecViolationError, "not in the state dict"): |
| 157 | + ep._validate() |
| 158 | + |
| 159 | + # Incorrect number of placeholder nodes |
| 160 | + ep.state_dict["bad_buffer"] = torch.randn(100) |
| 161 | + with self.assertRaisesRegex( |
| 162 | + SpecViolationError, "not found in the exported program's buffer list" |
| 163 | + ): |
| 164 | + ep._validate() |
| 165 | + |
| 166 | + def test_ep_verifier_buffer_mutate(self) -> None: |
| 167 | + class M(torch.nn.Module): |
181 | 168 | def __init__(self):
|
182 | 169 | super().__init__()
|
183 |
| - self.register_buffer( |
184 |
| - "a", |
185 |
| - torch.randn(1, 3, 100, 100).to(memory_format=torch.channels_last), |
186 |
| - ) |
187 | 170 |
|
188 |
| - def forward(self, x): |
189 |
| - return self.a + x |
| 171 | + self.my_parameter = torch.nn.Parameter(torch.tensor(2.0)) |
190 | 172 |
|
191 |
| - m = TestModel() |
192 |
| - egm = capture(m, (torch.randn(1, 3, 100, 100),)) |
193 |
| - egm._apply(lambda t: t.to(memory_format=torch.channels_last)) |
194 |
| - verifier = ATenDialectVerifier() |
195 |
| - with self.assertRaises(SpecViolationError): |
196 |
| - verifier(egm) |
197 |
| - self.assertFalse(verifier.is_valid(egm)) |
| 173 | + self.register_buffer("my_buffer1", torch.tensor(3.0)) |
| 174 | + self.register_buffer("my_buffer2", torch.tensor(4.0)) |
198 | 175 |
|
199 |
| - def test_aten_wrong_op(self) -> None: |
200 |
| - class TestModel(torch.nn.Module): |
201 |
| - def __init__(self): |
202 |
| - super().__init__() |
| 176 | + def forward(self, x1, x2): |
| 177 | + # Use the parameter, buffers, and both inputs in the forward method |
| 178 | + output = ( |
| 179 | + x1 + self.my_parameter |
| 180 | + ) * self.my_buffer1 + x2 * self.my_buffer2 |
203 | 181 |
|
204 |
| - def forward(self, x): |
205 |
| - return torch.ops.aten._add_relu(x, x) |
| 182 | + # Mutate one of the buffers (e.g., increment it by 1) |
| 183 | + self.my_buffer2.add_(1.0) |
| 184 | + return output |
206 | 185 |
|
207 |
| - m = TestModel() |
208 |
| - egm = torch.fx.symbolic_trace(m) |
209 |
| - verifier = ATenDialectVerifier() |
210 |
| - with self.assertRaises(SpecViolationError): |
211 |
| - verifier(egm) |
212 |
| - self.assertFalse(verifier.is_valid(egm)) |
| 186 | + ep = export(M(), (torch.tensor(5.0), torch.tensor(6.0))) |
| 187 | + ep._validate() |
213 | 188 |
|
214 |
| - @unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported") |
215 |
| - def test_verifier_control_flow_success(self) -> None: |
216 |
| - m = ControlFlow() |
217 |
| - gm = torch._export.export(m, (torch.tensor(True), torch.randn(3, 4))).graph_module |
218 |
| - # No error should be raised |
219 |
| - verifier = ATenDialectVerifier() |
220 |
| - verifier(gm) |
| 189 | + def test_ep_verifier_invalid_output(self) -> None: |
| 190 | + class M(torch.nn.Module): |
| 191 | + def __init__(self): |
| 192 | + super().__init__() |
| 193 | + |
| 194 | + self.my_parameter = torch.nn.Parameter(torch.tensor(2.0)) |
| 195 | + |
| 196 | + self.register_buffer("my_buffer1", torch.tensor(3.0)) |
| 197 | + self.register_buffer("my_buffer2", torch.tensor(4.0)) |
| 198 | + |
| 199 | + def forward(self, x1, x2): |
| 200 | + # Use the parameter, buffers, and both inputs in the forward method |
| 201 | + output = ( |
| 202 | + x1 + self.my_parameter |
| 203 | + ) * self.my_buffer1 + x2 * self.my_buffer2 |
| 204 | + |
| 205 | + # Mutate one of the buffers (e.g., increment it by 1) |
| 206 | + self.my_buffer2.add_(1.0) |
| 207 | + return output |
| 208 | + |
| 209 | + ep = export(M(), (torch.tensor(5.0), torch.tensor(6.0))) |
| 210 | + |
| 211 | + output_node = list(ep.graph.nodes)[-1] |
| 212 | + with ep.graph.inserting_before(output_node): |
| 213 | + additional_output_node = ep.graph.call_function( |
| 214 | + torch.add, args=(output_node.args[0][0], output_node.args[0][0]) |
| 215 | + ) |
| 216 | + output_node.args = ( |
| 217 | + ( |
| 218 | + output_node.args[0][0], |
| 219 | + additional_output_node, |
| 220 | + output_node.args[0][1], |
| 221 | + ), |
| 222 | + ) |
| 223 | + |
| 224 | + with self.assertRaisesRegex(SpecViolationError, "Number of output nodes"): |
| 225 | + ep._validate() |
221 | 226 |
|
222 | 227 |
|
223 | 228 | if __name__ == '__main__':
|
|
0 commit comments