Skip to content

Commit 87f2d0b

Browse files
dnikolaev-amdjithunnair-amd
authored andcommitted
Add test_batchnorm_nhwc_miopen_cuda_float32 (#1561)
New tests introduced for testing NHWC and NCHW batchnorm on MIOpen : - test_batchnorm_nhwc_miopen_cuda_float32 - test_batchnorm_nchw_miopen_cuda_float32 This test verifies weight and bias gradients, running_mean and running_var We can add other dtypes later How to run: `MIOPEN_ENABLE_LOGGING_CMD=1 python -u test/test_nn.py -v -k test_batchnorm_nhwc_miopen_cuda_float32` There is a difference in running_variance for NHWC batchnorm fp32 between MIOpen and native ``` MIOPEN_ENABLE_LOGGING_CMD=1 python -u test/test_nn.py -v -k test_batchnorm_nhwc_miopen_cuda_float32 ... self.assertEqual(mod.running_var, ref_mod.running_var) AssertionError: Tensor-likes are not close! Mismatched elements: 8 / 8 (100.0%) Greatest absolute difference: 0.05455732345581055 at index (5,) (up to 1e-05 allowed) Greatest relative difference: 0.030772637575864792 at index (5,) (up to 1.3e-06 allowed) ```
1 parent 031b577 commit 87f2d0b

File tree

1 file changed

+58
-0
lines changed

1 file changed

+58
-0
lines changed

test/test_nn.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8269,6 +8269,64 @@ def test_affine_3d_rotateRandom(self, device):
82698269

82708270
self.assertEqual(scipy_ary, gridsample_ary.reshape_as(scipy_ary))
82718271

8272+
def batchnorm2d_miopen(self, dtype, memory_format):
8273+
def run_test(input, grad_output):
8274+
c = input.size(1)
8275+
mod = nn.BatchNorm2d(c).cuda().to(dtype=input.dtype)
8276+
mod.weight.data.uniform_()
8277+
mod.bias.data.uniform_()
8278+
ref_input = input.detach().clone(memory_format=torch.preserve_format).requires_grad_(True)
8279+
ref_grad = grad.detach().clone(memory_format=torch.preserve_format)
8280+
ref_mod = nn.BatchNorm2d(c).cuda().to(dtype=input.dtype)
8281+
ref_mod.load_state_dict(mod.state_dict())
8282+
out = mod(input)
8283+
out.backward(grad_output)
8284+
with torch.backends.cudnn.flags(enabled=False): # force to use native nhwc batchnorm
8285+
ref_out = ref_mod(ref_input)
8286+
ref_out.backward(ref_grad)
8287+
self.assertTrue(out.is_contiguous(memory_format=memory_format))
8288+
self.assertTrue(ref_out.is_contiguous(memory_format=memory_format))
8289+
self.assertEqual(out, ref_out)
8290+
self.assertEqual(mod.weight.grad, ref_mod.weight.grad)
8291+
self.assertEqual(mod.bias.grad, ref_mod.bias.grad)
8292+
self.assertEqual(mod.running_mean, ref_mod.running_mean)
8293+
self.assertEqual(mod.running_var, ref_mod.running_var)
8294+
self.assertEqual(input.grad, ref_input.grad)
8295+
8296+
size = (4, 8, 2, 2)
8297+
input = torch.randint(1, 10, size=size, dtype=dtype, device="cuda")
8298+
input = input.contiguous(memory_format=memory_format).detach().requires_grad_()
8299+
grad = torch.randint(1, 10, size=size, dtype=dtype, device="cuda")
8300+
grad = grad.contiguous(memory_format=memory_format)
8301+
run_test(input, grad)
8302+
# see #42588, grad is channels_last contiguous, but grad.suggest_memory_format (rightly) return "contiguous"
8303+
# not channels_last
8304+
input = torch.randint(1, 10, (2, 8, 8, 1), dtype=dtype, device="cuda")
8305+
input = input.contiguous(memory_format=memory_format).detach().requires_grad_()
8306+
grad = torch.randint(1, 10, (2, 8, 8, 1), dtype=dtype, device="cuda")
8307+
grad = grad.permute(0, 2, 1, 3)
8308+
run_test(input, grad)
8309+
8310+
8311+
@onlyCUDA
8312+
@dtypes(torch.float)
8313+
def test_batchnorm_nhwc_miopen(self, dtype):
8314+
# TODO: Remove PYTORCH_MIOPEN_SUGGEST_NHWC once ROCm officially supports NHWC in MIOpen
8315+
PYTORCH_MIOPEN_SUGGEST_NHWC = "PYTORCH_MIOPEN_SUGGEST_NHWC"
8316+
prev_val = os.getenv(PYTORCH_MIOPEN_SUGGEST_NHWC)
8317+
try:
8318+
os.environ[PYTORCH_MIOPEN_SUGGEST_NHWC] = "1"
8319+
self.batchnorm2d_miopen(dtype, torch.channels_last)
8320+
finally:
8321+
if prev_val is None:
8322+
del os.environ[PYTORCH_MIOPEN_SUGGEST_NHWC]
8323+
else:
8324+
os.environ[PYTORCH_MIOPEN_SUGGEST_NHWC] = prev_val
8325+
8326+
@onlyCUDA
8327+
@dtypes(torch.float)
8328+
def test_batchnorm_nchw_miopen(self, dtype):
8329+
self.batchnorm2d_miopen(dtype, torch.contiguous_format)
82728330

82738331
@onlyCUDA
82748332
@dtypes(torch.float, torch.half)

0 commit comments

Comments
 (0)