Skip to content

Commit 90ac508

Browse files
jataylopruthvistony
authored andcommitted
Add skipIfRocmArch decorator for Navi skips (#1356)
1 parent 34d4129 commit 90ac508

File tree

4 files changed

+74
-2
lines changed

4 files changed

+74
-2
lines changed

test/inductor/test_cuda_repro.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
freeze_rng_state,
2424
IS_FBCODE,
2525
skipIfRocm,
26+
skipIfRocmArch,
2627
TEST_WITH_ASAN,
2728
)
2829

@@ -42,7 +43,7 @@
4243
sys.exit(0)
4344
raise
4445

45-
46+
NAVI_ARCH = ("gfx1100", "gfx1101") # Used for navi exclusive skips on ROCm
4647
TestCase = test_torchinductor.TestCase
4748
ToTuple = test_torchinductor.ToTuple
4849
check_model_cuda = test_torchinductor.check_model_cuda
@@ -307,6 +308,7 @@ def foo(x):
307308
out_ref.add_(2)
308309
# self.assertEqual(out_ref, out)
309310

311+
@skipIfRocmArch(NAVI_ARCH)
310312
def test_accuracy_issue1(self):
311313
class Repro(torch.nn.Module):
312314
def __init__(self):
@@ -343,6 +345,7 @@ def forward(self, start_positions: torch.Tensor, x: torch.Tensor):
343345
assert same_two_models(mod, opt_mod, args), "Dynamo failed"
344346

345347
@config.patch(allow_buffer_reuse=False)
348+
@skipIfRocmArch(NAVI_ARCH)
346349
def test_issue103461(self):
347350
def forward(add_1):
348351
var_mean = torch.ops.aten.var_mean.correction(
@@ -826,6 +829,7 @@ def forward(self, x):
826829
res2 = jit_func(x)
827830
self.assertEqual(res1, res2)
828831

832+
@skipIfRocmArch(NAVI_ARCH)
829833
def test_issue103481(self):
830834
def fn(x, y):
831835
# NOTE: 6 dimensions is important! does not fail for 5 dimensions

test/inductor/test_torchinductor.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@
7272
parametrize,
7373
serialTest,
7474
skipIfRocm,
75-
subtest,
75+
skipIfRocmArch,
7676
TEST_WITH_ASAN,
7777
TEST_WITH_ROCM,
7878
)
@@ -118,6 +118,7 @@
118118
_desired_test_bases = get_desired_device_type_test_bases()
119119
RUN_CPU = any(getattr(x, "device_type", "") == "cpu" for x in _desired_test_bases)
120120
RUN_GPU = any(getattr(x, "device_type", "") == GPU_TYPE for x in _desired_test_bases)
121+
NAVI_ARCH = ("gfx1100", "gfx1101") # Used for navi exclusive skips on ROCm
121122

122123
aten = torch.ops.aten
123124
requires_gpu = functools.partial(unittest.skipIf, not HAS_GPU, "requires gpu")
@@ -1246,6 +1247,7 @@ def fn(x):
12461247
# make sure things also work if they aren't unrolled
12471248
self.common(fn, (torch.randn(8, 3),))
12481249

1250+
@skipIfRocmArch(NAVI_ARCH)
12491251
def test_multilayer_sum_low_prec(self):
12501252
# fp16 nyi for cpu
12511253
if self.device == "cpu":
@@ -1256,6 +1258,7 @@ def fn(a):
12561258

12571259
self.common(fn, ((torch.rand((10, 3, 352, 352), dtype=torch.float16),)))
12581260

1261+
@skipIfRocmArch(NAVI_ARCH)
12591262
def test_multilayer_prime_size(self):
12601263
def fn(a):
12611264
return torch.max(a), torch.sum(a)
@@ -1265,6 +1268,7 @@ def fn(a):
12651268
sample[-1] = 1
12661269
self.common(fn, (sample,))
12671270

1271+
@skipIfRocmArch(NAVI_ARCH)
12681272
def test_multilayer_var(self):
12691273
def fn(a):
12701274
return torch.var(a)
@@ -2131,6 +2135,7 @@ def fn(a, b):
21312135

21322136
self.common(fn, (torch.randn(8, 8), torch.randn(8, 8)))
21332137

2138+
@skipIfRocmArch(NAVI_ARCH)
21342139
def test_large_tensor_reduction(self):
21352140
if not _has_sufficient_memory(self.device, 4.5 * 1024**3): # 4.5 GiB
21362141
raise unittest.SkipTest("insufficient memory")
@@ -2151,6 +2156,7 @@ def fn(a):
21512156
expect = torch.tensor(2, dtype=torch.int8, device=self.device)
21522157
self.assertEqual(actual, expect)
21532158

