Skip to content

Commit 91125f1

Browse files
BLOrange-AMDdnikolaev-amd
authored andcommitted
Converted NAVI check as a function (#1364)
* Moved NAVI check to the test file * Revised NAVI check as a function
1 parent 9726c26 commit 91125f1

File tree

2 files changed

+9
-8
lines changed

2 files changed

+9
-8
lines changed

test/inductor/test_torchinductor_opinfo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@
3131
from torch.testing._internal.common_methods_invocations import op_db, skipOps
3232
from torch.testing._internal.common_utils import (
3333
dtype_abbrs,
34-
IS_NAVI,
3534
IS_MACOS,
3635
IS_X86,
36+
is_navi_arch,
3737
skipCUDAMemoryLeakCheckIf,
3838
skipIfCrossRef,
3939
skipIfTorchDynamo,
@@ -204,7 +204,7 @@ def format_op(op):
204204
# Tensors are not alike
205205
inductor_skips["cuda"]["logcumsumexp"] = {f32}
206206
inductor_skips["cuda"]["special.modified_bessel_i1"] = {f64}
207-
if IS_NAVI:
207+
if is_navi_arch():
208208
inductor_skips["cuda"]["aminmax"] = {b8, f16, f32, f64, i32, i64}
209209
inductor_skips["cuda"]["dist"] = {b8, f16, f32, f64, i32, i64}
210210
inductor_skips["cuda"]["kron"] = {b8, f16, f32, f64, i32, i64}

torch/testing/_internal/common_utils.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1180,12 +1180,13 @@ def printErrors(self) -> None:
11801180
IS_X86 = platform.machine() in ('x86_64', 'i386')
11811181
IS_ARM64 = platform.machine() in ('arm64', 'aarch64')
11821182

1183-
IS_NAVI=False
1184-
if torch.cuda.is_available():
1185-
prop = torch.cuda.get_device_properties(0)
1186-
gfx_arch = prop.gcnArchName.split(":")[0]
1187-
if gfx_arch in ["gfx1100", "gfx1101", "gfx1102"]:
1188-
IS_NAVI = True
1183+
def is_navi_arch():
1184+
if torch.cuda.is_available():
1185+
prop = torch.cuda.get_device_properties(0)
1186+
gfx_arch = prop.gcnArchName.split(":")[0]
1187+
if gfx_arch in ["gfx1100", "gfx1101", "gfx1102"]:
1188+
return True
1189+
return False
11891190

11901191
def is_avx512_vnni_supported():
11911192
if sys.platform != 'linux':

0 commit comments

Comments
 (0)