Skip to content

Commit 4c94122

Browse files
committed
Enable NHWC batchnorm for miopen (#1400)
* Enable batchnorm NHWC for MIOpen * cleanup * test to compare NHWC MIOpen batchnorm with CPU * fix 'use_miopen' condition for nhwc miopen * fix includes * use native nhwc batchnorm to verify miopen * remove extra spaces * remove empty lines * set PYTORCH_MIOPEN_SUGGEST_NHWC=1 for all test_nn.py test
1 parent 4b8aea1 commit 4c94122

File tree

3 files changed

+77
-8
lines changed

3 files changed

+77
-8
lines changed

aten/src/ATen/native/Normalization.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,11 @@ BatchNormBackend _select_batch_norm_backend(
516516
return BatchNormBackend::Cudnn;
517517
}
518518

519+
// TODO: Remove PYTORCH_MIOPEN_SUGGEST_NHWC once ROCm officially supports NHWC in MIOpen
520+
// See #64427
521+
// non static variable is used to be able to change environment variable in runtime for testing
522+
bool PYTORCH_MIOPEN_SUGGEST_NHWC = c10::utils::check_env("PYTORCH_MIOPEN_SUGGEST_NHWC").value_or(false);
523+
519524
if (
520525
input.is_cuda()
521526
&& input.dim() <= MIOPEN_DIM_MAX
@@ -528,8 +533,8 @@ BatchNormBackend _select_batch_norm_backend(
528533
&& (input.dim() >= 3)
529534
&& detail::getCUDAHooks().compiledWithMIOpen()
530535
&& cudnn_enabled
531-
&& input.suggest_memory_format() != MemoryFormat::ChannelsLast
532-
&& input.suggest_memory_format() != MemoryFormat::ChannelsLast3d
536+
&& (input.suggest_memory_format() == MemoryFormat::Contiguous
537+
|| (input.suggest_memory_format() == MemoryFormat::ChannelsLast && PYTORCH_MIOPEN_SUGGEST_NHWC))
533538
) {
534539
return BatchNormBackend::Miopen;
535540
}
@@ -609,7 +614,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t> _batch_norm_impl_index(
609614
if (backend == BatchNormBackend::Miopen) {
610615
return std::tuple_cat(
611616
at::miopen_batch_norm(
612-
input.contiguous(), weight.contiguous(), bias.contiguous(),
617+
input.contiguous(input.suggest_memory_format()), weight.contiguous(), bias.contiguous(),
613618
running_mean.defined() ? running_mean.contiguous() : running_mean,
614619
running_var.defined() ? running_var.contiguous() : running_var,
615620
training, momentum, eps),

aten/src/ATen/native/miopen/BatchNorm_miopen.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm(
100100
mode = miopenBNSpatial;
101101
}
102102

103-
auto output_t = at::empty(input->sizes(), input->options());
103+
auto output_t = at::empty(input->sizes(), input->options(), input->suggest_memory_format());
104104
TensorArg output{ output_t, "output", 0 };
105105

106106
auto handle = getMiopenHandle();
@@ -177,8 +177,10 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm_backward(
177177
const Tensor& save_var_t =
178178
c10::value_or_else(save_var_t_opt, [] { return Tensor(); });
179179

180+
auto grad_output_contig =
181+
grad_output_t.contiguous(input_t.suggest_memory_format());
180182
TensorArg input{ input_t, "input", 1 },
181-
grad_output{ grad_output_t, "grad_output", 2 },
183+
grad_output{ grad_output_contig, "grad_output", 2 },
182184
weight{ weight_t, "weight", 3 },
183185
save_mean{ save_mean_t, "save_mean", 4 },
184186
save_var{ save_var_t, "save_var", 5 };
@@ -193,7 +195,9 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm_backward(
193195
}
194196
checkAllSameType(c, {input, grad_output});
195197
checkAllSameType(c, {weight, save_mean, save_var});
196-
checkAllContiguous(c, {input, grad_output, save_mean, save_var});
198+
checkAllContiguous(c, {save_mean, save_var});
199+
TORCH_CHECK(input->is_contiguous(input->suggest_memory_format()));
200+
TORCH_CHECK(grad_output->is_contiguous(input->suggest_memory_format()));
197201
checkDimRange(c, input, 2, 6 /* exclusive */);
198202
checkSameSize(c, input, grad_output);
199203
auto num_features = input->size(1);
@@ -208,7 +212,8 @@ std::tuple<Tensor, Tensor, Tensor> miopen_batch_norm_backward(
208212
mode = miopenBNSpatial;
209213
}
210214

211-
auto grad_input_t = at::empty(input->sizes(), input->options());
215+
auto grad_input_t = at::empty(
216+
input->sizes(), input->options(), input->suggest_memory_format());
212217
auto grad_weight_t = at::empty(weight->sizes(), weight->options());
213218
auto grad_bias_t = at::empty(weight->sizes(), weight->options());
214219

test/test_nn.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import warnings
1010
import pickle
1111
import re
12+
import os
1213
from copy import deepcopy
1314
from itertools import product
1415
from functools import partial
@@ -4928,6 +4929,54 @@ def run_test(input, grad_output):
49284929
grad = grad.permute(0, 2, 1, 3)
49294930
run_test(input, grad)
49304931

4932+
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
4933+
@unittest.skipIf(not TEST_CUDNN, "needs cudnn")
4934+
def test_batchnorm_nhwc_miopen(self):
4935+
def run_test(input, grad_output):
4936+
c = input.size(1)
4937+
mod = nn.BatchNorm2d(c).cuda().float()
4938+
mod.weight.data.uniform_()
4939+
mod.bias.data.uniform_()
4940+
ref_input = input.detach().clone(memory_format=torch.preserve_format).requires_grad_(True)
4941+
ref_grad = grad.detach().clone(memory_format=torch.preserve_format)
4942+
ref_mod = nn.BatchNorm2d(c).cuda().float()
4943+
ref_mod.load_state_dict(mod.state_dict())
4944+
out = mod(input)
4945+
out.backward(grad_output)
4946+
with torch.backends.cudnn.flags(enabled=False): # force to use native nhwc batchnorm
4947+
ref_out = ref_mod(ref_input)
4948+
ref_out.backward(ref_grad)
4949+
self.assertTrue(out.is_contiguous(memory_format=torch.channels_last))
4950+
self.assertTrue(ref_out.is_contiguous(memory_format=torch.channels_last))
4951+
self.assertEqual(out, ref_out)
4952+
self.assertEqual(mod.weight.grad, ref_mod.weight.grad)
4953+
self.assertEqual(mod.bias.grad, ref_mod.bias.grad)
4954+
self.assertEqual(input.grad, ref_input.grad)
4955+
4956+
# TODO: Remove PYTORCH_MIOPEN_SUGGEST_NHWC once ROCm officially supports NHWC in MIOpen
4957+
PYTORCH_MIOPEN_SUGGEST_NHWC = "PYTORCH_MIOPEN_SUGGEST_NHWC"
4958+
prev_val = os.getenv(PYTORCH_MIOPEN_SUGGEST_NHWC)
4959+
try:
4960+
os.environ[PYTORCH_MIOPEN_SUGGEST_NHWC] = "1"
4961+
input = torch.randint(1, 10, (4, 8, 2, 2), dtype=torch.float32, device="cuda")
4962+
input = input.contiguous(memory_format=torch.channels_last).detach().requires_grad_()
4963+
4964+
grad = torch.randint(1, 10, (4, 8, 2, 2), dtype=torch.float32, device="cuda")
4965+
grad = grad.contiguous(memory_format=torch.channels_last)
4966+
run_test(input, grad)
4967+
# see #42588, grad is channels_last contiguous, but grad.suggest_memory_format (rightly) return "contiguous"
4968+
# not channels_last
4969+
input = torch.randint(1, 10, (2, 8, 8, 1), dtype=torch.float32, device="cuda")
4970+
input = input.contiguous(memory_format=torch.channels_last).detach().requires_grad_()
4971+
grad = torch.randint(1, 10, (2, 8, 8, 1), dtype=torch.float32, device="cuda")
4972+
grad = grad.permute(0, 2, 1, 3)
4973+
run_test(input, grad)
4974+
finally:
4975+
if prev_val is None:
4976+
del os.environ[PYTORCH_MIOPEN_SUGGEST_NHWC]
4977+
else:
4978+
os.environ[PYTORCH_MIOPEN_SUGGEST_NHWC] = prev_val
4979+
49314980
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
49324981
def test_batchnorm_cudnn_half(self):
49334982
# THNN
@@ -13023,4 +13072,14 @@ def __init__(self) -> None:
1302313072

1302413073
if __name__ == '__main__':
1302513074
TestCase._default_dtype_check_enabled = True
13026-
run_tests()
13075+
# TODO: Remove PYTORCH_MIOPEN_SUGGEST_NHWC once ROCm officially supports NHWC in MIOpen
13076+
PYTORCH_MIOPEN_SUGGEST_NHWC = "PYTORCH_MIOPEN_SUGGEST_NHWC"
13077+
prev_val = os.getenv(PYTORCH_MIOPEN_SUGGEST_NHWC)
13078+
try:
13079+
os.environ[PYTORCH_MIOPEN_SUGGEST_NHWC] = "1"
13080+
run_tests()
13081+
finally:
13082+
if prev_val is None:
13083+
del os.environ[PYTORCH_MIOPEN_SUGGEST_NHWC]
13084+
else:
13085+
os.environ[PYTORCH_MIOPEN_SUGGEST_NHWC] = prev_val

0 commit comments

Comments
 (0)