|
4 | 4 | import math
|
5 | 5 |
|
6 | 6 | from dataclasses import dataclass
|
7 |
| -from typing import List |
8 | 7 |
|
9 | 8 | import torch
|
10 | 9 |
|
11 | 10 | import torch._dynamo.test_case
|
12 | 11 | import torch._dynamo.testing
|
13 | 12 | import torch._dynamo.utils
|
14 |
| -from functorch.compile import aot_module_simplified |
15 | 13 | from torch.testing._internal.triton_utils import HAS_CUDA, requires_cuda
|
16 | 14 |
|
17 | 15 | if HAS_CUDA:
|
@@ -225,69 +223,6 @@ def forward(self, x):
|
225 | 223 | return self.f(x)
|
226 | 224 |
|
227 | 225 |
|
228 |
| -@torch.library.custom_op("_torch_testing::custom_op_forward", mutates_args=()) |
229 |
| -def custom_op_forward( |
230 |
| - foo: torch.Tensor, |
231 |
| - bar: torch.Tensor, |
232 |
| - shape: List[int], |
233 |
| -) -> torch.Tensor: |
234 |
| - return torch.ones_like(foo) |
235 |
| - |
236 |
| - |
237 |
| -@custom_op_forward.register_fake |
238 |
| -def _(foo, bar, weight): |
239 |
| - return torch.empty_like(foo) |
240 |
| - |
241 |
| - |
242 |
| -@torch.library.custom_op("_torch_testing::custom_op_backward", mutates_args=()) |
243 |
| -def custom_op_backward( |
244 |
| - grad_output: torch.Tensor, |
245 |
| - foo: torch.Tensor, |
246 |
| - bar: torch.Tensor, |
247 |
| - shape: List[int], |
248 |
| -) -> torch.Tensor: |
249 |
| - assert list(bar.shape) == shape |
250 |
| - return torch.ones_like(bar) |
251 |
| - |
252 |
| - |
253 |
| -@custom_op_backward.register_fake |
254 |
| -def _(grad_output, foo, bar, shape): |
255 |
| - return torch.empty_like(bar) |
256 |
| - |
257 |
| - |
258 |
| -class CustomOpFunc(torch.autograd.Function): |
259 |
| - @staticmethod |
260 |
| - def forward(ctx, input, weight, normalized_shape): |
261 |
| - ctx.normalized_shape = normalized_shape |
262 |
| - input_ = input.contiguous() |
263 |
| - weight_ = weight.contiguous() |
264 |
| - output = custom_op_forward(input_, weight_, ctx.normalized_shape) |
265 |
| - ctx.save_for_backward(input_, weight_) |
266 |
| - return output |
267 |
| - |
268 |
| - @staticmethod |
269 |
| - def backward(ctx, grad_output): |
270 |
| - input_, weight_ = ctx.saved_tensors |
271 |
| - # grad_weight = a_func(grad_output, input_, weight_, ctx.normalized_shape) |
272 |
| - grad_weight = custom_op_backward( |
273 |
| - grad_output.contiguous(), |
274 |
| - input_, |
275 |
| - weight_, |
276 |
| - ctx.normalized_shape, |
277 |
| - ) |
278 |
| - return None, grad_weight, None |
279 |
| - |
280 |
| - |
281 |
| -class CustomOpModule(torch.nn.Module): |
282 |
| - def __init__(self, shape): |
283 |
| - super().__init__() |
284 |
| - self.shape = shape |
285 |
| - self.weight = torch.nn.Parameter(torch.ones(self.shape)) |
286 |
| - |
287 |
| - def forward(self, x): |
288 |
| - return CustomOpFunc.apply(x, self.weight, self.shape) |
289 |
| - |
290 |
| - |
291 | 226 | class AutogradFunctionTests(torch._dynamo.test_case.TestCase):
|
292 | 227 | # Sound behaviors, tested for working capture
|
293 | 228 | def test_autograd_function_equivalence(self):
|
@@ -592,29 +527,18 @@ def forward(self, L_x_: "f32[]", L_z_: "f32[]", L_weird_b: "f32[]", L_weird_c: "
|
592 | 527 |
|
593 | 528 | class GraphModule(torch.nn.Module):
|
594 | 529 | def forward(self, ctx, x: "f32[]", z: "f32[]", l_weird_b: "f32[]", l_weird_c: "f32[]"):
|
595 |
| - ctx_1 = ctx |
596 |
| - x_1 = x |
597 |
| - z_1 = z |
598 |
| - l_weird_b_1 = l_weird_b |
599 |
| - l_weird_c_1 = l_weird_c |
600 |
| -
|
601 |
| - mul: "f32[]" = l_weird_b_1 * l_weird_c_1 |
602 |
| - clone: "f32[]" = x_1.clone(); x_1 = None |
| 530 | + mul: "f32[]" = l_weird_b * l_weird_c |
| 531 | + clone: "f32[]" = x.clone(); x = None |
603 | 532 | mul_1: "f32[]" = mul * clone; mul = clone = None
|
604 |
| - return (mul_1, [l_weird_b_1, l_weird_c_1]) |
| 533 | + return (mul_1, [l_weird_b, l_weird_c]) |
605 | 534 |
|
606 | 535 | class GraphModule(torch.nn.Module):
|
607 | 536 | def forward(self, ctx, grad: "f32[]", l_weird_b: "f32[]", l_weird_c: "f32[]"):
|
608 |
| - ctx_1 = ctx |
609 |
| - grad_1 = grad |
610 |
| - l_weird_b_1 = l_weird_b |
611 |
| - l_weird_c_1 = l_weird_c |
612 |
| -
|
613 | 537 | _set_grad_enabled = torch._C._set_grad_enabled(False)
|
614 | 538 |
|
615 |
| - mul: "f32[]" = grad_1 * l_weird_b_1; l_weird_b_1 = None |
616 |
| - mul_1: "f32[]" = mul * l_weird_c_1; mul = l_weird_c_1 = None |
617 |
| - mul_2: "f32[]" = grad_1 * 2; grad_1 = None |
| 539 | + mul: "f32[]" = grad * l_weird_b; l_weird_b = None |
| 540 | + mul_1: "f32[]" = mul * l_weird_c; mul = l_weird_c = None |
| 541 | + mul_2: "f32[]" = grad * 2; grad = None |
618 | 542 |
|
619 | 543 | _set_grad_enabled_1 = torch._C._set_grad_enabled(True)
|
620 | 544 | return (mul_1, mul_2)
|
@@ -1179,22 +1103,6 @@ def fn():
|
1179 | 1103 | self.assertEqual(cnt.frame_count, 1)
|
1180 | 1104 | self.assertEqual(cnt.op_count, 2)
|
1181 | 1105 |
|
1182 |
| - def test_custom_op(self): |
1183 |
| - shape = [7] |
1184 |
| - x = torch.rand(128, shape[0]) |
1185 |
| - model = CustomOpModule(shape) |
1186 |
| - out = model(x) |
1187 |
| - |
1188 |
| - def backend(gm, example_inputs): |
1189 |
| - return aot_module_simplified( |
1190 |
| - gm, example_inputs, fw_compiler=lambda gm, _: gm |
1191 |
| - ) |
1192 |
| - |
1193 |
| - opt_model = torch.compile(model, backend=backend) |
1194 |
| - opt_out = opt_model(x) |
1195 |
| - opt_out.mean().backward() |
1196 |
| - self.assertEqual(out, opt_out) |
1197 |
| - |
1198 | 1106 | @requires_cuda
|
1199 | 1107 | def test_triton_kernel_basic(self):
|
1200 | 1108 | class Add(torch.autograd.Function):
|
|
0 commit comments