Skip to content

Commit 9726c26

Browse files
jataylodnikolaev-amd
authored andcommitted
Add skipIfRocmArch decorator for Navi skips (#1356)
1 parent 3120778 commit 9726c26

File tree

4 files changed

+77
-1
lines changed

4 files changed

+77
-1
lines changed

test/inductor/test_cuda_repro.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
freeze_rng_state,
2828
IS_FBCODE,
2929
skipIfRocm,
30+
skipIfRocmArch,
3031
TEST_WITH_ASAN,
3132
)
3233

@@ -48,7 +49,7 @@
4849
sys.exit(0)
4950
raise
5051

51-
52+
NAVI_ARCH = ("gfx1100", "gfx1101") # Used for navi exclusive skips on ROCm
5253
TestCase = test_torchinductor.TestCase
5354
ToTuple = test_torchinductor.ToTuple
5455
check_model_cuda = test_torchinductor.check_model_cuda
@@ -313,6 +314,7 @@ def foo(x):
313314
out_ref.add_(2)
314315
# self.assertEqual(out_ref, out)
315316

317+
@skipIfRocmArch(NAVI_ARCH)
316318
def test_accuracy_issue1(self):
317319
class Repro(torch.nn.Module):
318320
def __init__(self):
@@ -349,6 +351,7 @@ def forward(self, start_positions: torch.Tensor, x: torch.Tensor):
349351
assert same_two_models(mod, opt_mod, args), "Dynamo failed"
350352

351353
@config.patch(allow_buffer_reuse=False)
354+
@skipIfRocmArch(NAVI_ARCH)
352355
def test_issue103461(self):
353356
def forward(add_1):
354357
var_mean = torch.ops.aten.var_mean.correction(
@@ -832,6 +835,7 @@ def forward(self, x):
832835
res2 = jit_func(x)
833836
self.assertEqual(res1, res2)
834837

838+
@skipIfRocmArch(NAVI_ARCH)
835839
def test_issue103481(self):
836840
def fn(x, y):
837841
# NOTE: 6 dimensions is important! does not fail for 5 dimensions

test/inductor/test_torchinductor.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@
8787
skipIfRocm,
8888
skipIfXpu,
8989
subtest,
90+
skipIfRocmArch,
9091
TEST_WITH_ASAN,
9192
TEST_WITH_ROCM,
9293
)
@@ -129,6 +130,10 @@
129130
)
130131

131132
HAS_AVX2 = "fbgemm" in torch.backends.quantized.supported_engines
133+
_desired_test_bases = get_desired_device_type_test_bases()
134+
RUN_CPU = any(getattr(x, "device_type", "") == "cpu" for x in _desired_test_bases)
135+
RUN_GPU = any(getattr(x, "device_type", "") == GPU_TYPE for x in _desired_test_bases)
136+
NAVI_ARCH = ("gfx1100", "gfx1101") # Used for navi exclusive skips on ROCm
132137

133138
aten = torch.ops.aten
134139

@@ -1648,6 +1653,7 @@ def fn(x):
16481653
# make sure things also work if they aren't unrolled
16491654
self.common(fn, (torch.randn(8, 3),))
16501655

1656+
@skipIfRocmArch(NAVI_ARCH)
16511657
def test_multilayer_sum_low_prec(self):
16521658
# fp16 nyi for cpu
16531659
if self.device == "cpu":
@@ -1658,6 +1664,7 @@ def fn(a):
16581664

16591665
self.common(fn, ((torch.rand((10, 3, 352, 352), dtype=torch.float16),)))
16601666

1667+
@skipIfRocmArch(NAVI_ARCH)
16611668
def test_multilayer_prime_size(self):
16621669
def fn(a):
16631670
return torch.max(a), torch.sum(a)
@@ -1668,6 +1675,7 @@ def fn(a):
16681675
self.common(fn, (sample,))
16691676

16701677
@skipCPUIf(IS_MACOS, "fails on macos")
1678+
@skipIfRocmArch(NAVI_ARCH)
16711679
def test_multilayer_var(self):
16721680
def fn(a):
16731681
return torch.var(a)
@@ -2667,6 +2675,7 @@ def fn(a, b):
26672675

26682676
self.common(fn, (torch.randn(8, 8), torch.randn(8, 8)))
26692677

2678+
@skipIfRocmArch(NAVI_ARCH)
26702679
def test_large_tensor_reduction(self):
26712680
if not _has_sufficient_memory(self.device, 4.5 * 1024**3): # 4.5 GiB
26722681
raise unittest.SkipTest("insufficient memory")
@@ -2687,6 +2696,7 @@ def fn(a):
26872696
expect = torch.tensor(2, dtype=torch.int8, device=self.device)
26882697
self.assertEqual(actual, expect)
26892698

