Skip to content

Commit a1e8b0e

Browse files
change torch.equal() to self.assetrEqual() while comparing NHWC and NCHW batchnorm output (#1600)
`self.assertTrue(torch.equal(out1, out2))` assumes a compete match But we have a slight difference (~1e-7) with fp32 NHWC and NCHW batchnorm output `self.assertEqual(out1, out2)` allows for tolerance
1 parent d28d7ff commit a1e8b0e

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

test/test_nn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5071,7 +5071,7 @@ def test_batchnorm_nhwc_cuda(self):
50715071
inp2 = inp1.contiguous(memory_format=torch.channels_last)
50725072
out1 = model(inp1)
50735073
out2 = model(inp2)
5074-
self.assertTrue(torch.equal(out1, out2))
5074+
self.assertEqual(out1, out2)
50755075

50765076
def test_batchnorm_load_state_dict(self):
50775077
bn = torch.nn.BatchNorm2d(3)

0 commit comments

Comments
 (0)