Skip to content

Commit 4380b15

Browse files
committed
Enable NHWC batchnorm for miopen (#1400)
* Enable batchnorm NHWC for MIOpen * cleanup * test to compare NHWC MIOpen batchnorm with CPU * fix 'use_miopen' condition for nhwc miopen * fix includes * use native nhwc batchnorm to verify miopen * remove extra spaces * remove empty lines * set PYTORCH_MIOPEN_SUGGEST_NHWC=1 for all test_nn.py test
1 parent 90df487 commit 4380b15

File tree

3 files changed

+77
-8
lines changed

3 files changed

+77
-8
lines changed

aten/src/ATen/native/Normalization.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,11 @@ BatchNormBackend _select_batch_norm_backend(
510510
return BatchNormBackend::Cudnn;
511511
}
512512

513+
// TODO: Remove PYTORCH_MIOPEN_SUGGEST_NHWC once ROCm officially supports NHWC in MIOpen
514+
// See #64427
515+
// non static variable is used to be able to change environment variable in runtime for testing
516+
bool PYTORCH_MIOPEN_SUGGEST_NHWC = c10::utils::check_env("PYTORCH_MIOPEN_SUGGEST_NHWC").value_or(false);
517+
513518
if (
514519
input.is_cuda()
515520
&& input.dim() <= MIOPEN_DIM_MAX
@@ -522,8 +527,8 @@ BatchNormBackend _select_batch_norm_backend(
522527
&& (input.dim() >= 3)
523528
&& detail::getCUDAHooks().compiledWithMIOpen()
524529
&& cudnn_enabled
525-
&& input.suggest_memory_format() != MemoryFormat::ChannelsLast
526-
&& input.suggest_memory_format() != MemoryFormat::ChannelsLast3d
530+
&& (input.suggest_memory_format() == MemoryFormat::Contiguous
531+
|| (input.suggest_memory_format() == MemoryFormat::ChannelsLast && PYTORCH_MIOPEN_SUGGEST_NHWC))
527532
) {
528533
return BatchNormBackend::Miopen;
529534
}
@@ -603,7 +608,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t> _batch_norm_impl_index(
603608
if (backend == BatchNormBackend::Miopen) {
604609
return std::tuple_cat(
605610
at::miopen_batch_norm(
606-
input.contiguous(), weight.contiguous(), bias.contiguous(),
611+
input.contiguous(input.suggest_memory_format()), weight.contiguous(), bias.contiguous(),
607612
running_mean.defined() ? running_mean.contiguous() : running_mean,
608613
running_var.defined() ? running_var.contiguous() : running_var,
609614
training, momentum, eps),

aten/src/ATen/native/miopen/BatchNorm_miopen.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm(
100100
mode = miopenBNSpatial;
101101
}
102102

103-
auto output_t = at::empty(input->sizes(), input->options());
103+
auto output_t = at::empty(input->sizes(), input->options(), input->suggest_memory_format());
104104
TensorArg output{ output_t, "output", 0 };
105105

106106
auto handle = getMiopenHandle();
@@ -177,8 +177,10 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm_backward(
177177
const Tensor& save_var_t =
178178
c10::value_or_else(save_var_t_opt, [] { return Tensor(); });
179179

180+
auto grad_output_contig =
181+
grad_output_t.contiguous(input_t.suggest_memory_format());
180182
TensorArg input{ input_t, "input", 1 },
181-
grad_output{ grad_output_t, "grad_output", 2 },
183+
grad_output{ grad_output_contig, "grad_output", 2 },
182184
weight{ weight_t, "weight", 3 },
183185
save_mean{ save_mean_t, "save_mean", 4 },
184186
save_var{ save_var_t, "save_var", 5 };
@@ -193,7 +195,9 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm_backward(
193195
}
194196
checkAllSameType(c, {input, grad_output});
195197
checkAllSameType(c, {weight, save_mean, save_var});
196-
checkAllContiguous(c, {input, grad_output, save_mean, save_var});
198+
checkAllContiguous(c, {save_mean, save_var});
199+
TORCH_CHECK(input->is_contiguous(input->suggest_memory_format()));
200+
TORCH_CHECK(grad_output->is_contiguous(input->suggest_memory_format()));
197201
checkDimRange(c, input, 2, 6 /* exclusive */);
198202
checkSameSize(c, input, grad_output);
199203
auto num_features = input->size(1);
@@ -208,7 +212,8 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm_backward(
208212
mode = miopenBNSpatial;
209213
}
210214

211-
auto grad_input_t = at::empty(input->sizes(), input->options());
215+
auto grad_input_t = at::empty(
216+
input->sizes(), input->options(), input->suggest_memory_format());
212217
auto grad_weight_t = at::empty(weight->sizes(), weight->options());
213218
auto grad_bias_t = at::empty(weight->sizes(), weight->options());
214219

test/test_nn.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import warnings
1010
import pickle
1111
import re
12+
import os
1213
from copy import deepcopy
1314
from itertools import product
1415
from functools import partial
@@ -4860,6 +4861,54 @@ def run_test(input, grad_output):
48604861
grad = grad.permute(0, 2, 1, 3)
48614862
run_test(input, grad)
48624863

4864+
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
4865+
@unittest.skipIf(not TEST_CUDNN, "needs cudnn")
4866+
def test_batchnorm_nhwc_miopen(self):
4867+
def run_test(input, grad_output):
4868+
c = input.size(1)
4869+
mod = nn.BatchNorm2d(c).cuda().float()
4870+
mod.weight.data.uniform_()
4871+
mod.bias.data.uniform_()
4872+
ref_input = input.detach().clone(memory_format=torch.preserve_format).requires_grad_(True)
4873+
ref_grad = grad.detach().clone(memory_format=torch.preserve_format)
4874+
ref_mod = nn.BatchNorm2d(c).cuda().float()
4875+
ref_mod.load_state_dict(mod.state_dict())
4876+
out = mod(input)
4877+
out.backward(grad_output)
4878+
with torch.backends.cudnn.flags(enabled=False): # force to use native nhwc batchnorm
4879+
ref_out = ref_mod(ref_input)
4880+
ref_out.backward(ref_grad)
4881+
self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
4882+
self.assertTrue(ref_out.is_contiguous(memory_format=torch.channels_last))
4883+
self.assertEqual(out, ref_out)
4884+
self.assertEqual(mod.weight.grad, ref_mod.weight.grad)
4885+
self.assertEqual(mod.bias.grad, ref_mod.bias.grad)
4886+
self.assertEqual(input.grad, ref_input.grad)
4887+
4888+
# TODO: Remove PYTORCH_MIOPEN_SUGGEST_NHWC once ROCm officially supports NHWC in MIOpen
4889+
PYTORCH_MIOPEN_SUGGEST_NHWC = "PYTORCH_MIOPEN_SUGGEST_NHWC"
4890+
prev_val = os.getenv(PYTORCH_MIOPEN_SUGGEST_NHWC)
4891+
try:
4892+
os.environ[PYTORCH_MIOPEN_SUGGEST_NHWC] = "1"
4893+
input = torch.randint(1, 10, (4, 8, 2, 2), dtype=torch.float32, device="cuda")
4894+
input = input.contiguous(memory_format=torch.channels_last).detach().requires_grad_()
4895+
4896+
grad = torch.randint(1, 10, (4, 8, 2, 2), dtype=torch.float32, device="cuda")
4897+
grad = grad.contiguous(memory_format=torch.channels_last)
4898+
run_test(input, grad)
4899+
# see #42588, grad is channels_last contiguous, but grad.suggest_memory_format (rightly) return "contiguous"
4900+
# not channels_last
4901+
input = torch.randint(1, 10, (2, 8, 8, 1), dtype=torch.float32, device="cuda")
4902+
input = input.contiguous(memory_format=torch.channels_last).detach().requires_grad_()
4903+
grad = torch.randint(1, 10, (2, 8, 8, 1), dtype=torch.float32, device="cuda")
4904+
grad = grad.permute(0, 2, 1, 3)
4905+
run_test(input, grad)
4906+
finally:
4907+
if prev_val is None:
4908+
del os.environ[PYTORCH_MIOPEN_SUGGEST_NHWC]
4909+
else:
4910+
os.environ[PYTORCH_MIOPEN_SUGGEST_NHWC] = prev_val
4911+
48634912
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
48644913
def test_batchnorm_cudnn_half(self):
48654914
# THNN
@@ -12838,4 +12887,14 @@ def __init__(self):
1283812887

1283912888
if __name__ == '__main__':
1284012889
TestCase._default_dtype_check_enabled = True
12841-
run_tests()
12890+
# TODO: Remove PYTORCH_MIOPEN_SUGGEST_NHWC once ROCm officially supports NHWC in MIOpen
12891+
PYTORCH_MIOPEN_SUGGEST_NHWC = "PYTORCH_MIOPEN_SUGGEST_NHWC"
12892+
prev_val = os.getenv(PYTORCH_MIOPEN_SUGGEST_NHWC)
12893+
try:
12894+
os.environ[PYTORCH_MIOPEN_SUGGEST_NHWC] = "1"
12895+
run_tests()
12896+
finally:
12897+
if prev_val is None:
12898+
del os.environ[PYTORCH_MIOPEN_SUGGEST_NHWC]
12899+
else:
12900+
os.environ[PYTORCH_MIOPEN_SUGGEST_NHWC] = prev_val

0 commit comments

Comments
 (0)