Skip to content

Commit 53cf46b

Browse files
RabbitWhite1zou3519
authored andcommitted
Fix names conflict when lifting (pytorch#129817)
## Bug description When pending args that are potentially to be lift [here](https://github.com/pytorch/pytorch/blob/58f346c874a8a982679b4d4f3876602cc05d66d4/torch/_dynamo/output_graph.py#L1866) having same base name, like `contiguous` and `contiguous_1`, the call into [create_graph_input](https://github.com/pytorch/pytorch/blob/58f346c874a8a982679b4d4f3876602cc05d66d4/torch/_dynamo/output_graph.py#L2081) can finally create a name ([here](https://github.com/pytorch/pytorch/blob/58f346c874a8a982679b4d4f3876602cc05d66d4/torch/fx/graph.py#L1008)) that overwrite args to lift. And thus causing a wrong output of graph. ## Reproducing Below is an reproduceable example, ```python import logging from typing import List import torch from functorch.compile import aot_module_simplified, make_boxed_func @torch.library.custom_op("mylib::somefunc_forward", mutates_args=()) def somefunc_forward( input_: torch.Tensor, weight: torch.Tensor, shape: List[int], ) -> torch.Tensor: return torch.ones_like(input_) @somefunc_forward.register_fake def _(input_, shape, weight): return torch.empty_like(input_) @torch.library.custom_op("mylib::somefunc_backward", mutates_args=()) def somefunc_backward( grad_output: torch.Tensor, input_: torch.Tensor, weight: torch.Tensor, shape: List[int], ) -> torch.Tensor: print(f"backward.{grad_output.shape=}") print(f"backward.{input_.shape=}") print(f"backward.{weight.shape=}") print(f"backward.{shape=}") assert list(weight.shape) == shape return torch.ones_like(weight) @somefunc_backward.register_fake def _(grad_output, input_, weight, shape): return torch.empty_like(weight) def a_func(grad_output, input_, weight_, shape): return torch.ones_like(input_.sum() * weight_) class SomeFunc(torch.autograd.Function): @staticmethod def forward(ctx, input, weight, normalized_shape): ctx.normalized_shape = normalized_shape input_ = input.contiguous() weight_ = weight.contiguous() output = somefunc_forward(input_, weight_, ctx.normalized_shape) ctx.save_for_backward(input_, weight_) return output @staticmethod def backward(ctx, grad_output): input_, weight_ = ctx.saved_tensors # grad_weight = a_func(grad_output, input_, weight_, ctx.normalized_shape) grad_weight = somefunc_backward( grad_output.contiguous(), input_, weight_, ctx.normalized_shape, ) return None, grad_weight, None class MyModel(torch.nn.Module): def __init__(self): super().__init__() self.weight = torch.nn.Parameter(torch.ones(7)) def forward(self, x): return SomeFunc.apply(x, self.weight, [7]) model = MyModel() torch._logging.set_logs(dynamo=logging.DEBUG, aot=logging.DEBUG, graph_code=True) def aot_print_backend(gm, sample_inputs): # Forward compiler capture def fw(gm, sample_inputs): print(f"----- fw") gm.print_readable() return make_boxed_func(gm.forward) # Backward compiler capture def bw(gm, sample_inputs): print(f"----- bw") gm.print_readable() return make_boxed_func(gm.forward) # Call AOTAutograd gm_forward = aot_module_simplified( gm, sample_inputs, fw_compiler=fw, bw_compiler=bw ) return gm_forward model = torch.compile( model, backend=aot_print_backend, dynamic=False, ) out = model(torch.rand((128, 4, 7))) out.mean().backward() ``` I can see log that showing calling into create_graph_input like ```log V0629 02:08:46.839914 8200981504 torch/_dynamo/output_graph.py:2042] [0/0] create_graph_input contiguous (none) V0629 02:08:46.839998 8200981504 torch/_dynamo/output_graph.py:2042] [0/0] create_graph_input contiguous_1 (none) ``` And the backward graph generate will be like ```log class GraphModule(torch.nn.Module): def forward(self, function_ctx, somefunc_forward_default: "f32[128, 4, 7]", contiguous: "f32[128, 4, 7]", contiguous_1: "f32[7]"): contiguous_1 = contiguous contiguous_2 = contiguous_1 # No stacktrace found for following nodes _set_grad_enabled = torch._C._set_grad_enabled(False) # File: /Users/bytedance/testtorch/test_custom_op_bug.py:61 in backward, code: grad_output.contiguous(), contiguous: "f32[128, 4, 7]" = somefunc_forward_default.contiguous(); somefunc_forward_default = None # File: /opt/tiger/pytorch/torch/_library/custom_ops.py:506 in __call__, code: return self._opoverload(*args, **kwargs) somefunc_backward_default: "f32[7]" = torch.ops.mylib.somefunc_backward.default(contiguous, contiguous_1, contiguous_2, [7]); contiguous = contiguous_1 = contiguous_2 = None # No stacktrace found for following nodes _set_grad_enabled_1 = torch._C._set_grad_enabled(True) return (None, somefunc_backward_default) ``` The original code of `somefunc_backward` takes a input list of `grad_output`, `input_`, `weight` and `shape`, where `weight` should be shape of `torch.Size([7])`. However, in the graph, `contiguous1` and `contiguous_2` are assigned with `contiguous`, this leads to assertion failure I added in `somefunc_backward`. ## Environment ```log Collecting environment information... PyTorch version: 2.5.0a0+git0b7e8df Is debug build: False CUDA used to build PyTorch: None ROCM used to build PyTorch: N/A OS: macOS 14.5 (arm64) GCC version: Could not collect Clang version: 15.0.0 (clang-1500.3.9.4) CMake version: version 3.26.4 Libc version: N/A Python version: 3.9.19 (main, May 6 2024, 14:39:30) [Clang 14.0.6 ] (64-bit runtime) Python platform: macOS-14.5-arm64-arm-64bit Is CUDA available: False CUDA runtime version: No CUDA CUDA_MODULE_LOADING set to: N/A GPU models and configuration: No CUDA Nvidia driver version: No CUDA cuDNN version: No CUDA HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True CPU: Apple M3 Pro Versions of relevant libraries: [pip3] numpy==2.0.0 [pip3] optree==0.11.0 [pip3] torch==2.5.0a0+git0b7e8df [pip3] torchgraph==0.0.1 [conda] numpy 2.0.0 pypi_0 pypi [conda] optree 0.11.0 pypi_0 pypi [conda] torch 2.5.0a0+git0b7e8df dev_0 <develop> [conda] torchgraph 0.0.1 dev_0 <develop> ``` ## How to fix? I put a naive fix that add the potential args to lift into the used_names. This visits private variables, will fix that if this issue makes sense to you. @zou3519 @oulgen Co-authored-by: rzou <[email protected]> Pull Request resolved: pytorch#129817 Approved by: https://github.com/zou3519
1 parent b4b64f7 commit 53cf46b

File tree

9 files changed

+198
-64
lines changed

9 files changed

+198
-64
lines changed

test/dynamo/test_autograd_function.py

Lines changed: 98 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@
44
import math
55

66
from dataclasses import dataclass
7+
from typing import List
78

89
import torch
910

1011
import torch._dynamo.test_case
1112
import torch._dynamo.testing
1213
import torch._dynamo.utils
14+
from functorch.compile import aot_module_simplified
1315
from torch.testing._internal.triton_utils import HAS_CUDA, requires_cuda
1416

1517
if HAS_CUDA:
@@ -223,6 +225,69 @@ def forward(self, x):
223225
return self.f(x)
224226

225227

228+
@torch.library.custom_op("_torch_testing::custom_op_forward", mutates_args=())
229+
def custom_op_forward(
230+
foo: torch.Tensor,
231+
bar: torch.Tensor,
232+
shape: List[int],
233+
) -> torch.Tensor:
234+
return torch.ones_like(foo)
235+
236+
237+
@custom_op_forward.register_fake
238+
def _(foo, bar, weight):
239+
return torch.empty_like(foo)
240+
241+
242+
@torch.library.custom_op("_torch_testing::custom_op_backward", mutates_args=())
243+
def custom_op_backward(
244+
grad_output: torch.Tensor,
245+
foo: torch.Tensor,
246+
bar: torch.Tensor,
247+
shape: List[int],
248+
) -> torch.Tensor:
249+
assert list(bar.shape) == shape
250+
return torch.ones_like(bar)
251+
252+
253+
@custom_op_backward.register_fake
254+
def _(grad_output, foo, bar, shape):
255+
return torch.empty_like(bar)
256+
257+
258+
class CustomOpFunc(torch.autograd.Function):
259+
@staticmethod
260+
def forward(ctx, input, weight, normalized_shape):
261+
ctx.normalized_shape = normalized_shape
262+
input_ = input.contiguous()
263+
weight_ = weight.contiguous()
264+
output = custom_op_forward(input_, weight_, ctx.normalized_shape)
265+
ctx.save_for_backward(input_, weight_)
266+
return output
267+
268+
@staticmethod
269+
def backward(ctx, grad_output):
270+
input_, weight_ = ctx.saved_tensors
271+
# grad_weight = a_func(grad_output, input_, weight_, ctx.normalized_shape)
272+
grad_weight = custom_op_backward(
273+
grad_output.contiguous(),
274+
input_,
275+
weight_,
276+
ctx.normalized_shape,
277+
)
278+
return None, grad_weight, None
279+
280+
281+
class CustomOpModule(torch.nn.Module):
282+
def __init__(self, shape):
283+
super().__init__()
284+
self.shape = shape
285+
self.weight = torch.nn.Parameter(torch.ones(self.shape))
286+
287+
def forward(self, x):
288+
return CustomOpFunc.apply(x, self.weight, self.shape)
289+
290+
226291
class AutogradFunctionTests(torch._dynamo.test_case.TestCase):
227292
# Sound behaviors, tested for working capture
228293
def test_autograd_function_equivalence(self):
@@ -527,18 +592,29 @@ def forward(self, L_x_: "f32[]", L_z_: "f32[]", L_weird_b: "f32[]", L_weird_c: "
527592
528593
class GraphModule(torch.nn.Module):
529594
def forward(self, ctx, x: "f32[]", z: "f32[]", l_weird_b: "f32[]", l_weird_c: "f32[]"):
530-
mul: "f32[]" = l_weird_b * l_weird_c
531-
clone: "f32[]" = x.clone(); x = None
595+
ctx_1 = ctx
596+
x_1 = x
597+
z_1 = z
598+
l_weird_b_1 = l_weird_b
599+
l_weird_c_1 = l_weird_c
600+
601+
mul: "f32[]" = l_weird_b_1 * l_weird_c_1
602+
clone: "f32[]" = x_1.clone(); x_1 = None
532603
mul_1: "f32[]" = mul * clone; mul = clone = None
533-
return (mul_1, [l_weird_b, l_weird_c])
604+
return (mul_1, [l_weird_b_1, l_weird_c_1])
534605
535606
class GraphModule(torch.nn.Module):
536607
def forward(self, ctx, grad: "f32[]", l_weird_b: "f32[]", l_weird_c: "f32[]"):
608+
ctx_1 = ctx
609+
grad_1 = grad
610+
l_weird_b_1 = l_weird_b
611+
l_weird_c_1 = l_weird_c
612+
537613
_set_grad_enabled = torch._C._set_grad_enabled(False)
538614
539-
mul: "f32[]" = grad * l_weird_b; l_weird_b = None
540-
mul_1: "f32[]" = mul * l_weird_c; mul = l_weird_c = None
541-
mul_2: "f32[]" = grad * 2; grad = None
615+
mul: "f32[]" = grad_1 * l_weird_b_1; l_weird_b_1 = None
616+
mul_1: "f32[]" = mul * l_weird_c_1; mul = l_weird_c_1 = None
617+
mul_2: "f32[]" = grad_1 * 2; grad_1 = None
542618
543619
_set_grad_enabled_1 = torch._C._set_grad_enabled(True)
544620
return (mul_1, mul_2)
@@ -1103,6 +1179,22 @@ def fn():
11031179
self.assertEqual(cnt.frame_count, 1)
11041180
self.assertEqual(cnt.op_count, 2)
11051181

1182+
def test_custom_op(self):
1183+
shape = [7]
1184+
x = torch.rand(128, shape[0])
1185+
model = CustomOpModule(shape)
1186+
out = model(x)
1187+
1188+
def backend(gm, example_inputs):
1189+
return aot_module_simplified(
1190+
gm, example_inputs, fw_compiler=lambda gm, _: gm
1191+
)
1192+
1193+
opt_model = torch.compile(model, backend=backend)
1194+
opt_out = opt_model(x)
1195+
opt_out.mean().backward()
1196+
self.assertEqual(out, opt_out)
1197+
11061198
@requires_cuda
11071199
def test_triton_kernel_basic(self):
11081200
class Add(torch.autograd.Function):

test/dynamo/test_export.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1894,16 +1894,16 @@ def forward(self, x):
18941894
out_graph.cond_true_0.code.strip(),
18951895
"""\
18961896
def forward(self, l_x_):
1897-
l_x__1 = l_x_
1898-
add = l_x__1 + l_x__1; l_x__1 = None
1897+
l_x__2 = l_x_
1898+
add = l_x__2 + l_x__2; l_x__2 = None
18991899
return (add,)""",
19001900
)
19011901
self.assertExpectedInline(
19021902
out_graph.cond_false_0.code.strip(),
19031903
"""\
19041904
def forward(self, l_x_):
1905-
l_x__1 = l_x_
1906-
getitem = l_x__1[slice(None, 2, None)]; l_x__1 = None
1905+
l_x__2 = l_x_
1906+
getitem = l_x__2[slice(None, 2, None)]; l_x__2 = None
19071907
return (getitem,)""",
19081908
)
19091909
with self.assertRaisesRegex(
@@ -3947,13 +3947,13 @@ def forward(self, pred, x):
39473947
out_graph.cond_true_0.code.strip(),
39483948
"""\
39493949
def forward(self, a, b, l_x_, d_true_branch, c_false_branch):
3950-
a_1 = a
3951-
b_1 = b
3952-
l_x__1 = l_x_
3953-
add = l_x__1 + l_x__1; l_x__1 = None
3954-
cos = a_1.cos(); a_1 = None
3950+
a_2 = a
3951+
b_2 = b
3952+
l_x__2 = l_x_
3953+
add = l_x__2 + l_x__2; l_x__2 = None
3954+
cos = a_2.cos(); a_2 = None
39553955
add_1 = add + cos; add = cos = None
3956-
cos_1 = b_1.cos(); b_1 = None
3956+
cos_1 = b_2.cos(); b_2 = None
39573957
add_2 = add_1 + cos_1; add_1 = cos_1 = None
39583958
cos_2 = d_true_branch.cos(); d_true_branch = None
39593959
add_3 = add_2 + cos_2; add_2 = cos_2 = None
@@ -3964,13 +3964,13 @@ def forward(self, a, b, l_x_, d_true_branch, c_false_branch):
39643964
out_graph.cond_false_0.code.strip(),
39653965
"""\
39663966
def forward(self, a, b, l_x_, d_true_branch, c_false_branch):
3967-
a_1 = a
3968-
b_1 = b
3969-
l_x__1 = l_x_
3970-
mul = l_x__1 * l_x__1; l_x__1 = None
3971-
sin = a_1.sin(); a_1 = None
3967+
a_2 = a
3968+
b_2 = b
3969+
l_x__2 = l_x_
3970+
mul = l_x__2 * l_x__2; l_x__2 = None
3971+
sin = a_2.sin(); a_2 = None
39723972
add = mul + sin; mul = sin = None
3973-
sin_1 = b_1.sin(); b_1 = None
3973+
sin_1 = b_2.sin(); b_2 = None
39743974
add_1 = add + sin_1; add = sin_1 = None
39753975
sin_2 = c_false_branch.sin(); c_false_branch = None
39763976
add_2 = add_1 + sin_2; add_1 = sin_2 = None

test/dynamo/test_functions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1964,6 +1964,7 @@ def forward(self, L_lambda0_keywords_y_: "f32[2, 2]"):
19641964
"""\
19651965
class GraphModule(torch.nn.Module):
19661966
def forward(self, s0: "Sym(s0)", L_lambda0_keywords_y_: "f32[s0, s0]"):
1967+
s0_1 = s0
19671968
l_lambda0_keywords_y_ = L_lambda0_keywords_y_
19681969
19691970
mul: "f32[s0, s0]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_
@@ -2012,6 +2013,7 @@ def forward(self, L_lambda0_keywords_y_: "f32[2, 2]"):
20122013
"""\
20132014
class GraphModule(torch.nn.Module):
20142015
def forward(self, s0: "Sym(s0)", L_lambda0_keywords_y_: "f32[s0, s0]"):
2016+
s0_1 = s0
20152017
l_lambda0_keywords_y_ = L_lambda0_keywords_y_
20162018
20172019
mul: "f32[s0, s0]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_
@@ -2063,6 +2065,7 @@ def forward(self, L_lambda0_keywords_y_: "f32[2, 2]"):
20632065
"""\
20642066
class GraphModule(torch.nn.Module):
20652067
def forward(self, s0: "Sym(s0)", L_lambda0_keywords_y_: "f32[s0, s0]"):
2068+
s0_1 = s0
20662069
l_lambda0_keywords_y_ = L_lambda0_keywords_y_
20672070
20682071
mul: "f32[s0, s0]" = l_lambda0_keywords_y_ * l_lambda0_keywords_y_
@@ -2111,6 +2114,7 @@ def forward(self, L_x_: "f32[2, 2]"):
21112114
"""\
21122115
class GraphModule(torch.nn.Module):
21132116
def forward(self, s0: "Sym(s0)", L_x_: "f32[s0, s0]"):
2117+
s0_1 = s0
21142118
l_x_ = L_x_
21152119
21162120
mul: "f32[s0, s0]" = l_x_ * 4

0 commit comments

Comments
 (0)