Skip to content

Commit dff9d68

Browse files
Revert "Fix names conflict when lifting (pytorch#129817)"
This reverts commit 53cf46b. Reverted pytorch#129817 on behalf of https://github.com/clee2000 due to Failing inductor/test_flex_attention.py https://github.com/pytorch/pytorch/actions/runs/9940532858/job/27478084137 https://hud.pytorch.org/pytorch/pytorch/commit/74da2a467f166e00316aee82ba24835ca563ed87 Sorry for the churn, possibly a landrace? ([comment](pytorch#129817 (comment)))
1 parent 78799e8 commit dff9d68

File tree

9 files changed

+64
-198
lines changed

9 files changed

+64
-198
lines changed

test/dynamo/test_autograd_function.py

Lines changed: 6 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,12 @@
44
import math
55

66
from dataclasses import dataclass
7-
from typing import List
87

98
import torch
109

1110
import torch._dynamo.test_case
1211
import torch._dynamo.testing
1312
import torch._dynamo.utils
14-
from functorch.compile import aot_module_simplified
1513
from torch.testing._internal.triton_utils import HAS_CUDA, requires_cuda
1614

1715
if HAS_CUDA:
@@ -225,69 +223,6 @@ def forward(self, x):
225223
return self.f(x)
226224

227225

