Skip to content

Commit ae268ac

Browse files
authored
Fix channel last 3d support for batch_norm (#642)
1 parent a61732e commit ae268ac

File tree

3 files changed

+62
-37
lines changed

3 files changed

+62
-37
lines changed

intel_extension_for_pytorch/csrc/aten/cpu/Normalization.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@ struct Var {
6868
};
6969

7070
static inline bool is_contiguous(const at::Tensor& t) {
71-
return t.is_contiguous() || t.is_contiguous(at::MemoryFormat::ChannelsLast);
71+
return t.is_contiguous() || t.is_contiguous(at::MemoryFormat::ChannelsLast) ||
72+
t.is_contiguous(at::MemoryFormat::ChannelsLast3d);
7273
}
7374

7475
// For some ambiguous cases, it is possible a channels last contiguous
@@ -78,7 +79,9 @@ static inline bool is_contiguous(const at::Tensor& t) {
7879
static inline at::MemoryFormat suggest_memory_format_contig(
7980
const at::Tensor& t) {
8081
return t.is_contiguous() ? at::MemoryFormat::Contiguous
81-
: at::MemoryFormat::ChannelsLast;
82+
: (t.is_contiguous(at::MemoryFormat::ChannelsLast3d)
83+
? at::MemoryFormat::ChannelsLast3d
84+
: at::MemoryFormat::ChannelsLast);
8285
}
8386

8487
template <typename scalar_t, typename param_t>

intel_extension_for_pytorch/csrc/aten/cpu/kernels/NormalizationKrnl.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1267,7 +1267,9 @@ void batch_norm_cpu_kernel_impl(
12671267
eps);
12681268
}
12691269
});
1270-
} else if (input.is_contiguous(at::MemoryFormat::ChannelsLast)) {
1270+
} else if (
1271+
input.is_contiguous(at::MemoryFormat::ChannelsLast) ||
1272+
input.is_contiguous(at::MemoryFormat::ChannelsLast3d)) {
12711273
AT_DISPATCH_FLOATING_TYPES_AND(
12721274
at::ScalarType::BFloat16,
12731275
input.scalar_type(),
@@ -1338,7 +1340,9 @@ void batch_norm_cpu_collect_stats_kernel_impl(
13381340
}
13391341
}
13401342
});
1341-
} else if (input.is_contiguous(at::MemoryFormat::ChannelsLast)) {
1343+
} else if (
1344+
input.is_contiguous(at::MemoryFormat::ChannelsLast) ||
1345+
input.is_contiguous(at::MemoryFormat::ChannelsLast3d)) {
13421346
AT_DISPATCH_FLOATING_TYPES_AND(
13431347
at::ScalarType::BFloat16,
13441348
input.scalar_type(),
@@ -1445,7 +1449,9 @@ void batch_norm_cpu_backward_kernel_impl(
14451449
}
14461450
}
14471451
});
1448-
} else if (input.is_contiguous(at::MemoryFormat::ChannelsLast)) {
1452+
} else if (
1453+
input.is_contiguous(at::MemoryFormat::ChannelsLast) ||
1454+
input.is_contiguous(at::MemoryFormat::ChannelsLast3d)) {
14491455
AT_DISPATCH_FLOATING_TYPES_AND(
14501456
at::ScalarType::BFloat16,
14511457
input.scalar_type(),

tests/cpu/test_cpu_ops.py

Lines changed: 48 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
HAS_TORCHVISION = False
1414
skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
1515

16+
bn_m = {1 : nn.BatchNorm1d, 2 : nn.BatchNorm2d, 3 : nn.BatchNorm3d}
17+
1618
class CPUOPsTester(TestCase):
1719

1820
def test_channelshuffle(self):
@@ -142,38 +144,52 @@ def test_pixel_shuffle_nhwc_cpu(self):
142144
self.assertEqual(input.grad, ref_input.grad)
143145

144146
def test_batch_norm(self):
145-
m = nn.BatchNorm2d(100)
146-
x = torch.randn(20, 100, 35, 45)
147-
x1 = x.clone().detach().requires_grad_()
148-
y1 = m(x1)
149-
y1.mean().backward()
150-
151-
# test channels last
152-
x2 = x.clone().detach().to(memory_format=torch.channels_last).requires_grad_()
153-
y2 = m(x2)
154-
y2.mean().backward()
155-
self.assertTrue(y2.is_contiguous(memory_format=torch.channels_last))
156-
self.assertEqual(y1, y2)
157-
self.assertTrue(x2.grad.is_contiguous(memory_format=torch.channels_last))
158-
self.assertEqual(x1.grad, x2.grad)
159-
160-
# test bfloat16
161-
x3 = x.clone().detach().bfloat16().requires_grad_()
162-
y3 = m(x3)
163-
y3.mean().backward()
164-
self.assertTrue(y3.dtype == torch.bfloat16)
165-
self.assertEqual(y1, y3, prec=0.1)
166-
self.assertTrue(x3.grad.dtype == torch.bfloat16)
167-
self.assertEqual(x1.grad, x3.grad)
168-
169-
# test autocast
170-
with torch.cpu.amp.autocast():
171-
for datatype in (torch.bfloat16, torch.float32):
172-
x4 = x.clone().detach().to(datatype).requires_grad_()
173-
y4 = m(x4)
174-
y4.mean().backward()
175-
self.assertTrue(y4.dtype == datatype)
176-
self.assertTrue(x4.grad.dtype == datatype)
147+
for dim in [2, 3]:
148+
m = bn_m[dim](10)
149+
input_size = [3, 10, 25, 25]
150+
if dim == 3:
151+
input_size.append(25)
152+
x = torch.randn(input_size)
153+
x1 = x.clone().detach().requires_grad_()
154+
y1 = m(x1)
155+
y1.mean().backward()
156+
157+
# test channels last
158+
suggest_memory_format = torch.channels_last if dim == 2 else torch.channels_last_3d
159+
x2 = x.clone().detach().to(memory_format=suggest_memory_format).requires_grad_()
160+
161+
y2 = m(x2)
162+
y2.mean().backward()
163+
self.assertTrue(y2.is_contiguous(memory_format=suggest_memory_format))
164+
self.assertEqual(y1, y2)
165+
self.assertTrue(x2.grad.is_contiguous(memory_format=suggest_memory_format))
166+
self.assertEqual(x1.grad, x2.grad)
167+
168+
# test bfloat16
169+
x3 = x.clone().detach().bfloat16().requires_grad_()
170+
y3 = m(x3)
171+
y3.mean().backward()
172+
self.assertTrue(y3.dtype == torch.bfloat16)
173+
self.assertEqual(y1, y3, prec=0.1)
174+
self.assertTrue(x3.grad.dtype == torch.bfloat16)
175+
self.assertEqual(x1.grad, x3.grad)
176+
177+
# test autocast
178+
with torch.cpu.amp.autocast():
179+
for datatype in (torch.bfloat16, torch.float32):
180+
x4 = x.clone().detach().to(datatype).requires_grad_()
181+
y4 = m(x4)
182+
y4.mean().backward()
183+
self.assertTrue(y4.dtype == datatype)
184+
self.assertTrue(x4.grad.dtype == datatype)
185+
186+
x5 = x.clone().detach().to(datatype).to(memory_format=suggest_memory_format).requires_grad_()
187+
y5 = m(x5)
188+
y5.mean().backward()
189+
self.assertTrue(y5.dtype == datatype)
190+
self.assertTrue(x5.grad.dtype == datatype)
191+
self.assertTrue(y5.is_contiguous(memory_format=suggest_memory_format))
192+
self.assertTrue(x5.grad.is_contiguous(memory_format=suggest_memory_format))
177193

178194
def test_adaptive_avg_pool2d(self):
179195
m = nn.AdaptiveAvgPool2d((5,7))

0 commit comments

Comments
 (0)