Skip to content

Commit e23a500

Browse files
authored
Add skipIfRocmArch decorator for Navi skips (#1356)
1 parent ef4fc56 commit e23a500

File tree

4 files changed

+73
-1
lines changed

4 files changed

+73
-1
lines changed

test/inductor/test_cuda_repro.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
freeze_rng_state,
2222
IS_FBCODE,
2323
skipIfRocm,
24+
skipIfRocmArch,
2425
TEST_WITH_ASAN,
2526
)
2627

@@ -40,7 +41,7 @@
4041
sys.exit(0)
4142
raise
4243

43-
44+
NAVI_ARCH = ("gfx1100", "gfx1101") # Used for navi exclusive skips on ROCm
4445
TestCase = test_torchinductor.TestCase
4546
ToTuple = test_torchinductor.ToTuple
4647
check_model_cuda = test_torchinductor.check_model_cuda
@@ -305,6 +306,7 @@ def foo(x):
305306
out_ref.add_(2)
306307
# self.assertEqual(out_ref, out)
307308

309+
@skipIfRocmArch(NAVI_ARCH)
308310
def test_accuracy_issue1(self):
309311
class Repro(torch.nn.Module):
310312
def __init__(self):
@@ -341,6 +343,7 @@ def forward(self, start_positions: torch.Tensor, x: torch.Tensor):
341343
assert same_two_models(mod, opt_mod, args), "Dynamo failed"
342344

343345
@config.patch(allow_buffer_reuse=False)
346+
@skipIfRocmArch(NAVI_ARCH)
344347
def test_issue103461(self):
345348
def forward(add_1):
346349
var_mean = torch.ops.aten.var_mean.correction(
@@ -828,6 +831,7 @@ def forward(self, x):
828831
res2 = jit_func(x)
829832
self.assertEqual(res1, res2)
830833

834+
@skipIfRocmArch(NAVI_ARCH)
831835
def test_issue103481(self):
832836
def fn(x, y):
833837
# NOTE: 6 dimensions is important! does not fail for 5 dimensions

test/inductor/test_torchinductor.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
IS_WINDOWS,
6666
IS_X86,
6767
skipIfRocm,
68+
skipIfRocmArch,
6869
TEST_WITH_ASAN,
6970
TEST_WITH_ROCM,
7071
TestCase as TorchTestCase,
@@ -104,6 +105,7 @@
104105
_desired_test_bases = get_desired_device_type_test_bases()
105106
RUN_CPU = any(getattr(x, "device_type", "") == "cpu" for x in _desired_test_bases)
106107
RUN_GPU = any(getattr(x, "device_type", "") == GPU_TYPE for x in _desired_test_bases)
108+
NAVI_ARCH = ("gfx1100", "gfx1101") # Used for navi exclusive skips on ROCm
107109

108110
aten = torch.ops.aten
109111
requires_gpu = functools.partial(unittest.skipIf, not HAS_GPU, "requires gpu")
@@ -1120,6 +1122,7 @@ def fn(x):
11201122
# make sure things also work if they aren't unrolled
11211123
self.common(fn, (torch.randn(8, 3),))
11221124

1125+
@skipIfRocmArch(NAVI_ARCH)
11231126
def test_multilayer_sum_low_prec(self):
11241127
# fp16 nyi for cpu
11251128
if self.device == "cpu":
@@ -1130,6 +1133,7 @@ def fn(a):
11301133

11311134
self.common(fn, ((torch.rand((10, 3, 352, 352), dtype=torch.float16),)))
11321135

1136+
@skipIfRocmArch(NAVI_ARCH)
11331137
def test_multilayer_prime_size(self):
11341138
def fn(a):
11351139
return torch.max(a), torch.sum(a)
@@ -1139,6 +1143,7 @@ def fn(a):
11391143
sample[-1] = 1
11401144
self.common(fn, (sample,))
11411145

1146+
@skipIfRocmArch(NAVI_ARCH)
11421147
def test_multilayer_var(self):
11431148
def fn(a):
11441149
return torch.var(a)
@@ -1838,6 +1843,7 @@ def fn(a, b):
18381843

18391844
self.common(fn, (torch.randn(8, 8), torch.randn(8, 8)))
18401845

1846+
@skipIfRocmArch(NAVI_ARCH)
18411847
def test_large_tensor_reduction(self):
18421848
if not _has_sufficient_memory(self.device, 4.5 * 1024**3): # 4.5 GiB
18431849
raise unittest.SkipTest("insufficient memory")
@@ -1858,6 +1864,7 @@ def fn(a):
18581864
expect = torch.tensor(2, dtype=torch.int8, device=self.device)
18591865
self.assertEqual(actual, expect)
18601866

1867+
@skipIfRocmArch(NAVI_ARCH)
18611868
def test_large_broadcast_reduction(self):
18621869
if self.device == "cpu":
18631870
raise unittest.SkipTest("Fails on CPU")
@@ -2868,6 +2875,7 @@ def test_conv2d_channels_last(self):
28682875
check_lowp=False,
28692876
)
28702877