228-
@torch.library.custom_op("_torch_testing::custom_op_forward", mutates_args=())
229-
def custom_op_forward(
230-
foo: torch.Tensor,
231-
bar: torch.Tensor,
232-
shape: List[int],
233-
) -> torch.Tensor:
234-
return torch.ones_like(foo)
235-
236-
237-
@custom_op_forward.register_fake
238-
def _(foo, bar, weight):
239-
return torch.empty_like(foo)
240-
241-
242-
@torch.library.custom_op("_torch_testing::custom_op_backward", mutates_args=())
243-
def custom_op_backward(
244-
grad_output: torch.Tensor,
245-
foo: torch.Tensor,
246-
bar: torch.Tensor,
247-
shape: List[int],
248-
) -> torch.Tensor:
249-
assert list(bar.shape) == shape
250-
return torch.ones_like(bar)
251-
252-
253-
@custom_op_backward.register_fake
254-
def _(grad_output, foo, bar, shape):
255-
return torch.empty_like(bar)
256-
257-
258-
class CustomOpFunc(torch.autograd.Function):
259-
@staticmethod
260-
def forward(ctx, input, weight, normalized_shape):
261-
ctx.normalized_shape = normalized_shape
262-
input_ = input.contiguous()
263-
weight_ = weight.contiguous()
264-
output = custom_op_forward(input_, weight_, ctx.normalized_shape)
265-
ctx.save_for_backward(input_, weight_)
266-
return output
267-
268-
@staticmethod
269-
def backward(ctx, grad_output):
270-
input_, weight_ = ctx.saved_tensors
271-
# grad_weight = a_func(grad_output, input_, weight_, ctx.normalized_shape)
272-
grad_weight = custom_op_backward(
273-
grad_output.contiguous(),
274-
input_,
275-
weight_,
276-
ctx.normalized_shape,
277-
)
278-
return None, grad_weight, None
279-
280-
281-
class CustomOpModule(torch.nn.Module):
282-
def __init__(self, shape):
283-
super().__init__()
284-
self.shape = shape
285-
self.weight = torch.nn.Parameter(torch.ones(self.shape))
286-
287-
def forward(self, x):
288-
return CustomOpFunc.apply(x, self.weight, self.shape)
289-
290-
291226
class AutogradFunctionTests(torch._dynamo.test_case.TestCase):
292227
# Sound behaviors, tested for working capture
293228
def test_autograd_function_equivalence(self):
@@ -592,29 +527,18 @@ def forward(self, L_x_: "f32[]", L_z_: "f32[]", L_weird_b: "f32[]", L_weird_c: "
592527
593528
class GraphModule(torch.nn.Module):
594529
def forward(self, ctx, x: "f32[]", z: "f32[]", l_weird_b: "f32[]", l_weird_c: "f32[]"):
595-
ctx_1 = ctx
596-
x_1 = x
597-
z_1 = z
598-
l_weird_b_1 = l_weird_b
599-
l_weird_c_1 = l_weird_c
600-
601-
mul: "f32[]" = l_weird_b_1 * l_weird_c_1
602-
clone: "f32[]" = x_1.clone(); x_1 = None
530+
mul: "f32[]" = l_weird_b * l_weird_c
531+
clone: "f32[]" = x.clone(); x = None
603532
mul_1: "f32[]" = mul * clone; mul = clone = None
604-
return (mul_1, [l_weird_b_1, l_weird_c_1])
533+
return (mul_1, [l_weird_b, l_weird_c])
605534
606535
class GraphModule(torch.nn.Module):
607536
def forward(self, ctx, grad: "f32[]", l_weird_b: "f32[]", l_weird_c: "f32[]"):
608-
ctx_1 = ctx
609-
grad_1 = grad
610-
l_weird_b_1 = l_weird_b
611-
l_weird_c_1 = l_weird_c
612-
613537
_set_grad_enabled = torch._C._set_grad_enabled(False)
614538
615-
mul: "f32[]" = grad_1 * l_weird_b_1; l_weird_b_1 = None
616-
mul_1: "f32[]" = mul * l_weird_c_1; mul = l_weird_c_1 = None
617-
mul_2: "f32[]" = grad_1 * 2; grad_1 = None
539+
mul: "f32[]" = grad * l_weird_b; l_weird_b = None
540+
mul_1: "f32[]" = mul * l_weird_c; mul = l_weird_c = None
541+
mul_2: "f32[]" = grad * 2; grad = None
618542
619543
_set_grad_enabled_1 = torch._C._set_grad_enabled(True)
620544
return (mul_1, mul_2)
@@ -1179,22 +1103,6 @@ def fn():
11791103
self.assertEqual(cnt.frame_count, 1)
11801104
self.assertEqual(cnt.op_count, 2)
11811105

1182-
def test_custom_op(self):
1183-
shape = [7]
1184-
x = torch.rand(128, shape[0])
1185-
model = CustomOpModule(shape)
1186-
out = model(x)
1187-
1188-
def backend(gm, example_inputs):
1189-
return aot_module_simplified(
1190-
gm, example_inputs, fw_compiler=lambda gm, _: gm
1191-
)
1192-
1193-
opt_model = torch.compile(model, backend=backend)
1194-
opt_out = opt_model(x)
1195-
opt_out.mean().backward()
1196-
self.assertEqual(out, opt_out)
1197-
11981106
@requires_cuda
11991107
def test_triton_kernel_basic(self):
12001108
class Add(torch.autograd.Function):

test/dynamo/test_export.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1894,16 +1894,16 @@ def forward(self, x):
18941894
out_graph.cond_true_0.code.strip(),
18951895
"""\
18961896
def forward(self, l_x_):
1897-
l_x__2 = l_x_
1898-
add = l_x__2 + l_x__2; l_x__2 = None
1897+
l_x__1 = l_x_
1898+
add = l_x__1 + l_x__1; l_x__1 = None
18991899
return (add,)""",
19001900
)
19011901
self.assertExpectedInline(
19021902
out_graph.cond_false_0.code.strip(),
19031903
"""\
19041904
def forward(self, l_x_):
1905-
l_x__2 = l_x_
1906-
getitem = l_x__2[slice(None, 2, None)]; l_x__2 = None
1905+
l_x__1 = l_x_
1906+
getitem = l_x__1[slice(None, 2, None)]; l_x__1 = None
19071907
return (getitem,)""",
19081908
)
19091909
with self.assertRaisesRegex(
@@ -3947,13 +3947,13 @@ def forward(self, pred, x):
39473947
out_graph.cond_true_0.code.strip(),
39483948
"""\
39493949
def forward(self, a, b, l_x_, d_true_branch, c_false_branch):
3950-
a_2 = a
3951-
b_2 = b
3952-
l_x__2 = l_x_
3953-
add = l_x__2 + l_x__2; l_x__2 = None
3954-
cos = a_2.cos(); a_2 = None
3950+
a_1 = a
3951+
b_1 = b
3952+
l_x__1 = l_x_
3953+
add = l_x__1 + l_x__1; l_x__1 = None
3954+
cos = a_1.cos(); a_1 = None
39553955
add_1 = add + cos; add = cos = None
3956-
cos_1 = b_2.cos(); b_2 = None
3956+
cos_1 = b_1.cos(); b_1 = None
39573957
add_2 = add_1 + cos_1; add_1 = cos_1 = None
39583958
cos_2 = d_true_branch.cos(); d_true_branch = None
39593959
add_3 = add_2 + cos_2; add_2 = cos_2 = None
@@ -3964,13 +3964,13 @@ def forward(self, a, b, l_x_, d_true_branch, c_false_branch):
39643964
out_graph.cond_false_0.code.strip(),
39653965
"""\
39663966
def forward(self, a, b, l_x_, d_true_branch, c_false_branch):
3967-
a_2 = a
3968-
b_2 = b
3969-
l_x__2 = l_x_
3970-
mul = l_x__2 * l_x__2; l_x__2 = None
3971-
sin = a_2.sin(); a_2 = None
3967+
a_1 = a
3968+
b_1 = b
3969+
l_x__1 = l_x_
3970+
mul = l_x__1 * l_x__1; l_x__1 = None
3971+
sin = a_1.sin(); a_1 = None
39723972
add = mul + sin; mul = sin = None
3973-
sin_1 = b_2.sin(); b_2 = None
3973+
sin_1 = b_1.sin(); b_1 = None
39743974
add_1 = add + sin_1; add = sin_1 = None
39753975
sin_2 = c_false_branch.sin(); c_false_branch = None
39763976
add_2 = add_1 + sin_2; add_1 = sin_2 = None

test/dynamo/test_functions.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1964,7 +1964,6 @@ def forward(self, L_lambda0_keywords_y_: "f32[2, 2]"):
19641964
"""\
19651965
class GraphModule(torch.nn.Module):
19661966
def forward(self, s0: "Sym(s0)", L_lambda0_keywords_y_: "f32[s0, s0]"):
1967-
s0_1 = s0
19681967
l_lambda0_keywords_y_ = L_lambda0_keywords_y_
19691968
19701969
mul: "f32[s0, s0]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_
@@ -2013,7 +2012,6 @@ def forward(self, L_lambda0_keywords_y_: "f32[2, 2]"):
20132012
"""\
20142013
class GraphModule(torch.nn.Module):
20152014
def forward(self, s0: "Sym(s0)", L_lambda0_keywords_y_: "f32[s0, s0]"):
2016-
s0_1 = s0
20172015
l_lambda0_keywords_y_ = L_lambda0_keywords_y_
20182016
20192017
mul: "f32[s0, s0]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_
@@ -2065,7 +2063,6 @@ def forward(self, L_lambda0_keywords_y_: "f32[2, 2]"):
20652063
"""\
20662064
class GraphModule(torch.nn.Module):
20672065
def forward(self, s0: "Sym(s0)", L_lambda0_keywords_y_: "f32[s0, s0]"):
2068-
s0_1 = s0
20692066
l_lambda0_keywords_y_ = L_lambda0_keywords_y_
20702067
20712068
mul: "f32[s0, s0]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_
@@ -2114,7 +2111,6 @@ def forward(self, L_x_: "f32[2, 2]"):
21142111
"""\
21152112
class GraphModule(torch.nn.Module):
21162113
def forward(self, s0: "Sym(s0)", L_x_: "f32[s0, s0]"):
2117-
s0_1 = s0
21182114
l_x_ = L_x_
21192115
21202116
mul: "f32[s0, s0]" = l_x_ * 4

0 commit comments

Comments
 (0)