Skip to content

Commit 80236dc

Browse files
mlazospytorchmergebot
authored andcommitted
Add buffer static input tests to cudagraph trees (pytorch#130402)
Pull Request resolved: pytorch#130402 Approved by: https://github.com/eellison ghstack dependencies: pytorch#130391, pytorch#130392, pytorch#130503, pytorch#130393
1 parent 69a7738 commit 80236dc

File tree

1 file changed

+43
-4
lines changed

1 file changed

+43
-4
lines changed

test/inductor/test_cudagraph_trees.py

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1825,7 +1825,7 @@ def run_static_input_param_test(self, fn_eager, num_graphs):
18251825

18261826
self.assertEqual(self.get_manager().new_graph_id().id, num_graphs)
18271827

1828-
def _module_test(self, mod):
1828+
def _module_test(self, mod, name="weight", param_wrapping=True):
18291829
with torch.device("cuda"):
18301830

18311831
def fn(x, mod):
@@ -1848,11 +1848,14 @@ def run_test():
18481848
self.assertEqual(exp_grad, compiled_grad)
18491849

18501850
run_test()
1851-
old = mod.weight.data
1852-
mod.weight.data = torch.rand_like(mod.weight.data)
1851+
old_attr = getattr(mod, name)
1852+
modified_attr = torch.rand_like(old_attr)
1853+
if param_wrapping:
1854+
modified_attr = torch.nn.Parameter(modified_attr)
1855+
setattr(mod, name, modified_attr)
18531856
run_test()
18541857
# Run original version to verify we reuse the other recording
1855-
mod.weight.data = old
1858+
setattr(mod, name, old_attr)
18561859
run_test()
18571860

18581861
# Fwd + bwd graphs for each version of the function => 4 graphs
@@ -1877,6 +1880,18 @@ def test_multi_dispatch_single_compile_builtin_module(self):
18771880
# Note: Linear is a builtin module so we enable that config setting above
18781881
self._module_test(torch.nn.Linear(2, 3, device="cuda"))
18791882

1883+
@torch._dynamo.config.patch("error_on_recompile", True)
1884+
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
1885+
def test_multi_dispatch_single_compile_builtin_module_buffers(self):
1886+
# Verify that we don't recompile when changing the buffer of a builtin module
1887+
# and that we record another cudagraph
1888+
self._module_test(
1889+
torch.nn.BatchNorm1d(2, device="cuda"),
1890+
name="running_mean",
1891+
param_wrapping=False,
1892+
)
1893+
1894+
@torch._inductor.config.patch("triton.cudagraphs", True)
18801895
@torch._dynamo.config.patch("error_on_recompile", True)
18811896
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
18821897
def test_multi_dispatch_custom_module(self):
@@ -1894,6 +1909,30 @@ def forward(self, x):
18941909
TestModule(torch.nn.Parameter(torch.rand([2, 2], device="cuda")))
18951910
)
18961911

1912+
@torch._dynamo.config.patch("error_on_recompile", True)
1913+
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
1914+
def test_multi_dispatch_custom_module_buffer(self):
1915+
# Test that we can correctly dispatch multiple graphs
1916+
# if buffers of a custom module change
1917+
class TestModule(torch.nn.Module):
1918+
def __init__(self, param, buf) -> None:
1919+
super().__init__()
1920+
self.weight = param
1921+
self.register_buffer("buf", buf)
1922+
1923+
def forward(self, x):
1924+
return x * self.weight + self.buf
1925+
1926+
self._module_test(
1927+
TestModule(
1928+
torch.nn.Parameter(torch.rand([2, 2], device="cuda")),
1929+
torch.rand([2, 2], device="cuda"),
1930+
),
1931+
name="buf",
1932+
param_wrapping=False,
1933+
)
1934+
1935+
@torch._inductor.config.patch("triton.cudagraphs", True)
18971936
@torch._dynamo.config.patch("error_on_recompile", True)
18981937
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
18991938
def test_multi_dispatch_child_node(self):

0 commit comments

Comments
 (0)