Skip to content

Commit 432b200

Browse files
[rocm6.4_internal_testing] Temporary disable NHWC batchnorm tests in ROCm6.4 (#1849)
Remove definition of PYTORCH_MIOPEN_SUGGEST_NHWC environment variable to disable NHWC batchnorm on MIOpen in ROCm6.4. NHWC can be enabled back by defining PYTORCH_MIOPEN_SUGGEST_NHWC=1 if needed Skip related NHWC batchnorm tests Waiting for NHWC batchnorm support from MIOpen (SWDEV-510757) @pruthvistony should I remove NHWC batchnorm completely or is it enough to disable it by default?
1 parent 15a2179 commit 432b200

File tree

1 file changed

+5
-24
lines changed

1 file changed

+5
-24
lines changed

test/test_nn.py

Lines changed: 5 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import warnings
1111
import pickle
1212
import re
13-
import os
1413
from copy import deepcopy
1514
from itertools import product
1615
from functools import partial
@@ -5098,7 +5097,7 @@ def test_batchnorm_nhwc_eval(self, mixed, dtype):
50985097
if TEST_WITH_ROCM and not mixed and dtype in (torch.half, torch.bfloat16):
50995098
self.skipTest("pure mode not supported for bf16/fp16 on ROCm")
51005099
if TEST_WITH_ROCM:
5101-
self.skipTest("MIOpen SolverNotFound for FP32/Fp16/BF16 NHWC batchnorm SWDEV-509640")
5100+
self.skipTest("NHWC batchnorm disabled on ROCm6.4 SWDEV-510757 SWDEV-509640")
51025101

51035102
(N, C, H, W) = 2, 64, 50, 50
51045103
model = torch.nn.BatchNorm2d(C, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
@@ -5159,6 +5158,8 @@ def _run_test(input, grad_output, mixed: bool):
51595158
grad = grad.permute(0, 2, 1, 3)
51605159
_run_test(input, grad, mixed)
51615160

5161+
if TEST_WITH_ROCM and layout == "NHWC":
5162+
self.skipTest("NHWC batchnorm disabled on ROCm6.4 SWDEV-510757")
51625163
if mixed and dtype == torch.float:
51635164
self.skipTest("mixed precision is useless for float32")
51645165
if TEST_WITH_ROCM and not mixed and dtype in (torch.half, torch.bfloat16):
@@ -5167,17 +5168,7 @@ def _run_test(input, grad_output, mixed: bool):
51675168
self.skipTest("MIOpen tolerance issue for NCHW BF16 mixed batchnorm SWDEV-507600")
51685169

51695170
memory_format = torch.contiguous_format if layout == "NCHW" else torch.channels_last
5170-
# TODO: Remove PYTORCH_MIOPEN_SUGGEST_NHWC once ROCm officially supports NHWC in MIOpen
5171-
PYTORCH_MIOPEN_SUGGEST_NHWC = "PYTORCH_MIOPEN_SUGGEST_NHWC"
5172-
prev_val = os.getenv(PYTORCH_MIOPEN_SUGGEST_NHWC)
5173-
try:
5174-
os.environ[PYTORCH_MIOPEN_SUGGEST_NHWC] = "1"
5175-
_batchnorm2d_helper(dtype, memory_format=memory_format, mixed=mixed)
5176-
finally:
5177-
if prev_val is None:
5178-
del os.environ[PYTORCH_MIOPEN_SUGGEST_NHWC]
5179-
else:
5180-
os.environ[PYTORCH_MIOPEN_SUGGEST_NHWC] = prev_val
5171+
_batchnorm2d_helper(dtype, memory_format=memory_format, mixed=mixed)
51815172

51825173
def test_batchnorm_load_state_dict(self):
51835174
bn = torch.nn.BatchNorm2d(3)
@@ -13203,14 +13194,4 @@ def __init__(self) -> None:
1320313194

1320413195
if __name__ == '__main__':
1320513196
TestCase._default_dtype_check_enabled = True
13206-
# TODO: Remove PYTORCH_MIOPEN_SUGGEST_NHWC once ROCm officially supports NHWC in MIOpen
13207-
PYTORCH_MIOPEN_SUGGEST_NHWC = "PYTORCH_MIOPEN_SUGGEST_NHWC"
13208-
prev_val = os.getenv(PYTORCH_MIOPEN_SUGGEST_NHWC)
13209-
try:
13210-
os.environ[PYTORCH_MIOPEN_SUGGEST_NHWC] = "1"
13211-
run_tests()
13212-
finally:
13213-
if prev_val is None:
13214-
del os.environ[PYTORCH_MIOPEN_SUGGEST_NHWC]
13215-
else:
13216-
os.environ[PYTORCH_MIOPEN_SUGGEST_NHWC] = prev_val
13197+
run_tests()

0 commit comments

Comments
 (0)