|
13 | 13 | HAS_TORCHVISION = False
|
14 | 14 | skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
|
15 | 15 |
|
| 16 | +bn_m = {1 : nn.BatchNorm1d, 2 : nn.BatchNorm2d, 3 : nn.BatchNorm3d} |
| 17 | + |
16 | 18 | class CPUOPsTester(TestCase):
|
17 | 19 |
|
18 | 20 | def test_channelshuffle(self):
|
@@ -142,38 +144,52 @@ def test_pixel_shuffle_nhwc_cpu(self):
|
142 | 144 | self.assertEqual(input.grad, ref_input.grad)
|
143 | 145 |
|
144 | 146 | def test_batch_norm(self):
|
145 |
| - m = nn.BatchNorm2d(100) |
146 |
| - x = torch.randn(20, 100, 35, 45) |
147 |
| - x1 = x.clone().detach().requires_grad_() |
148 |
| - y1 = m(x1) |
149 |
| - y1.mean().backward() |
150 |
| - |
151 |
| - # test channels last |
152 |
| - x2 = x.clone().detach().to(memory_format=torch.channels_last).requires_grad_() |
153 |
| - y2 = m(x2) |
154 |
| - y2.mean().backward() |
155 |
| - self.assertTrue(y2.is_contiguous(memory_format=torch.channels_last)) |
156 |
| - self.assertEqual(y1, y2) |
157 |
| - self.assertTrue(x2.grad.is_contiguous(memory_format=torch.channels_last)) |
158 |
| - self.assertEqual(x1.grad, x2.grad) |
159 |
| - |
160 |
| - # test bfloat16 |
161 |
| - x3 = x.clone().detach().bfloat16().requires_grad_() |
162 |
| - y3 = m(x3) |
163 |
| - y3.mean().backward() |
164 |
| - self.assertTrue(y3.dtype == torch.bfloat16) |
165 |
| - self.assertEqual(y1, y3, prec=0.1) |
166 |
| - self.assertTrue(x3.grad.dtype == torch.bfloat16) |
167 |
| - self.assertEqual(x1.grad, x3.grad) |
168 |
| - |
169 |
| - # test autocast |
170 |
| - with torch.cpu.amp.autocast(): |
171 |
| - for datatype in (torch.bfloat16, torch.float32): |
172 |
| - x4 = x.clone().detach().to(datatype).requires_grad_() |
173 |
| - y4 = m(x4) |
174 |
| - y4.mean().backward() |
175 |
| - self.assertTrue(y4.dtype == datatype) |
176 |
| - self.assertTrue(x4.grad.dtype == datatype) |
| 147 | + for dim in [2, 3]: |
| 148 | + m = bn_m[dim](10) |
| 149 | + input_size = [3, 10, 25, 25] |
| 150 | + if dim == 3: |
| 151 | + input_size.append(25) |
| 152 | + x = torch.randn(input_size) |
| 153 | + x1 = x.clone().detach().requires_grad_() |
| 154 | + y1 = m(x1) |
| 155 | + y1.mean().backward() |
| 156 | + |
| 157 | + # test channels last |
| 158 | + suggest_memory_format = torch.channels_last if dim == 2 else torch.channels_last_3d |
| 159 | + x2 = x.clone().detach().to(memory_format=suggest_memory_format).requires_grad_() |
| 160 | + |
| 161 | + y2 = m(x2) |
| 162 | + y2.mean().backward() |
| 163 | + self.assertTrue(y2.is_contiguous(memory_format=suggest_memory_format)) |
| 164 | + self.assertEqual(y1, y2) |
| 165 | + self.assertTrue(x2.grad.is_contiguous(memory_format=suggest_memory_format)) |
| 166 | + self.assertEqual(x1.grad, x2.grad) |
| 167 | + |
| 168 | + # test bfloat16 |
| 169 | + x3 = x.clone().detach().bfloat16().requires_grad_() |
| 170 | + y3 = m(x3) |
| 171 | + y3.mean().backward() |
| 172 | + self.assertTrue(y3.dtype == torch.bfloat16) |
| 173 | + self.assertEqual(y1, y3, prec=0.1) |
| 174 | + self.assertTrue(x3.grad.dtype == torch.bfloat16) |
| 175 | + self.assertEqual(x1.grad, x3.grad) |
| 176 | + |
| 177 | + # test autocast |
| 178 | + with torch.cpu.amp.autocast(): |
| 179 | + for datatype in (torch.bfloat16, torch.float32): |
| 180 | + x4 = x.clone().detach().to(datatype).requires_grad_() |
| 181 | + y4 = m(x4) |
| 182 | + y4.mean().backward() |
| 183 | + self.assertTrue(y4.dtype == datatype) |
| 184 | + self.assertTrue(x4.grad.dtype == datatype) |
| 185 | + |
| 186 | + x5 = x.clone().detach().to(datatype).to(memory_format=suggest_memory_format).requires_grad_() |
| 187 | + y5 = m(x5) |
| 188 | + y5.mean().backward() |
| 189 | + self.assertTrue(y5.dtype == datatype) |
| 190 | + self.assertTrue(x5.grad.dtype == datatype) |
| 191 | + self.assertTrue(y5.is_contiguous(memory_format=suggest_memory_format)) |
| 192 | + self.assertTrue(x5.grad.is_contiguous(memory_format=suggest_memory_format)) |
177 | 193 |
|
178 | 194 | def test_adaptive_avg_pool2d(self):
|
179 | 195 | m = nn.AdaptiveAvgPool2d((5,7))
|
|
0 commit comments