@@ -1855,7 +1855,7 @@ def run_static_input_param_test(self, fn_eager, num_graphs):
1855
1855
1856
1856
self .assertEqual (self .get_manager ().new_graph_id ().id , num_graphs )
1857
1857
1858
- def _module_test (self , mod , name = "weight" , param_wrapping = True ):
1858
+ def _module_test (self , mod ):
1859
1859
with torch .device ("cuda" ):
1860
1860
1861
1861
def fn (x , mod ):
@@ -1878,14 +1878,11 @@ def run_test():
1878
1878
self .assertEqual (exp_grad , compiled_grad )
1879
1879
1880
1880
run_test ()
1881
- old_attr = getattr (mod , name )
1882
- modified_attr = torch .rand_like (old_attr )
1883
- if param_wrapping :
1884
- modified_attr = torch .nn .Parameter (modified_attr )
1885
- setattr (mod , name , modified_attr )
1881
+ old = mod .weight .data
1882
+ mod .weight .data = torch .rand_like (mod .weight .data )
1886
1883
run_test ()
1887
1884
# Run original version to verify we reuse the other recording
1888
- setattr ( mod , name , old_attr )
1885
+ mod . weight . data = old
1889
1886
run_test ()
1890
1887
1891
1888
# Fwd + bwd graphs for each version of the function => 4 graphs
@@ -1910,18 +1907,6 @@ def test_multi_dispatch_single_compile_builtin_module(self):
1910
1907
# Note: Linear is a builtin module so we enable that config setting above
1911
1908
self ._module_test (torch .nn .Linear (2 , 3 , device = "cuda" ))
1912
1909
1913
- @torch ._dynamo .config .patch ("error_on_recompile" , True )
1914
- @torch ._dynamo .config .patch ("inline_inbuilt_nn_modules" , True )
1915
- def test_multi_dispatch_single_compile_builtin_module_buffers (self ):
1916
- # Verify that we don't recompile when changing the buffer of a builtin module
1917
- # and that we record another cudagraph
1918
- self ._module_test (
1919
- torch .nn .BatchNorm1d (2 , device = "cuda" ),
1920
- name = "running_mean" ,
1921
- param_wrapping = False ,
1922
- )
1923
-
1924
- @torch ._inductor .config .patch ("triton.cudagraphs" , True )
1925
1910
@torch ._dynamo .config .patch ("error_on_recompile" , True )
1926
1911
@torch ._dynamo .config .patch ("inline_inbuilt_nn_modules" , True )
1927
1912
def test_multi_dispatch_custom_module (self ):
@@ -1939,30 +1924,6 @@ def forward(self, x):
1939
1924
TestModule (torch .nn .Parameter (torch .rand ([2 , 2 ], device = "cuda" )))
1940
1925
)
1941
1926
1942
- @torch ._dynamo .config .patch ("error_on_recompile" , True )
1943
- @torch ._dynamo .config .patch ("inline_inbuilt_nn_modules" , True )
1944
- def test_multi_dispatch_custom_module_buffer (self ):
1945
- # Test that we can correctly dispatch multiple graphs
1946
- # if buffers of a custom module change
1947
- class TestModule (torch .nn .Module ):
1948
- def __init__ (self , param , buf ) -> None :
1949
- super ().__init__ ()
1950
- self .weight = param
1951
- self .register_buffer ("buf" , buf )
1952
-
1953
- def forward (self , x ):
1954
- return x * self .weight + self .buf
1955
-
1956
- self ._module_test (
1957
- TestModule (
1958
- torch .nn .Parameter (torch .rand ([2 , 2 ], device = "cuda" )),
1959
- torch .rand ([2 , 2 ], device = "cuda" ),
1960
- ),
1961
- name = "buf" ,
1962
- param_wrapping = False ,
1963
- )
1964
-
1965
- @torch ._inductor .config .patch ("triton.cudagraphs" , True )
1966
1927
@torch ._dynamo .config .patch ("error_on_recompile" , True )
1967
1928
@torch ._dynamo .config .patch ("inline_inbuilt_nn_modules" , True )
1968
1929
def test_multi_dispatch_child_node (self ):
0 commit comments