|
4 | 4 | import math
|
5 | 5 |
|
6 | 6 | from dataclasses import dataclass
|
| 7 | +from typing import List |
7 | 8 |
|
8 | 9 | import torch
|
9 | 10 |
|
10 | 11 | import torch._dynamo.test_case
|
11 | 12 | import torch._dynamo.testing
|
12 | 13 | import torch._dynamo.utils
|
| 14 | +from functorch.compile import aot_module_simplified |
13 | 15 | from torch.testing._internal.triton_utils import HAS_CUDA, requires_cuda
|
14 | 16 |
|
15 | 17 | if HAS_CUDA:
|
@@ -223,6 +225,69 @@ def forward(self, x):
|
223 | 225 | return self.f(x)
|
224 | 226 |
|
225 | 227 |
|
| 228 | +@torch.library.custom_op("_torch_testing::custom_op_forward", mutates_args=()) |
| 229 | +def custom_op_forward( |
| 230 | + foo: torch.Tensor, |
| 231 | + bar: torch.Tensor, |
| 232 | + shape: List[int], |
| 233 | +) -> torch.Tensor: |
| 234 | + return torch.ones_like(foo) |
| 235 | + |
| 236 | + |
| 237 | +@custom_op_forward.register_fake |
| 238 | +def _(foo, bar, weight): |
| 239 | + return torch.empty_like(foo) |
| 240 | + |
| 241 | + |
| 242 | +@torch.library.custom_op("_torch_testing::custom_op_backward", mutates_args=()) |
| 243 | +def custom_op_backward( |
| 244 | + grad_output: torch.Tensor, |
| 245 | + foo: torch.Tensor, |
| 246 | + bar: torch.Tensor, |
| 247 | + shape: List[int], |
| 248 | +) -> torch.Tensor: |
| 249 | + assert list(bar.shape) == shape |
| 250 | + return torch.ones_like(bar) |
| 251 | + |
| 252 | + |
| 253 | +@custom_op_backward.register_fake |
| 254 | +def _(grad_output, foo, bar, shape): |
| 255 | + return torch.empty_like(bar) |
| 256 | + |
| 257 | + |
| 258 | +class CustomOpFunc(torch.autograd.Function): |
| 259 | + @staticmethod |
| 260 | + def forward(ctx, input, weight, normalized_shape): |
| 261 | + ctx.normalized_shape = normalized_shape |
| 262 | + input_ = input.contiguous() |
| 263 | + weight_ = weight.contiguous() |
| 264 | + output = custom_op_forward(input_, weight_, ctx.normalized_shape) |
| 265 | + ctx.save_for_backward(input_, weight_) |
| 266 | + return output |
| 267 | + |
| 268 | + @staticmethod |
| 269 | + def backward(ctx, grad_output): |
| 270 | + input_, weight_ = ctx.saved_tensors |
| 271 | + # grad_weight = a_func(grad_output, input_, weight_, ctx.normalized_shape) |
| 272 | + grad_weight = custom_op_backward( |
| 273 | + grad_output.contiguous(), |
| 274 | + input_, |
| 275 | + weight_, |
| 276 | + ctx.normalized_shape, |
| 277 | + ) |
| 278 | + return None, grad_weight, None |
| 279 | + |
| 280 | + |
| 281 | +class CustomOpModule(torch.nn.Module): |
| 282 | + def __init__(self, shape): |
| 283 | + super().__init__() |
| 284 | + self.shape = shape |
| 285 | + self.weight = torch.nn.Parameter(torch.ones(self.shape)) |
| 286 | + |
| 287 | + def forward(self, x): |
| 288 | + return CustomOpFunc.apply(x, self.weight, self.shape) |
| 289 | + |
| 290 | + |
226 | 291 | class AutogradFunctionTests(torch._dynamo.test_case.TestCase):
|
227 | 292 | # Sound behaviors, tested for working capture
|
228 | 293 | def test_autograd_function_equivalence(self):
|
@@ -527,18 +592,29 @@ def forward(self, L_x_: "f32[]", L_z_: "f32[]", L_weird_b: "f32[]", L_weird_c: "
|
527 | 592 |
|
528 | 593 | class GraphModule(torch.nn.Module):
|
529 | 594 | def forward(self, ctx, x: "f32[]", z: "f32[]", l_weird_b: "f32[]", l_weird_c: "f32[]"):
|
530 |
| - mul: "f32[]" = l_weird_b * l_weird_c |
531 |
| - clone: "f32[]" = x.clone(); x = None |
| 595 | + ctx_1 = ctx |
| 596 | + x_1 = x |
| 597 | + z_1 = z |
| 598 | + l_weird_b_1 = l_weird_b |
| 599 | + l_weird_c_1 = l_weird_c |
| 600 | +
|
| 601 | + mul: "f32[]" = l_weird_b_1 * l_weird_c_1 |
| 602 | + clone: "f32[]" = x_1.clone(); x_1 = None |
532 | 603 | mul_1: "f32[]" = mul * clone; mul = clone = None
|
533 |
| - return (mul_1, [l_weird_b, l_weird_c]) |
| 604 | + return (mul_1, [l_weird_b_1, l_weird_c_1]) |
534 | 605 |
|
535 | 606 | class GraphModule(torch.nn.Module):
|
536 | 607 | def forward(self, ctx, grad: "f32[]", l_weird_b: "f32[]", l_weird_c: "f32[]"):
|
| 608 | + ctx_1 = ctx |
| 609 | + grad_1 = grad |
| 610 | + l_weird_b_1 = l_weird_b |
| 611 | + l_weird_c_1 = l_weird_c |
| 612 | +
|
537 | 613 | _set_grad_enabled = torch._C._set_grad_enabled(False)
|
538 | 614 |
|
539 |
| - mul: "f32[]" = grad * l_weird_b; l_weird_b = None |
540 |
| - mul_1: "f32[]" = mul * l_weird_c; mul = l_weird_c = None |
541 |
| - mul_2: "f32[]" = grad * 2; grad = None |
| 615 | + mul: "f32[]" = grad_1 * l_weird_b_1; l_weird_b_1 = None |
| 616 | + mul_1: "f32[]" = mul * l_weird_c_1; mul = l_weird_c_1 = None |
| 617 | + mul_2: "f32[]" = grad_1 * 2; grad_1 = None |
542 | 618 |
|
543 | 619 | _set_grad_enabled_1 = torch._C._set_grad_enabled(True)
|
544 | 620 | return (mul_1, mul_2)
|
@@ -1103,6 +1179,22 @@ def fn():
|
1103 | 1179 | self.assertEqual(cnt.frame_count, 1)
|
1104 | 1180 | self.assertEqual(cnt.op_count, 2)
|
1105 | 1181 |
|
| 1182 | + def test_custom_op(self): |
| 1183 | + shape = [7] |
| 1184 | + x = torch.rand(128, shape[0]) |
| 1185 | + model = CustomOpModule(shape) |
| 1186 | + out = model(x) |
| 1187 | + |
| 1188 | + def backend(gm, example_inputs): |
| 1189 | + return aot_module_simplified( |
| 1190 | + gm, example_inputs, fw_compiler=lambda gm, _: gm |
| 1191 | + ) |
| 1192 | + |
| 1193 | + opt_model = torch.compile(model, backend=backend) |
| 1194 | + opt_out = opt_model(x) |
| 1195 | + opt_out.mean().backward() |
| 1196 | + self.assertEqual(out, opt_out) |
| 1197 | + |
1106 | 1198 | @requires_cuda
|
1107 | 1199 | def test_triton_kernel_basic(self):
|
1108 | 1200 | class Add(torch.autograd.Function):
|
|
0 commit comments