2699+
@skipIfRocmArch(NAVI_ARCH)
26902700
def test_large_broadcast_reduction(self):
26912701
if self.device == "cpu":
26922702
raise unittest.SkipTest("Fails on CPU")
@@ -3761,6 +3771,7 @@ def test_conv2d_channels_last(self):
37613771
check_lowp=False,
37623772
)
37633773

3774+
@skipIfRocmArch(NAVI_ARCH)
37643775
def test_conv2d_backward_channels_last(self):
37653776
def fn(grad_output, inp, weight):
37663777
convolution_backward_8 = torch.ops.aten.convolution_backward.default(
@@ -4520,6 +4531,7 @@ def fn(x, y):
45204531
self.assertEqual(a.stride(), c.stride())
45214532
self.assertEqual(c.stride()[2], 1)
45224533

4534+
@skipIfRocmArch(NAVI_ARCH)
45234535
def test_std(self):
45244536
def fn(x):
45254537
return (
@@ -4562,6 +4574,7 @@ def test_batch_norm_2d(self):
45624574

45634575
# From yolov3
45644576
@with_tf32_off
4577+
@skipIfRocmArch(NAVI_ARCH)
45654578
def test_batch_norm_2d_2(self):
45664579
if self.device == "cpu":
45674580
raise unittest.SkipTest(f"requires {GPU_TYPE}")
@@ -4697,6 +4710,7 @@ def fn(x):
46974710

46984711
self.common(fn, (x,))
46994712

4713+
@skipIfRocmArch(NAVI_ARCH)
47004714
def test_cauchy(self):
47014715
def fn(x, y):
47024716
return torch.sum(1 / (torch.unsqueeze(x, -1) - y))
@@ -6034,6 +6048,7 @@ def fn(a):
60346048
y = fn_compiled(x)
60356049
self.assertTrue(y is not x)
60366050

6051+
@skipIfRocmArch(NAVI_ARCH)
60376052
def test_l1_loss(self):
60386053
def fn(a, b):
60396054
return torch.nn.functional.l1_loss(a, b), torch.nn.functional.mse_loss(a, b)
@@ -6430,6 +6445,7 @@ def fn(x):
64306445
fn, (torch.tensor([1, float("inf"), 2, float("-inf"), float("nan")]),)
64316446
)
64326447

6448+
@skipIfRocmArch(NAVI_ARCH)
64336449
def test_any(self):
64346450
def fn(x):
64356451
return (
@@ -7177,6 +7193,8 @@ def fn(a, dim, index, b, reduce):
71777193
check_lowp=check_lowp,
71787194
)
71797195

7196+
# issue #1150
7197+
@skipIfRocmArch(NAVI_ARCH)
71807198
def test_dense_mask_index(self):
71817199
r"""
71827200
There will be a little difference for reduce order between aten and inductor
@@ -8152,6 +8170,7 @@ def fn(a, b):
81528170
b = torch.rand(2, 2, 1, 4, 1).int()
81538171
self.common(fn, (a, b))
81548172

8173+
@skipIfRocmArch(NAVI_ARCH)
81558174
def test_argmax_argmin1(self):
81568175
def fn(x):
81578176
return (aten.argmax(x), aten.argmin(x))
@@ -8163,6 +8182,7 @@ def fn(x):
81638182
],
81648183
)
81658184

8185+
@skipIfRocmArch(NAVI_ARCH)
81668186
def test_argmax_argmin2(self):
81678187
def fn(x):
81688188
return (
@@ -8174,6 +8194,7 @@ def fn(x):
81748194

81758195
self.common(fn, (torch.randn([144, 144]),))
81768196

8197+
@skipIfRocmArch(NAVI_ARCH)
81778198
def test_argmax_argmin_with_duplicates(self):
81788199
def fn(x):
81798200
return (
@@ -8195,6 +8216,7 @@ def fn(x):
81958216
t1 = torch.randint(8, size=(1028, 1028))
81968217
self.common(fn, (t1,))
81978218

8219+
@skipIfRocmArch(NAVI_ARCH)
81988220
def test_argmax_argmin_with_nan(self):
81998221
def fn(x):
82008222
return (
@@ -8327,6 +8349,7 @@ def fn(x):
83278349
],
83288350
)
83298351

8352+
@skipIfRocmArch(NAVI_ARCH)
83308353
def test_tmp_not_defined_issue1(self):
83318354
def forward(
83328355
primals_3,
@@ -8721,6 +8744,7 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None):
87218744
else:
87228745
self.assertEqual(len(inps), 0)
87238746

8747+
@skipIfRocmArch(NAVI_ARCH)
87248748
def test_dtype_mismatch_issue(self):
87258749
def fn(x):
87268750
attn = torch.nn.functional.pad(x, [0, 1])
@@ -11388,6 +11412,7 @@ def test_rnn_compile_safe(self):
1138811412

1138911413
class NanCheckerTest(TestCase):
1139011414
@config.patch("nan_asserts", True)
11415+
@skipIfRocmArch(NAVI_ARCH)
1139111416
def test_nan_checker_pass(self):
1139211417
def f(x):
1139311418
return torch.softmax(x, dim=-1)

test/inductor/test_torchinductor_opinfo.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from torch.testing._internal.common_methods_invocations import op_db, skipOps
3232
from torch.testing._internal.common_utils import (
3333
dtype_abbrs,
34+
IS_NAVI,
3435
IS_MACOS,
3536
IS_X86,
3637
skipCUDAMemoryLeakCheckIf,
@@ -203,6 +204,19 @@ def format_op(op):
203204
# Tensors are not alike
204205
inductor_skips["cuda"]["logcumsumexp"] = {f32}
205206
inductor_skips["cuda"]["special.modified_bessel_i1"] = {f64}
207+
if IS_NAVI:
208+
inductor_skips["cuda"]["aminmax"] = {b8, f16, f32, f64, i32, i64}
209+
inductor_skips["cuda"]["dist"] = {b8, f16, f32, f64, i32, i64}
210+
inductor_skips["cuda"]["kron"] = {b8, f16, f32, f64, i32, i64}
211+
inductor_skips["cuda"]["masked.std"] = {b8, f16, f32, f64, i32, i64}
212+
inductor_skips["cuda"]["masked.var"] = {b8, f16, f32, f64, i32, i64}
213+
inductor_skips["cuda"][("max", "reduction_no_dim")] = {b8, f16, f32, f64, i32, i64}
214+
inductor_skips["cuda"][("min", "reduction_no_dim")] = {b8, f16, f32, f64, i32, i64}
215+
inductor_skips["cuda"]["nn.functional.conv_transpose3d"] = {b8, f16, f32, f64, i32, i64}
216+
inductor_skips["cuda"]["std"] = {b8, f16, f32, f64, i32, i64}
217+
inductor_skips["cuda"]["std_mean"] = {b8, f16, f32, f64, i32, i64}
218+
inductor_skips["cuda"]["var"] = {b8, f16, f32, f64, i32, i64}
219+
inductor_skips["cuda"]["var_mean"] = {b8, f16, f32, f64, i32, i64}
206220

207221
inductor_expected_failures_single_sample = defaultdict(dict)
208222

torch/testing/_internal/common_utils.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1180,6 +1180,13 @@ def printErrors(self) -> None:
11801180
IS_X86 = platform.machine() in ('x86_64', 'i386')
11811181
IS_ARM64 = platform.machine() in ('arm64', 'aarch64')
11821182

1183+
IS_NAVI=False
1184+
if torch.cuda.is_available():
1185+
prop = torch.cuda.get_device_properties(0)
1186+
gfx_arch = prop.gcnArchName.split(":")[0]
1187+
if gfx_arch in ["gfx1100", "gfx1101", "gfx1102"]:
1188+
IS_NAVI = True
1189+
11831190
def is_avx512_vnni_supported():
11841191
if sys.platform != 'linux':
11851192
return False
@@ -1590,6 +1597,19 @@ def wrapper(*args, **kwargs):
15901597
return dec_fn(func)
15911598
return dec_fn
15921599

1600+
def skipIfRocmArch(arch: Tuple[str, ...]):
1601+
def dec_fn(fn):
1602+
@wraps(fn)
1603+
def wrap_fn(self, *args, **kwargs):
1604+
if TEST_WITH_ROCM:
1605+
prop = torch.cuda.get_device_properties(0)
1606+
if prop.gcnArchName.split(":")[0] in arch:
1607+
reason = f"skipIfRocm: test skipped on {arch}"
1608+
raise unittest.SkipTest(reason)
1609+
return fn(self, *args, **kwargs)
1610+
return wrap_fn
1611+
return dec_fn
1612+
15931613
def runOnRocm(fn):
15941614
@wraps(fn)
15951615
def wrapper(*args, **kwargs):
@@ -1599,6 +1619,19 @@ def wrapper(*args, **kwargs):
15991619
raise unittest.SkipTest("test currently only works on the ROCm stack")
16001620
return wrapper
16011621

1622+
def runOnRocmArch(arch: Tuple[str, ...]):
1623+
def dec_fn(fn):
1624+
@wraps(fn)
1625+
def wrap_fn(self, *args, **kwargs):
1626+
if TEST_WITH_ROCM:
1627+
prop = torch.cuda.get_device_properties(0)
1628+
if prop.gcnArchName.split(":")[0] not in arch:
1629+
reason = f"skipIfRocm: test skipped on {arch}"
1630+
raise unittest.SkipTest(reason)
1631+
return fn(self, *args, **kwargs)
1632+
return wrap_fn
1633+
return dec_fn
1634+
16021635
def skipIfXpu(func=None, *, msg="test doesn't currently work on the XPU stack"):
16031636
def dec_fn(fn):
16041637
reason = f"skipIfXpu: {msg}"

0 commit comments

Comments
 (0)