|
9 | 9 | import warnings
|
10 | 10 | import pickle
|
11 | 11 | import re
|
| 12 | +import os |
12 | 13 | from copy import deepcopy
|
13 | 14 | from itertools import product
|
14 | 15 | from functools import partial
|
@@ -4928,6 +4929,54 @@ def run_test(input, grad_output):
|
4928 | 4929 | grad = grad.permute(0, 2, 1, 3)
|
4929 | 4930 | run_test(input, grad)
|
4930 | 4931 |
|
| 4932 | + @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") |
| 4933 | + @unittest.skipIf(not TEST_CUDNN, "needs cudnn") |
| 4934 | + def test_batchnorm_nhwc_miopen(self): |
| 4935 | + def run_test(input, grad_output): |
| 4936 | + c = input.size(1) |
| 4937 | + mod = nn.BatchNorm2d(c).cuda().float() |
| 4938 | + mod.weight.data.uniform_() |
| 4939 | + mod.bias.data.uniform_() |
| 4940 | + ref_input = input.detach().clone(memory_format=torch.preserve_format).requires_grad_(True) |
| 4941 | + ref_grad = grad.detach().clone(memory_format=torch.preserve_format) |
| 4942 | + ref_mod = nn.BatchNorm2d(c).cuda().float() |
| 4943 | + ref_mod.load_state_dict(mod.state_dict()) |
| 4944 | + out = mod(input) |
| 4945 | + out.backward(grad_output) |
| 4946 | + with torch.backends.cudnn.flags(enabled=False): # force to use native nhwc batchnorm |
| 4947 | + ref_out = ref_mod(ref_input) |
| 4948 | + ref_out.backward(ref_grad) |
| 4949 | + self.assertTrue(out.is_contiguous(memory_format=torch.channels_last)) |
| 4950 | + self.assertTrue(ref_out.is_contiguous(memory_format=torch.channels_last)) |
| 4951 | + self.assertEqual(out, ref_out) |
| 4952 | + self.assertEqual(mod.weight.grad, ref_mod.weight.grad) |
| 4953 | + self.assertEqual(mod.bias.grad, ref_mod.bias.grad) |
| 4954 | + self.assertEqual(input.grad, ref_input.grad) |
| 4955 | + |
| 4956 | + # TODO: Remove PYTORCH_MIOPEN_SUGGEST_NHWC once ROCm officially supports NHWC in MIOpen |
| 4957 | + PYTORCH_MIOPEN_SUGGEST_NHWC = "PYTORCH_MIOPEN_SUGGEST_NHWC" |
| 4958 | + prev_val = os.getenv(PYTORCH_MIOPEN_SUGGEST_NHWC) |
| 4959 | + try: |
| 4960 | + os.environ[PYTORCH_MIOPEN_SUGGEST_NHWC] = "1" |
| 4961 | + input = torch.randint(1, 10, (4, 8, 2, 2), dtype=torch.float32, device="cuda") |
| 4962 | + input = input.contiguous(memory_format=torch.channels_last).detach().requires_grad_() |
| 4963 | + |
| 4964 | + grad = torch.randint(1, 10, (4, 8, 2, 2), dtype=torch.float32, device="cuda") |
| 4965 | + grad = grad.contiguous(memory_format=torch.channels_last) |
| 4966 | + run_test(input, grad) |
| 4967 | + # see #42588, grad is channels_last contiguous, but grad.suggest_memory_format (rightly) return "contiguous" |
| 4968 | + # not channels_last |
| 4969 | + input = torch.randint(1, 10, (2, 8, 8, 1), dtype=torch.float32, device="cuda") |
| 4970 | + input = input.contiguous(memory_format=torch.channels_last).detach().requires_grad_() |
| 4971 | + grad = torch.randint(1, 10, (2, 8, 8, 1), dtype=torch.float32, device="cuda") |
| 4972 | + grad = grad.permute(0, 2, 1, 3) |
| 4973 | + run_test(input, grad) |
| 4974 | + finally: |
| 4975 | + if prev_val is None: |
| 4976 | + del os.environ[PYTORCH_MIOPEN_SUGGEST_NHWC] |
| 4977 | + else: |
| 4978 | + os.environ[PYTORCH_MIOPEN_SUGGEST_NHWC] = prev_val |
| 4979 | + |
4931 | 4980 | @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
|
4932 | 4981 | def test_batchnorm_cudnn_half(self):
|
4933 | 4982 | # THNN
|
@@ -13023,4 +13072,14 @@ def __init__(self) -> None:
|
13023 | 13072 |
|
13024 | 13073 | if __name__ == '__main__':
|
13025 | 13074 | TestCase._default_dtype_check_enabled = True
|
13026 |
| - run_tests() |
| 13075 | + # TODO: Remove PYTORCH_MIOPEN_SUGGEST_NHWC once ROCm officially supports NHWC in MIOpen |
| 13076 | + PYTORCH_MIOPEN_SUGGEST_NHWC = "PYTORCH_MIOPEN_SUGGEST_NHWC" |
| 13077 | + prev_val = os.getenv(PYTORCH_MIOPEN_SUGGEST_NHWC) |
| 13078 | + try: |
| 13079 | + os.environ[PYTORCH_MIOPEN_SUGGEST_NHWC] = "1" |
| 13080 | + run_tests() |
| 13081 | + finally: |
| 13082 | + if prev_val is None: |
| 13083 | + del os.environ[PYTORCH_MIOPEN_SUGGEST_NHWC] |
| 13084 | + else: |
| 13085 | + os.environ[PYTORCH_MIOPEN_SUGGEST_NHWC] = prev_val |
0 commit comments