2159+
@skipIfRocmArch(NAVI_ARCH)
21542160
def test_large_broadcast_reduction(self):
21552161
if self.device == "cpu":
21562162
raise unittest.SkipTest("Fails on CPU")
@@ -3204,6 +3210,7 @@ def test_conv2d_channels_last(self):
32043210
check_lowp=False,
32053211
)
32063212

3213+
@skipIfRocmArch(NAVI_ARCH)
32073214
def test_conv2d_backward_channels_last(self):
32083215
def fn(grad_output, inp, weight):
32093216
convolution_backward_8 = torch.ops.aten.convolution_backward.default(
@@ -3949,6 +3956,7 @@ def fn(x, y):
39493956
self.assertEqual(a.stride(), c.stride())
39503957
self.assertEqual(c.stride()[2], 1)
39513958

3959+
@skipIfRocmArch(NAVI_ARCH)
39523960
def test_std(self):
39533961
def fn(x):
39543962
return (
@@ -3991,6 +3999,7 @@ def test_batch_norm_2d(self):
39913999

39924000
# From yolov3
39934001
@with_tf32_off
4002+
@skipIfRocmArch(NAVI_ARCH)
39944003
def test_batch_norm_2d_2(self):
39954004
if self.device == "cpu":
39964005
raise unittest.SkipTest(f"requires {GPU_TYPE}")
@@ -4126,6 +4135,7 @@ def fn(x):
41264135

41274136
self.common(fn, (x,))
41284137

4138+
@skipIfRocmArch(NAVI_ARCH)
41294139
def test_cauchy(self):
41304140
def fn(x, y):
41314141
return torch.sum(1 / (torch.unsqueeze(x, -1) - y))
@@ -5395,6 +5405,7 @@ def fn(a):
53955405
y = fn_compiled(x)
53965406
self.assertTrue(y is not x)
53975407

5408+
@skipIfRocmArch(NAVI_ARCH)
53985409
def test_l1_loss(self):
53995410
def fn(a, b):
54005411
return torch.nn.functional.l1_loss(a, b), torch.nn.functional.mse_loss(a, b)
@@ -5791,6 +5802,7 @@ def fn(x):
57915802
fn, (torch.tensor([1, float("inf"), 2, float("-inf"), float("nan")]),)
57925803
)
57935804

5805+
@skipIfRocmArch(NAVI_ARCH)
57945806
def test_any(self):
57955807
def fn(x):
57965808
return (
@@ -6478,6 +6490,8 @@ def fn(a, dim, index, b, reduce):
64786490
],
64796491
)
64806492

