Skip to content

Commit 9cb23ba

Browse files
Revert "Add buffer static input tests to cudagraph trees (pytorch#130402)"
This reverts commit 80236dc. Reverted pytorch#130402 on behalf of https://github.com/clee2000 due to broke lint for torch/_functorch/_aot_autograd/subclass_utils.py https://github.com/pytorch/pytorch/actions/runs/9948630877/job/27483551649 https://hud.pytorch.org/pytorch/pytorch/commit/80236dca90b0874cb2b6f9c9fa5f159c55726401 lint was green on PR, probably a landrace ([comment](pytorch#130393 (comment)))
1 parent c509319 commit 9cb23ba

File tree

1 file changed

+4
-43
lines changed

1 file changed

+4
-43
lines changed

test/inductor/test_cudagraph_trees.py

Lines changed: 4 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1855,7 +1855,7 @@ def run_static_input_param_test(self, fn_eager, num_graphs):
18551855

18561856
self.assertEqual(self.get_manager().new_graph_id().id, num_graphs)
18571857

1858-
def _module_test(self, mod, name="weight", param_wrapping=True):
1858+
def _module_test(self, mod):
18591859
with torch.device("cuda"):
18601860

18611861
def fn(x, mod):
@@ -1878,14 +1878,11 @@ def run_test():
18781878
self.assertEqual(exp_grad, compiled_grad)
18791879

18801880
run_test()
1881-
old_attr = getattr(mod, name)
1882-
modified_attr = torch.rand_like(old_attr)
1883-
if param_wrapping:
1884-
modified_attr = torch.nn.Parameter(modified_attr)
1885-
setattr(mod, name, modified_attr)
1881+
old = mod.weight.data
1882+
mod.weight.data = torch.rand_like(mod.weight.data)
18861883
run_test()
18871884
# Run original version to verify we reuse the other recording
1888-
setattr(mod, name, old_attr)
1885+
mod.weight.data = old
18891886
run_test()
18901887

18911888
# Fwd + bwd graphs for each version of the function => 4 graphs
@@ -1910,18 +1907,6 @@ def test_multi_dispatch_single_compile_builtin_module(self):
19101907
# Note: Linear is a builtin module so we enable that config setting above
19111908
self._module_test(torch.nn.Linear(2, 3, device="cuda"))
19121909

1913-
@torch._dynamo.config.patch("error_on_recompile", True)
1914-
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
1915-
def test_multi_dispatch_single_compile_builtin_module_buffers(self):
1916-
# Verify that we don't recompile when changing the buffer of a builtin module
1917-
# and that we record another cudagraph
1918-
self._module_test(
1919-
torch.nn.BatchNorm1d(2, device="cuda"),
1920-
name="running_mean",
1921-
param_wrapping=False,
1922-
)
1923-
1924-
@torch._inductor.config.patch("triton.cudagraphs", True)
19251910
@torch._dynamo.config.patch("error_on_recompile", True)
19261911
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
19271912
def test_multi_dispatch_custom_module(self):
@@ -1939,30 +1924,6 @@ def forward(self, x):
19391924
TestModule(torch.nn.Parameter(torch.rand([2, 2], device="cuda")))
19401925
)
19411926

1942-
@torch._dynamo.config.patch("error_on_recompile", True)
1943-
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
1944-
def test_multi_dispatch_custom_module_buffer(self):
1945-
# Test that we can correctly dispatch multiple graphs
1946-
# if buffers of a custom module change
1947-
class TestModule(torch.nn.Module):
1948-
def __init__(self, param, buf) -> None:
1949-
super().__init__()
1950-
self.weight = param
1951-
self.register_buffer("buf", buf)
1952-
1953-
def forward(self, x):
1954-
return x * self.weight + self.buf
1955-
1956-
self._module_test(
1957-
TestModule(
1958-
torch.nn.Parameter(torch.rand([2, 2], device="cuda")),
1959-
torch.rand([2, 2], device="cuda"),
1960-
),
1961-
name="buf",
1962-
param_wrapping=False,
1963-
)
1964-
1965-
@torch._inductor.config.patch("triton.cudagraphs", True)
19661927
@torch._dynamo.config.patch("error_on_recompile", True)
19671928
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
19681929
def test_multi_dispatch_child_node(self):

0 commit comments

Comments
 (0)