Skip to content

Commit d10d2fa

Browse files
jithunnair-amddnikolaev-amd
authored andcommitted
skipIfRocm needs msg parameter
1 parent 4c85c6c commit d10d2fa

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

test/test_cuda.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1884,9 +1884,8 @@ def test_graph_capture_oom(self):
18841884
with torch.cuda.graph(torch.cuda.CUDAGraph()):
18851885
torch.zeros(2**40, device="cuda")
18861886

1887-
@unittest.skipIf(
1888-
not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
1889-
)
1887+
@unittest.skipIf(not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs")
1888+
@skipIfRocm(msg="TODO: temp skip on ROCm 6.2")
18901889
@serialTest()
18911890
def test_repeat_graph_capture_cublas_workspace_memory(self):
18921891
(x, y, z) = 1024, 512, 64
@@ -2842,6 +2841,7 @@ def forward(self, input_dict: dict):
28422841
@unittest.skipIf(
28432842
not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs"
28442843
)
2844+
@skipIfRocm(msg="TODO: temp skip on ROCm 6.2")
28452845
def test_graph_make_graphed_callables_same_pool(self):
28462846
torch.manual_seed(5)
28472847
torch.cuda.manual_seed(5)

0 commit comments

Comments
 (0)