6493+
# issue #1150
6494+
@skipIfRocmArch(NAVI_ARCH)
64816495
def test_dense_mask_index(self):
64826496
r"""
64836497
There will be a little difference for reduce order between aten and inductor
@@ -7361,6 +7375,7 @@ def fn(a, b):
73617375
b = torch.rand(2, 2, 1, 4, 1).int()
73627376
self.common(fn, (a, b))
73637377

7378+
@skipIfRocmArch(NAVI_ARCH)
73647379
def test_argmax_argmin1(self):
73657380
def fn(x):
73667381
return (aten.argmax(x), aten.argmin(x))
@@ -7372,6 +7387,7 @@ def fn(x):
73727387
],
73737388
)
73747389

7390+
@skipIfRocmArch(NAVI_ARCH)
73757391
def test_argmax_argmin2(self):
73767392
def fn(x):
73777393
return (
@@ -7383,6 +7399,7 @@ def fn(x):
73837399

73847400
self.common(fn, (torch.randn([144, 144]),))
73857401

7402+
@skipIfRocmArch(NAVI_ARCH)
73867403
def test_argmax_argmin_with_duplicates(self):
73877404
def fn(x):
73887405
return (
@@ -7404,6 +7421,7 @@ def fn(x):
74047421
t1 = torch.randint(8, size=(1028, 1028))
74057422
self.common(fn, (t1,))
74067423

7424+
@skipIfRocmArch(NAVI_ARCH)
74077425
def test_argmax_argmin_with_nan(self):
74087426
def fn(x):
74097427
return (
@@ -7536,6 +7554,7 @@ def fn(x):
75367554
],
75377555
)
75387556

7557+
@skipIfRocmArch(NAVI_ARCH)
75397558
def test_tmp_not_defined_issue1(self):
75407559
def forward(
75417560
primals_3,
@@ -7930,6 +7949,7 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None):
79307949
else:
79317950
self.assertEqual(len(inps), 0)
79327951

7952+
@skipIfRocmArch(NAVI_ARCH)
79337953
def test_dtype_mismatch_issue(self):
79347954
def fn(x):
79357955
attn = torch.nn.functional.pad(x, [0, 1])
@@ -10414,6 +10434,7 @@ def test_rnn_compile_safe(self):
1041410434

1041510435
class NanCheckerTest(TestCase):
1041610436
@config.patch("nan_asserts", True)
10437+
@skipIfRocmArch(NAVI_ARCH)
1041710438
def test_nan_checker_pass(self):
1041810439
def f(x):
1041910440
return torch.softmax(x, dim=-1)

test/inductor/test_torchinductor_opinfo.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from torch.testing._internal.common_methods_invocations import op_db, skipOps
3232
from torch.testing._internal.common_utils import (
3333
dtype_abbrs,
34+
IS_NAVI,
3435
IS_MACOS,
3536
IS_X86,
3637
skipCUDAMemoryLeakCheckIf,
@@ -203,6 +204,19 @@ def format_op(op):
203204
# Tensors are not alike
204205
inductor_skips["cuda"]["logcumsumexp"] = {f32}
205206
inductor_skips["cuda"]["special.modified_bessel_i1"] = {f64}
207+
if IS_NAVI:
208+
inductor_skips["cuda"]["aminmax"] = {b8, f16, f32, f64, i32, i64}
209+
inductor_skips["cuda"]["dist"] = {b8, f16, f32, f64, i32, i64}
210+
inductor_skips["cuda"]["kron"] = {b8, f16, f32, f64, i32, i64}
211+
inductor_skips["cuda"]["masked.std"] = {b8, f16, f32, f64, i32, i64}
212+
inductor_skips["cuda"]["masked.var"] = {b8, f16, f32, f64, i32, i64}
213+
inductor_skips["cuda"][("max", "reduction_no_dim")] = {b8, f16, f32, f64, i32, i64}
214+
inductor_skips["cuda"][("min", "reduction_no_dim")] = {b8, f16, f32, f64, i32, i64}
215+
inductor_skips["cuda"]["nn.functional.conv_transpose3d"] = {b8, f16, f32, f64, i32, i64}
216+
inductor_skips["cuda"]["std"] = {b8, f16, f32, f64, i32, i64}
217+
inductor_skips["cuda"]["std_mean"] = {b8, f16, f32, f64, i32, i64}
218+
inductor_skips["cuda"]["var"] = {b8, f16, f32, f64, i32, i64}
219+
inductor_skips["cuda"]["var_mean"] = {b8, f16, f32, f64, i32, i64}
206220

207221
inductor_expected_failures_single_sample = defaultdict(dict)
208222

torch/testing/_internal/common_utils.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1177,6 +1177,13 @@ def printErrors(self) -> None:
11771177
IS_X86 = platform.machine() in ('x86_64', 'i386')
11781178
IS_ARM64 = platform.machine() in ('arm64', 'aarch64')
11791179

1180+
IS_NAVI=False
1181+
if torch.cuda.is_available():
1182+
prop = torch.cuda.get_device_properties(0)
1183+
gfx_arch = prop.gcnArchName.split(":")[0]
1184+
if gfx_arch in ["gfx1100", "gfx1101", "gfx1102"]:
1185+
IS_NAVI = True
1186+
11801187
def is_avx512_vnni_supported():
11811188
if sys.platform != 'linux':
11821189
return False
@@ -1560,6 +1567,19 @@ def wrapper(*args, **kwargs):
15601567
return dec_fn(func)
15611568
return dec_fn
15621569

1570+
def skipIfRocmArch(arch: Tuple[str, ...]):
1571+
def dec_fn(fn):
1572+
@wraps(fn)
1573+
def wrap_fn(self, *args, **kwargs):
1574+
if TEST_WITH_ROCM:
1575+
prop = torch.cuda.get_device_properties(0)
1576+
if prop.gcnArchName.split(":")[0] in arch:
1577+
reason = f"skipIfRocm: test skipped on {arch}"
1578+
raise unittest.SkipTest(reason)
1579+
return fn(self, *args, **kwargs)
1580+
return wrap_fn
1581+
return dec_fn
1582+
15631583
def runOnRocm(fn):
15641584
@wraps(fn)
15651585
def wrapper(*args, **kwargs):
@@ -1569,6 +1589,19 @@ def wrapper(*args, **kwargs):
15691589
raise unittest.SkipTest("test currently only works on the ROCm stack")
15701590
return wrapper
15711591

1592+
def runOnRocmArch(arch: Tuple[str, ...]):
1593+
def dec_fn(fn):
1594+
@wraps(fn)
1595+
def wrap_fn(self, *args, **kwargs):
1596+
if TEST_WITH_ROCM:
1597+
prop = torch.cuda.get_device_properties(0)
1598+
if prop.gcnArchName.split(":")[0] not in arch:
1599+
reason = f"skipIfRocm: test skipped on {arch}"
1600+
raise unittest.SkipTest(reason)
1601+
return fn(self, *args, **kwargs)
1602+
return wrap_fn
1603+
return dec_fn
1604+
15721605
def skipIfXpu(func=None, *, msg="test doesn't currently work on the XPU stack"):
15731606
def dec_fn(fn):
15741607
reason = f"skipIfXpu: {msg}"

0 commit comments

Comments
 (0)