@@ -1825,7 +1825,7 @@ def run_static_input_param_test(self, fn_eager, num_graphs):
1825
1825
1826
1826
self .assertEqual (self .get_manager ().new_graph_id ().id , num_graphs )
1827
1827
1828
- def _module_test (self , mod ):
1828
+ def _module_test (self , mod , name = "weight" , param_wrapping = True ):
1829
1829
with torch .device ("cuda" ):
1830
1830
1831
1831
def fn (x , mod ):
@@ -1848,11 +1848,14 @@ def run_test():
1848
1848
self .assertEqual (exp_grad , compiled_grad )
1849
1849
1850
1850
run_test ()
1851
- old = mod .weight .data
1852
- mod .weight .data = torch .rand_like (mod .weight .data )
1851
+ old_attr = getattr (mod , name )
1852
+ modified_attr = torch .rand_like (old_attr )
1853
+ if param_wrapping :
1854
+ modified_attr = torch .nn .Parameter (modified_attr )
1855
+ setattr (mod , name , modified_attr )
1853
1856
run_test ()
1854
1857
# Run original version to verify we reuse the other recording
1855
- mod . weight . data = old
1858
+ setattr ( mod , name , old_attr )
1856
1859
run_test ()
1857
1860
1858
1861
# Fwd + bwd graphs for each version of the function => 4 graphs
@@ -1877,6 +1880,18 @@ def test_multi_dispatch_single_compile_builtin_module(self):
1877
1880
# Note: Linear is a builtin module so we enable that config setting above
1878
1881
self ._module_test (torch .nn .Linear (2 , 3 , device = "cuda" ))
1879
1882
1883
+ @torch ._dynamo .config .patch ("error_on_recompile" , True )
1884
+ @torch ._dynamo .config .patch ("inline_inbuilt_nn_modules" , True )
1885
+ def test_multi_dispatch_single_compile_builtin_module_buffers (self ):
1886
+ # Verify that we don't recompile when changing the buffer of a builtin module
1887
+ # and that we record another cudagraph
1888
+ self ._module_test (
1889
+ torch .nn .BatchNorm1d (2 , device = "cuda" ),
1890
+ name = "running_mean" ,
1891
+ param_wrapping = False ,
1892
+ )
1893
+
1894
+ @torch ._inductor .config .patch ("triton.cudagraphs" , True )
1880
1895
@torch ._dynamo .config .patch ("error_on_recompile" , True )
1881
1896
@torch ._dynamo .config .patch ("inline_inbuilt_nn_modules" , True )
1882
1897
def test_multi_dispatch_custom_module (self ):
@@ -1894,6 +1909,30 @@ def forward(self, x):
1894
1909
TestModule (torch .nn .Parameter (torch .rand ([2 , 2 ], device = "cuda" )))
1895
1910
)
1896
1911
1912
+ @torch ._dynamo .config .patch ("error_on_recompile" , True )
1913
+ @torch ._dynamo .config .patch ("inline_inbuilt_nn_modules" , True )
1914
+ def test_multi_dispatch_custom_module_buffer (self ):
1915
+ # Test that we can correctly dispatch multiple graphs
1916
+ # if buffers of a custom module change
1917
+ class TestModule (torch .nn .Module ):
1918
+ def __init__ (self , param , buf ) -> None :
1919
+ super ().__init__ ()
1920
+ self .weight = param
1921
+ self .register_buffer ("buf" , buf )
1922
+
1923
+ def forward (self , x ):
1924
+ return x * self .weight + self .buf
1925
+
1926
+ self ._module_test (
1927
+ TestModule (
1928
+ torch .nn .Parameter (torch .rand ([2 , 2 ], device = "cuda" )),
1929
+ torch .rand ([2 , 2 ], device = "cuda" ),
1930
+ ),
1931
+ name = "buf" ,
1932
+ param_wrapping = False ,
1933
+ )
1934
+
1935
+ @torch ._inductor .config .patch ("triton.cudagraphs" , True )
1897
1936
@torch ._dynamo .config .patch ("error_on_recompile" , True )
1898
1937
@torch ._dynamo .config .patch ("inline_inbuilt_nn_modules" , True )
1899
1938
def test_multi_dispatch_child_node (self ):
0 commit comments