2878+
@skipIfRocmArch(NAVI_ARCH)
28712879
def test_conv2d_backward_channels_last(self):
28722880
def fn(grad_output, inp, weight):
28732881
convolution_backward_8 = torch.ops.aten.convolution_backward.default(
@@ -3535,6 +3543,7 @@ def fn(x, y):
35353543
self.assertEqual(a.stride(), c.stride())
35363544
self.assertEqual(c.stride()[2], 1)
35373545

3546+
@skipIfRocmArch(NAVI_ARCH)
35383547
def test_std(self):
35393548
def fn(x):
35403549
return (
@@ -3577,6 +3586,7 @@ def test_batch_norm_2d(self):
35773586

35783587
# From yolov3
35793588
@with_tf32_off
3589+
@skipIfRocmArch(NAVI_ARCH)
35803590
def test_batch_norm_2d_2(self):
35813591
if self.device == "cpu":
35823592
raise unittest.SkipTest(f"requires {GPU_TYPE}")
@@ -3712,6 +3722,7 @@ def fn(x):
37123722

37133723
self.common(fn, (x,))
37143724

3725+
@skipIfRocmArch(NAVI_ARCH)
37153726
def test_cauchy(self):
37163727
def fn(x, y):
37173728
return torch.sum(1 / (torch.unsqueeze(x, -1) - y))
@@ -4868,6 +4879,7 @@ def fn(a):
48684879
y = fn_compiled(x)
48694880
self.assertTrue(y is not x)
48704881

4882+
@skipIfRocmArch(NAVI_ARCH)
48714883
def test_l1_loss(self):
48724884
def fn(a, b):
48734885
return torch.nn.functional.l1_loss(a, b), torch.nn.functional.mse_loss(a, b)
@@ -5255,6 +5267,7 @@ def fn(x):
52555267
fn, (torch.tensor([1, float("inf"), 2, float("-inf"), float("nan")]),)
52565268
)
52575269

5270+
@skipIfRocmArch(NAVI_ARCH)
52585271
def test_any(self):
52595272
def fn(x):
52605273
return (
@@ -5890,6 +5903,7 @@ def fn(a, dim, index, b, reduce):
58905903
)
58915904

58925905
# issue #1150
5906+
@skipIfRocmArch(NAVI_ARCH)
58935907
def test_dense_mask_index(self):
58945908
if self.device == "cpu":
58955909
raise unittest.SkipTest(
@@ -6607,6 +6621,7 @@ def fn(a, b):
66076621
b = torch.rand(2, 2, 1, 4, 1).int()
66086622
self.common(fn, (a, b))
66096623

6624+
@skipIfRocmArch(NAVI_ARCH)
66106625
def test_argmax_argmin1(self):
66116626
def fn(x):
66126627
return (aten.argmax(x), aten.argmin(x))
@@ -6618,6 +6633,7 @@ def fn(x):
66186633
],
66196634
)
66206635

