Skip to content

Commit d768940

Browse files
Add test_batchnorm_nhwc_miopen_cuda_float32 (#1561)
New tests introduced for testing NHWC and NCHW batchnorm on MIOpen : - test_batchnorm_nhwc_miopen_cuda_float32 - test_batchnorm_nchw_miopen_cuda_float32 This test verifies weight and bias gradients, running_mean and running_var We can add other dtypes later How to run: `MIOPEN_ENABLE_LOGGING_CMD=1 python -u test/test_nn.py -v -k test_batchnorm_nhwc_miopen_cuda_float32` There is a difference in running_variance for NHWC batchnorm fp32 between MIOpen and native ``` MIOPEN_ENABLE_LOGGING_CMD=1 python -u test/test_nn.py -v -k test_batchnorm_nhwc_miopen_cuda_float32 ... self.assertEqual(mod.running_var, ref_mod.running_var) AssertionError: Tensor-likes are not close! Mismatched elements: 8 / 8 (100.0%) Greatest absolute difference: 0.05455732345581055 at index (5,) (up to 1e-05 allowed) Greatest relative difference: 0.030772637575864792 at index (5,) (up to 1.3e-06 allowed) ```
1 parent bfdb3cd commit d768940

File tree

1 file changed

+58
-0
lines changed

1 file changed

+58
-0
lines changed

test/test_nn.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8238,6 +8238,64 @@ def test_affine_3d_rotateRandom(self, device):
82388238

82398239
self.assertEqual(scipy_ary, gridsample_ary.reshape_as(scipy_ary))
82408240

8241+
def batchnorm2d_miopen(self, dtype, memory_format):
8242+
def run_test(input, grad_output):
8243+
c = input.size(1)
8244+
mod = nn.BatchNorm2d(c).cuda().to(dtype=input.dtype)
8245+
mod.weight.data.uniform_()
8246+
mod.bias.data.uniform_()
8247+
ref_input = input.detach().clone(memory_format=torch.preserve_format).requires_grad_(True)
8248+
ref_grad = grad.detach().clone(memory_format=torch.preserve_format)
8249+
ref_mod = nn.BatchNorm2d(c).cuda().to(dtype=input.dtype)
8250+
ref_mod.load_state_dict(mod.state_dict())
8251+
out = mod(input)
8252+
out.backward(grad_output)
8253+
with torch.backends.cudnn.flags(enabled=False): # force to use native nhwc batchnorm
8254+
ref_out = ref_mod(ref_input)
8255+
ref_out.backward(ref_grad)
8256+
self.assertTrue(out.is_contiguous(memory_format=memory_format))
8257+
self.assertTrue(ref_out.is_contiguous(memory_format=memory_format))
8258+
self.assertEqual(out, ref_out)
8259+
self.assertEqual(mod.weight.grad, ref_mod.weight.grad)
8260+
self.assertEqual(mod.bias.grad, ref_mod.bias.grad)
8261+
self.assertEqual(mod.running_mean, ref_mod.running_mean)
8262+
self.assertEqual(mod.running_var, ref_mod.running_var)
8263+
self.assertEqual(input.grad, ref_input.grad)
8264+
8265+
size = (4, 8, 2, 2)
8266+
input = torch.randint(1, 10, size=size, dtype=dtype, device="cuda")
8267+
input = input.contiguous(memory_format=memory_format).detach().requires_grad_()
8268+
grad = torch.randint(1, 10, size=size, dtype=dtype, device="cuda")
8269+
grad = grad.contiguous(memory_format=memory_format)
8270+
run_test(input, grad)
8271+
# see #42588, grad is channels_last contiguous, but grad.suggest_memory_format (rightly) return "contiguous"
8272+
# not channels_last
8273+
input = torch.randint(1, 10, (2, 8, 8, 1), dtype=dtype, device="cuda")
8274+
input = input.contiguous(memory_format=memory_format).detach().requires_grad_()
8275+
grad = torch.randint(1, 10, (2, 8, 8, 1), dtype=dtype, device="cuda")
8276+
grad = grad.permute(0, 2, 1, 3)
8277+
run_test(input, grad)
8278+
8279+
8280+
@onlyCUDA
8281+
@dtypes(torch.float)
8282+
def test_batchnorm_nhwc_miopen(self, dtype):
8283+
# TODO: Remove PYTORCH_MIOPEN_SUGGEST_NHWC once ROCm officially supports NHWC in MIOpen
8284+
PYTORCH_MIOPEN_SUGGEST_NHWC = "PYTORCH_MIOPEN_SUGGEST_NHWC"
8285+
prev_val = os.getenv(PYTORCH_MIOPEN_SUGGEST_NHWC)
8286+
try:
8287+
os.environ[PYTORCH_MIOPEN_SUGGEST_NHWC] = "1"
8288+
self.batchnorm2d_miopen(dtype, torch.channels_last)
8289+
finally:
8290+
if prev_val is None:
8291+
del os.environ[PYTORCH_MIOPEN_SUGGEST_NHWC]
8292+
else:
8293+
os.environ[PYTORCH_MIOPEN_SUGGEST_NHWC] = prev_val
8294+
8295+
@onlyCUDA
8296+
@dtypes(torch.float)
8297+
def test_batchnorm_nchw_miopen(self, dtype):
8298+
self.batchnorm2d_miopen(dtype, torch.contiguous_format)
82418299

82428300
@onlyCUDA
82438301
@dtypes(torch.float, torch.half)

0 commit comments

Comments
 (0)