|
9 | 9 | import warnings
|
10 | 10 | import pickle
|
11 | 11 | import re
|
| 12 | +import os |
12 | 13 | from copy import deepcopy
|
13 | 14 | from itertools import product
|
14 | 15 | from functools import partial
|
@@ -4860,6 +4861,54 @@ def run_test(input, grad_output):
|
4860 | 4861 | grad = grad.permute(0, 2, 1, 3)
|
4861 | 4862 | run_test(input, grad)
|
4862 | 4863 |
|
| 4864 | + @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") |
| 4865 | + @unittest.skipIf(not TEST_CUDNN, "needs cudnn") |
| 4866 | + def test_batchnorm_nhwc_miopen(self): |
| 4867 | + def run_test(input, grad_output): |
| 4868 | + c = input.size(1) |
| 4869 | + mod = nn.BatchNorm2d(c).cuda().float() |
| 4870 | + mod.weight.data.uniform_() |
| 4871 | + mod.bias.data.uniform_() |
| 4872 | + ref_input = input.detach().clone(memory_format=torch.preserve_format).requires_grad_(True) |
| 4873 | + ref_grad = grad.detach().clone(memory_format=torch.preserve_format) |
| 4874 | + ref_mod = nn.BatchNorm2d(c).cuda().float() |
| 4875 | + ref_mod.load_state_dict(mod.state_dict()) |
| 4876 | + out = mod(input) |
| 4877 | + out.backward(grad_output) |
| 4878 | + with torch.backends.cudnn.flags(enabled=False): # force to use native nhwc batchnorm |
| 4879 | + ref_out = ref_mod(ref_input) |
| 4880 | + ref_out.backward(ref_grad) |
| 4881 | + self.assertTrue(out.is_contiguous(memory_format=torch.channels_last)) |
| 4882 | + self.assertTrue(ref_out.is_contiguous(memory_format=torch.channels_last)) |
| 4883 | + self.assertEqual(out, ref_out) |
| 4884 | + self.assertEqual(mod.weight.grad, ref_mod.weight.grad) |
| 4885 | + self.assertEqual(mod.bias.grad, ref_mod.bias.grad) |
| 4886 | + self.assertEqual(input.grad, ref_input.grad) |
| 4887 | + |
| 4888 | + # TODO: Remove PYTORCH_MIOPEN_SUGGEST_NHWC once ROCm officially supports NHWC in MIOpen |
| 4889 | + PYTORCH_MIOPEN_SUGGEST_NHWC = "PYTORCH_MIOPEN_SUGGEST_NHWC" |
| 4890 | + prev_val = os.getenv(PYTORCH_MIOPEN_SUGGEST_NHWC) |
| 4891 | + try: |
| 4892 | + os.environ[PYTORCH_MIOPEN_SUGGEST_NHWC] = "1" |
| 4893 | + input = torch.randint(1, 10, (4, 8, 2, 2), dtype=torch.float32, device="cuda") |
| 4894 | + input = input.contiguous(memory_format=torch.channels_last).detach().requires_grad_() |
| 4895 | + |
| 4896 | + grad = torch.randint(1, 10, (4, 8, 2, 2), dtype=torch.float32, device="cuda") |
| 4897 | + grad = grad.contiguous(memory_format=torch.channels_last) |
| 4898 | + run_test(input, grad) |
| 4899 | + # see #42588, grad is channels_last contiguous, but grad.suggest_memory_format (rightly) return "contiguous" |
| 4900 | + # not channels_last |
| 4901 | + input = torch.randint(1, 10, (2, 8, 8, 1), dtype=torch.float32, device="cuda") |
| 4902 | + input = input.contiguous(memory_format=torch.channels_last).detach().requires_grad_() |
| 4903 | + grad = torch.randint(1, 10, (2, 8, 8, 1), dtype=torch.float32, device="cuda") |
| 4904 | + grad = grad.permute(0, 2, 1, 3) |
| 4905 | + run_test(input, grad) |
| 4906 | + finally: |
| 4907 | + if prev_val is None: |
| 4908 | + del os.environ[PYTORCH_MIOPEN_SUGGEST_NHWC] |
| 4909 | + else: |
| 4910 | + os.environ[PYTORCH_MIOPEN_SUGGEST_NHWC] = prev_val |
| 4911 | + |
4863 | 4912 | @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
|
4864 | 4913 | def test_batchnorm_cudnn_half(self):
|
4865 | 4914 | # THNN
|
@@ -12838,4 +12887,14 @@ def __init__(self):
|
12838 | 12887 |
|
12839 | 12888 | if __name__ == '__main__':
|
12840 | 12889 | TestCase._default_dtype_check_enabled = True
|
12841 |
| - run_tests() |
| 12890 | + # TODO: Remove PYTORCH_MIOPEN_SUGGEST_NHWC once ROCm officially supports NHWC in MIOpen |
| 12891 | + PYTORCH_MIOPEN_SUGGEST_NHWC = "PYTORCH_MIOPEN_SUGGEST_NHWC" |
| 12892 | + prev_val = os.getenv(PYTORCH_MIOPEN_SUGGEST_NHWC) |
| 12893 | + try: |
| 12894 | + os.environ[PYTORCH_MIOPEN_SUGGEST_NHWC] = "1" |
| 12895 | + run_tests() |
| 12896 | + finally: |
| 12897 | + if prev_val is None: |
| 12898 | + del os.environ[PYTORCH_MIOPEN_SUGGEST_NHWC] |
| 12899 | + else: |
| 12900 | + os.environ[PYTORCH_MIOPEN_SUGGEST_NHWC] = prev_val |
0 commit comments