6636+
@skipIfRocmArch(NAVI_ARCH)
66216637
def test_argmax_argmin2(self):
66226638
def fn(x):
66236639
return (
@@ -6629,6 +6645,7 @@ def fn(x):
66296645

66306646
self.common(fn, (torch.randn([144, 144]),))
66316647

6648+
@skipIfRocmArch(NAVI_ARCH)
66326649
def test_argmax_argmin_with_duplicates(self):
66336650
def fn(x):
66346651
return (
@@ -6650,6 +6667,7 @@ def fn(x):
66506667
t1 = torch.randint(8, size=(1028, 1028))
66516668
self.common(fn, (t1,))
66526669

6670+
@skipIfRocmArch(NAVI_ARCH)
66536671
def test_argmax_argmin_with_nan(self):
66546672
def fn(x):
66556673
return (
@@ -6782,6 +6800,7 @@ def fn(x):
67826800
],
67836801
)
67846802

6803+
@skipIfRocmArch(NAVI_ARCH)
67856804
def test_tmp_not_defined_issue1(self):
67866805
def forward(
67876806
primals_3,
@@ -7062,6 +7081,7 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None):
70627081
else:
70637082
self.assertEqual(len(inps), 0)
70647083

7084+
@skipIfRocmArch(NAVI_ARCH)
70657085
def test_dtype_mismatch_issue(self):
70667086
def fn(x):
70677087
attn = torch.nn.functional.pad(x, [0, 1])
@@ -9040,6 +9060,7 @@ def test_rnn_compile_safe(self):
90409060

90419061
class NanCheckerTest(TestCase):
90429062
@config.patch("nan_asserts", True)
9063+
@skipIfRocmArch(NAVI_ARCH)
90439064
def test_nan_checker_pass(self):
90449065
def f(x):
90459066
return torch.softmax(x, dim=-1)

test/inductor/test_torchinductor_opinfo.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from torch.testing._internal.common_methods_invocations import op_db, skipOps
3232
from torch.testing._internal.common_utils import (
3333
dtype_abbrs,
34+
IS_NAVI,
3435
IS_MACOS,
3536
IS_X86,
3637
skipCUDAMemoryLeakCheckIf,
@@ -200,6 +201,19 @@ def format_op(op):
200201
if TEST_WITH_ROCM:
201202
# Tensors are not alike
202203
inductor_skips["cuda"]["logcumsumexp"] = {f32}
204+
if IS_NAVI:
205+
inductor_skips["cuda"]["aminmax"] = {b8, f16, f32, f64, i32, i64}
206+
inductor_skips["cuda"]["dist"] = {b8, f16, f32, f64, i32, i64}
207+
inductor_skips["cuda"]["kron"] = {b8, f16, f32, f64, i32, i64}
208+
inductor_skips["cuda"]["masked.std"] = {b8, f16, f32, f64, i32, i64}
209+
inductor_skips["cuda"]["masked.var"] = {b8, f16, f32, f64, i32, i64}
210+
inductor_skips["cuda"][("max", "reduction_no_dim")] = {b8, f16, f32, f64, i32, i64}
211+
inductor_skips["cuda"][("min", "reduction_no_dim")] = {b8, f16, f32, f64, i32, i64}
212+
inductor_skips["cuda"]["nn.functional.conv_transpose3d"] = {b8, f16, f32, f64, i32, i64}
213+
inductor_skips["cuda"]["std"] = {b8, f16, f32, f64, i32, i64}
214+
inductor_skips["cuda"]["std_mean"] = {b8, f16, f32, f64, i32, i64}
215+
inductor_skips["cuda"]["var"] = {b8, f16, f32, f64, i32, i64}
216+
inductor_skips["cuda"]["var_mean"] = {b8, f16, f32, f64, i32, i64}
203217

204218
inductor_expected_failures_single_sample = defaultdict(dict)
205219

torch/testing/_internal/common_utils.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1176,6 +1176,13 @@ def printErrors(self) -> None:
11761176
IS_X86 = platform.machine() in ('x86_64', 'i386')
11771177
IS_ARM64 = platform.machine() in ('arm64', 'aarch64')
11781178

1179+
IS_NAVI=False
1180+
if torch.cuda.is_available():
1181+
prop = torch.cuda.get_device_properties(0)
1182+
gfx_arch = prop.gcnArchName.split(":")[0]
1183+
if gfx_arch in ["gfx1100", "gfx1101", "gfx1102"]:
1184+
IS_NAVI = True
1185+
11791186
def is_avx512_vnni_supported():
11801187
if sys.platform != 'linux':
11811188
return False
@@ -1559,6 +1566,19 @@ def wrapper(*args, **kwargs):
15591566
return dec_fn(func)
15601567
return dec_fn
15611568

1569+
def skipIfRocmArch(arch: Tuple[str, ...]):
1570+
def dec_fn(fn):
1571+
@wraps(fn)
1572+
def wrap_fn(self, *args, **kwargs):
1573+
if TEST_WITH_ROCM:
1574+
prop = torch.cuda.get_device_properties(0)
1575+
if prop.gcnArchName.split(":")[0] in arch:
1576+
reason = f"skipIfRocm: test skipped on {arch}"
1577+
raise unittest.SkipTest(reason)
1578+
return fn(self, *args, **kwargs)
1579+
return wrap_fn
1580+
return dec_fn
1581+
15621582
def runOnRocm(fn):
15631583
@wraps(fn)
15641584
def wrapper(*args, **kwargs):
@@ -1568,6 +1588,19 @@ def wrapper(*args, **kwargs):
15681588
raise unittest.SkipTest("test currently only works on the ROCm stack")
15691589
return wrapper
15701590

1591+
def runOnRocmArch(arch: Tuple[str, ...]):
1592+
def dec_fn(fn):
1593+
@wraps(fn)
1594+
def wrap_fn(self, *args, **kwargs):
1595+
if TEST_WITH_ROCM:
1596+
prop = torch.cuda.get_device_properties(0)
1597+
if prop.gcnArchName.split(":")[0] not in arch:
1598+
reason = f"skipIfRocm: test skipped on {arch}"
1599+
raise unittest.SkipTest(reason)
1600+
return fn(self, *args, **kwargs)
1601+
return wrap_fn
1602+
return dec_fn
1603+
15711604
def skipIfMps(fn):
15721605
@wraps(fn)
15731606
def wrapper(*args, **kwargs):

0 commit comments

Comments
 (0)