Skip to content

Commit cad3f82

Browse files
authored
support ConcatBnRelu for BFloat16 (#647)
* support ConcatBnRelu for BFloat16 * modify some details of ConcatBnRelu
1 parent 11a982e commit cad3f82

File tree

4 files changed

+145
-154
lines changed

4 files changed

+145
-154
lines changed

intel_extension_for_pytorch/csrc/aten/cpu/kernels/ConcatBnReluKrnl.cpp

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,24 @@ at::Tensor concat_bn_relu_kernel_impl(
5353
}
5454
#if defined(CPU_CAPABILITY_AVX512)
5555
if (tensor_check) {
56-
at::Tensor output = at::empty(
57-
output_dim,
58-
a[0].options()
59-
.dtype(at::kFloat)
60-
.memory_format(a[0].suggest_memory_format()));
61-
torch_ipex::cpu::kernel::vec::vec512::ConcatBnReluKernelImpl_ChannelsLast<
62-
float>(a, bn_scale, bn_beta, output);
56+
at::Tensor output;
57+
if (a[0].scalar_type() == at::kBFloat16) {
58+
output = at::empty(
59+
output_dim,
60+
a[0].options()
61+
.dtype(at::kBFloat16)
62+
.memory_format(a[0].suggest_memory_format()));
63+
torch_ipex::cpu::kernel::vec::vec512::ConcatBnReluKernelImpl_ChannelsLast<
64+
at::BFloat16>(a, bn_scale, bn_beta, output);
65+
} else {
66+
output = at::empty(
67+
output_dim,
68+
a[0].options()
69+
.dtype(at::kFloat)
70+
.memory_format(a[0].suggest_memory_format()));
71+
torch_ipex::cpu::kernel::vec::vec512::ConcatBnReluKernelImpl_ChannelsLast<
72+
float>(a, bn_scale, bn_beta, output);
73+
}
6374
return output;
6475
}
6576
#endif

intel_extension_for_pytorch/csrc/cpu/vec512/concat_bn_relu.h

Lines changed: 72 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,16 @@
99
#include <limits>
1010
#include "utils.h"
1111

12+
// use float as accumulation type for BFloat16
13+
template <typename scalar_t>
14+
struct AccType {
15+
using type = scalar_t;
16+
};
17+
template <>
18+
struct AccType<BFloat16> {
19+
using type = float;
20+
};
21+
1222
namespace torch_ipex {
1323
namespace cpu {
1424
namespace kernel {
@@ -17,33 +27,75 @@ namespace vec512 {
1727

1828
using Tensor = at::Tensor;
1929

20-
template <typename T>
21-
void _concat_bn_relu_kernel_channels_last(
30+
template <typename T, typename ACC_T>
31+
static void _concat_bn_relu_kernel_channels_last(
2232
const std::vector<const T*>& in_ptr,
2333
const std::vector<int64_t>& in_ch,
2434
T* out_ptr,
25-
const T* scale_ptr,
26-
const T* beta_ptr,
35+
const ACC_T* scale_ptr,
36+
const ACC_T* beta_ptr,
37+
int64_t total_size_except_channels,
38+
int64_t ci,
39+
int64_t co) {
40+
int64_t i = 0, j = 0, k = 0;
41+
auto zero = _mm512_set1_ps(0.0);
42+
#ifdef _OPENMP
43+
#if (_OPENMP >= 201307)
44+
#pragma omp parallel for simd schedule( \
45+
static) if (omp_get_max_threads() > 1 && !omp_in_parallel())
46+
#else
47+
#pragma omp parallel for schedule( \
48+
static) if (omp_get_max_threads() > 1 && !omp_in_parallel())
49+
#endif
50+
#endif
51+
for (i = 0; i < total_size_except_channels; ++i) {
52+
for (j = 0; j < in_ptr.size(); ++j) {
53+
auto concat_in_ptr = in_ptr[j] + i * in_ch[j + 1] - (i + 1) * in_ch[j];
54+
for (k = in_ch[j]; k < in_ch[j + 1]; k += 16) {
55+
auto in = _mm512_loadu_ps(concat_in_ptr + k);
56+
auto beta = _mm512_loadu_ps(beta_ptr + k);
57+
auto scale = _mm512_loadu_ps(scale_ptr + k);
58+
auto bn_out = _mm512_add_ps(beta, _mm512_mul_ps(scale, in));
59+
auto out = _mm512_max_ps(zero, bn_out);
60+
_mm512_storeu_ps(out_ptr + i * co + k, out);
61+
}
62+
}
63+
}
64+
}
65+
66+
template <>
67+
void _concat_bn_relu_kernel_channels_last<at::BFloat16, float>(
68+
const std::vector<const at::BFloat16*>& in_ptr,
69+
const std::vector<int64_t>& in_ch,
70+
at::BFloat16* out_ptr,
71+
const float* scale_ptr,
72+
const float* beta_ptr,
2773
int64_t total_size_except_channels,
2874
int64_t ci,
2975
int64_t co) {
3076
int64_t i = 0, j = 0, k = 0;
3177
auto zero = _mm512_set1_ps(0.0);
32-
#pragma omp parallel for collapse(2)
78+
#ifdef _OPENMP
79+
#if (_OPENMP >= 201307)
80+
#pragma omp parallel for simd schedule( \
81+
static) if (omp_get_max_threads() > 1 && !omp_in_parallel())
82+
#else
83+
#pragma omp parallel for schedule( \
84+
static) if (omp_get_max_threads() > 1 && !omp_in_parallel())
85+
#endif
86+
#endif
3387
for (i = 0; i < total_size_except_channels; ++i) {
3488
for (j = 0; j < in_ptr.size(); ++j) {
89+
auto concat_in_ptr = in_ptr[j] + i * in_ch[j + 1] - (i + 1) * in_ch[j];
3590
for (k = in_ch[j]; k < in_ch[j + 1]; k += 16) {
36-
_mm512_store_ps(
37-
out_ptr + i * co + k,
38-
_mm512_max_ps(
39-
zero,
40-
_mm512_add_ps(
41-
_mm512_load_ps(beta_ptr + k),
42-
_mm512_mul_ps(
43-
_mm512_load_ps(scale_ptr + k),
44-
_mm512_load_ps(
45-
in_ptr[j] + i * (in_ch[j + 1] - in_ch[j]) + k -
46-
in_ch[j])))));
91+
auto in =
92+
cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(concat_in_ptr + k)));
93+
auto beta = _mm512_loadu_ps(beta_ptr + k);
94+
auto scale = _mm512_loadu_ps(scale_ptr + k);
95+
auto bn_out = _mm512_add_ps(beta, _mm512_mul_ps(scale, in));
96+
auto out = _mm512_max_ps(zero, bn_out);
97+
_mm256_storeu_si256(
98+
(__m256i*)(out_ptr + i * co + k), cvt_fp32_to_bf16(out));
4799
}
48100
}
49101
}
@@ -57,6 +109,7 @@ void ConcatBnReluKernelImpl_ChannelsLast(
57109
const Tensor& scale,
58110
const Tensor& beta,
59111
Tensor& output) {
112+
using ACC_T = typename AccType<T>::type;
60113
int64_t list_length = a.size();
61114
int64_t total_size_except_channels = 1;
62115
std::vector<const T*> input_ptr(list_length);
@@ -74,11 +127,11 @@ void ConcatBnReluKernelImpl_ChannelsLast(
74127
total_size_except_channels *= a[0].size(i);
75128
}
76129

77-
const T* scale_data = scale.data_ptr<T>();
78-
const T* beta_data = beta.data_ptr<T>();
130+
const ACC_T* scale_data = scale.data_ptr<ACC_T>();
131+
const ACC_T* beta_data = beta.data_ptr<ACC_T>();
79132
T* output_data = output.data_ptr<T>();
80133

81-
_concat_bn_relu_kernel_channels_last<T>(
134+
_concat_bn_relu_kernel_channels_last<T, ACC_T>(
82135
input_ptr,
83136
input_channels,
84137
output_data,

intel_extension_for_pytorch/csrc/jit/cpu/passes/graph_rewrite.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -536,7 +536,8 @@ void FuseConcatBnRelu(std::shared_ptr<Graph>& graph) {
536536
auto tensor1 = listConstruct->input(0)->type()->cast<TensorType>();
537537
auto check_type_channelsize = [](c10::TensorType tensor) {
538538
return (
539-
tensor.scalarType().value() == at::kFloat &&
539+
(tensor.scalarType().value() == at::kFloat ||
540+
tensor.scalarType().value() == at::kBFloat16) &&
540541
tensor.sizes()[1].value() % 16 == 0 && is_channelslast(tensor));
541542
};
542543
// Check if the dimension of the first tensor is either 4 or 5.
@@ -562,6 +563,15 @@ void FuseConcatBnRelu(std::shared_ptr<Graph>& graph) {
562563
}
563564
}
564565
}
566+
// Check if the BN weights is fp32 datatype.
567+
auto bn_node = node->input(0)->node();
568+
if (bn_node->namedInput("weight")
569+
->type()
570+
->cast<TensorType>()
571+
->scalarType()
572+
.value() != at::kFloat) {
573+
return false;
574+
}
565575
return true;
566576
};
567577

tests/cpu/test_jit.py

Lines changed: 44 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -701,33 +701,14 @@ def forward(self, x, y, z):
701701
x = x + y + z
702702
return self.layernorm(x)
703703

704-
class ConcatBnRelu2d(torch.nn.Module):
705-
def __init__(self):
706-
super(ConcatBnRelu2d, self).__init__()
707-
self.bn = torch.nn.BatchNorm2d(96)
708-
self.relu = torch.nn.ReLU()
709-
def forward(self, x1, x2, x3):
710-
x = torch.cat((x1, x2, x3), dim = 1)
711-
x = self.bn(x)
712-
return self.relu(x)
713-
714-
class ConcatBnRelu2d_v1(torch.nn.Module):
715-
def __init__(self):
716-
super(ConcatBnRelu2d_v1, self).__init__()
717-
self.bn = torch.nn.BatchNorm2d(32)
718-
self.relu = torch.nn.ReLU()
719-
def forward(self, x1, x2, x3):
720-
x = torch.cat((x1, x2, x3), dim = 2)
721-
x = self.bn(x)
722-
return self.relu(x)
723-
724-
class ConcatBnRelu3d(torch.nn.Module):
725-
def __init__(self):
726-
super(ConcatBnRelu3d, self).__init__()
727-
self.bn = torch.nn.BatchNorm3d(96)
704+
class ConcatBnRelu(torch.nn.Module):
705+
def __init__(self, dim, cat_dim, in_channels, **kwargs):
706+
super(ConcatBnRelu, self).__init__()
707+
self.bn = bn_module[dim](in_channels)
728708
self.relu = torch.nn.ReLU()
709+
self.cat_dim = cat_dim
729710
def forward(self, x1, x2, x3):
730-
x = torch.cat((x1, x2, x3), dim = 1)
711+
x = torch.cat((x1, x2, x3), dim = self.cat_dim)
731712
x = self.bn(x)
732713
return self.relu(x)
733714

@@ -1010,114 +991,50 @@ def test_add_layernorm(self):
1010991
self.assertTrue(any(n.kind() == node for n in trace_graph.nodes()))
1011992

1012993
def test_concat_bn_relu(self):
1013-
a1 = torch.randn(1, 32, 13, 24, dtype=torch.bfloat16).contiguous(memory_format=torch.channels_last)
1014-
a2 = torch.randn(1, 32, 13, 24, dtype=torch.bfloat16).contiguous(memory_format=torch.channels_last)
1015-
a3 = torch.randn(1, 32, 13, 24, dtype=torch.bfloat16).contiguous(memory_format=torch.channels_last)
1016-
model = ConcatBnRelu2d().eval().to(memory_format=torch.channels_last)
1017-
model = ipex.optimize(model, dtype=torch.bfloat16, level='O0')
1018-
with torch.no_grad():
1019-
jit_model = torch.jit.trace(model, (a1, a2, a3)).eval()
1020-
jit_model = torch.jit.freeze(jit_model)
1021-
#warmup run
1022-
for _ in range(2):
1023-
jit_res = jit_model(a1, a2, a3)
1024-
ori_res = model(a1, a2, a3)
1025-
self.assertEqual(jit_res, ori_res)
1026-
1027-
a1 = torch.randn(1, 32, 13, 24, dtype=torch.float).contiguous(memory_format=torch.channels_last)
1028-
a2 = torch.randn(1, 32, 13, 24, dtype=torch.float).contiguous(memory_format=torch.channels_last)
1029-
a3 = torch.randn(1, 32, 13, 24, dtype=torch.float).contiguous(memory_format=torch.channels_last)
1030-
model = ConcatBnRelu2d_v1().eval().to(memory_format=torch.channels_last)
1031-
model = ipex.optimize(model, dtype=torch.float32, level='O0')
1032-
with torch.no_grad():
1033-
jit_model = torch.jit.trace(model, (a1, a2, a3)).eval()
1034-
jit_model = torch.jit.freeze(jit_model)
1035-
#warmup run
1036-
for _ in range(2):
1037-
jit_res = jit_model(a1, a2, a3)
1038-
ori_res = model(a1, a2, a3)
1039-
self.assertEqual(jit_res, ori_res)
1040-
1041-
model = ConcatBnRelu2d().eval().to(memory_format=torch.channels_last)
1042-
model = ipex.optimize(model, dtype=torch.float32, level='O0')
1043-
with torch.no_grad():
1044-
jit_model = torch.jit.trace(model, (a1, a2, a3)).eval()
1045-
jit_model = torch.jit.freeze(jit_model)
1046-
#warmup run
1047-
for _ in range(2):
1048-
jit_res = jit_model(a1, a2, a3)
1049-
ori_res = model(a1, a2, a3)
1050-
self.assertEqual(jit_res, ori_res)
1051-
1052-
a1 = torch.randn(1, 32, 18, 53, dtype=torch.float).contiguous(memory_format=torch.channels_last)
1053-
a2 = torch.randn(1, 32, 18, 53, dtype=torch.float).contiguous(memory_format=torch.channels_last)
1054-
a3 = torch.randn(1, 32, 18, 53, dtype=torch.float).contiguous(memory_format=torch.channels_last)
1055-
with torch.no_grad():
1056-
jit_res = jit_model(a1, a2, a3)
1057-
ori_res = model(a1, a2, a3)
1058-
self.assertEqual(jit_res, ori_res)
1059-
1060-
a1 = torch.randn(1, 16, 24, 116, dtype=torch.float).contiguous(memory_format=torch.channels_last)
1061-
a2 = torch.randn(1, 48, 24, 116, dtype=torch.float).contiguous(memory_format=torch.channels_last)
1062-
a3 = torch.randn(1, 32, 24, 116, dtype=torch.float).contiguous(memory_format=torch.channels_last)
1063-
with torch.no_grad():
1064-
jit_res = jit_model(a1, a2, a3)
1065-
ori_res = model(a1, a2, a3)
1066-
self.assertEqual(jit_res, ori_res)
994+
batch_size = 3
995+
image_size = 16
996+
options = itertools.product([2, 3], [[32, 32, 32], [60, 60, 60], [17, 27, 32], [16, 32, 48]], [torch.float32, torch.bfloat16], ['O0', 'O1'], [True, False])
997+
for dim, channels, dtype, level, use_channels_last in options:
998+
input_size = [
999+
[batch_size, channels[0], image_size, image_size],
1000+
[batch_size, channels[1], image_size, image_size],
1001+
[batch_size, channels[2], image_size, image_size]
1002+
]
1003+
if dim == 3:
1004+
for i in range(3):
1005+
input_size[i].append(image_size)
1006+
a1 = torch.randn(input_size[0], dtype=dtype)
1007+
a2 = torch.randn(input_size[1], dtype=dtype)
1008+
a3 = torch.randn(input_size[2], dtype=dtype)
1009+
a = [a1, a2, a3]
10671010

1068-
a1 = torch.randn(1, 17, 15, 24, dtype=torch.float).contiguous(memory_format=torch.channels_last)
1069-
a2 = torch.randn(1, 47, 15, 24, dtype=torch.float).contiguous(memory_format=torch.channels_last)
1070-
a3 = torch.randn(1, 32, 15, 24, dtype=torch.float).contiguous(memory_format=torch.channels_last)
1071-
with torch.no_grad():
1072-
jit_res = jit_model(a1, a2, a3)
1073-
ori_res = model(a1, a2, a3)
1074-
self.assertEqual(jit_res, ori_res)
1011+
in_channels = sum(channels)
1012+
model = ConcatBnRelu(dim, 1, in_channels).eval()
10751013

1076-
a1 = torch.randn(1, 32, 13, 24, dtype=torch.float)
1077-
a2 = torch.randn(1, 32, 13, 24, dtype=torch.float)
1078-
a3 = torch.randn(1, 32, 13, 24, dtype=torch.float)
1079-
with torch.no_grad():
1080-
jit_res = jit_model(a1, a2, a3)
1081-
ori_res = model(a1, a2, a3)
1082-
self.assertEqual(jit_res, ori_res)
1014+
if use_channels_last:
1015+
suggest_memory_format = torch.channels_last if dim == 2 else torch.channels_last_3d
1016+
for i in range(3):
1017+
a[i] = a[i].to(memory_format=suggest_memory_format)
1018+
model = model.to(memory_format=suggest_memory_format)
10831019

1084-
a1 = torch.randn(1, 32, 13, 24, 33, dtype=torch.float).contiguous(memory_format=torch.channels_last_3d)
1085-
a2 = torch.randn(1, 32, 13, 24, 33, dtype=torch.float).contiguous(memory_format=torch.channels_last_3d)
1086-
a3 = torch.randn(1, 32, 13, 24, 33, dtype=torch.float).contiguous(memory_format=torch.channels_last_3d)
1087-
model = ConcatBnRelu3d().eval().to(memory_format=torch.channels_last_3d)
1088-
model = ipex.optimize(model, dtype=torch.float32, level='O0')
1089-
with torch.no_grad():
1090-
jit_model = torch.jit.trace(model, (a1, a2, a3)).eval()
1091-
jit_model = torch.jit.freeze(jit_model)
1092-
#warmup run
1093-
for _ in range(2):
1094-
jit_res = jit_model(a1, a2, a3)
1095-
ori_res = model(a1, a2, a3)
1096-
self.assertEqual(jit_res, ori_res)
1020+
model = ipex.optimize(model, dtype=dtype, level=level)
10971021

1098-
a1 = torch.randn(1, 16, 17, 14, 31, dtype=torch.float).contiguous(memory_format=torch.channels_last_3d)
1099-
a2 = torch.randn(1, 48, 17, 14, 31, dtype=torch.float).contiguous(memory_format=torch.channels_last_3d)
1100-
a3 = torch.randn(1, 32, 17, 14, 31, dtype=torch.float).contiguous(memory_format=torch.channels_last_3d)
1101-
with torch.no_grad():
1102-
jit_res = jit_model(a1, a2, a3)
1103-
ori_res = model(a1, a2, a3)
1104-
self.assertEqual(jit_res, ori_res)
1022+
with torch.cpu.amp.autocast(enabled=True if dtype == torch.bfloat16 else False), torch.no_grad():
1023+
result = model(a[0], a[1], a[2])
1024+
trace_model = torch.jit.trace(model, (a[0], a[1], a[2])).eval()
1025+
trace_model = torch.jit.freeze(trace_model)
11051026

1106-
a1 = torch.randn(1, 17, 13, 24, 33, dtype=torch.float).contiguous(memory_format=torch.channels_last_3d)
1107-
a2 = torch.randn(1, 47, 13, 24, 33, dtype=torch.float).contiguous(memory_format=torch.channels_last_3d)
1108-
a3 = torch.randn(1, 32, 13, 24, 33, dtype=torch.float).contiguous(memory_format=torch.channels_last_3d)
1109-
with torch.no_grad():
1110-
jit_res = jit_model(a1, a2, a3)
1111-
ori_res = model(a1, a2, a3)
1112-
self.assertEqual(jit_res, ori_res)
1027+
tresult = trace_model(a[0], a[1], a[2])
1028+
trace_graph = trace_model.graph_for(a[0], a[1], a[2])
11131029

1114-
a1 = torch.randn(1, 32, 13, 24, 33, dtype=torch.float)
1115-
a2 = torch.randn(1, 32, 13, 24, 33, dtype=torch.float)
1116-
a3 = torch.randn(1, 32, 13, 24, 33, dtype=torch.float)
1117-
with torch.no_grad():
1118-
jit_res = jit_model(a1, a2, a3)
1119-
ori_res = model(a1, a2, a3)
1120-
self.assertEqual(jit_res, ori_res)
1030+
self.assertEqual(result, tresult)
1031+
self.assertEqual(tresult.dtype, dtype)
1032+
if use_channels_last:
1033+
self.assertTrue(tresult.is_contiguous(memory_format=suggest_memory_format))
1034+
if use_channels_last and a1.size(1) % 16 == 0 and a2.size(1) % 16 == 0 and a3.size(1) % 16 == 0 :
1035+
self.assertTrue(any(n.kind() == "ipex::concat_bn_relu" for n in trace_graph.nodes()))
1036+
else:
1037+
self.assertTrue(all(n.kind() != "ipex::concat_bn_relu" for n in trace_graph.nodes()))
11211038

11221039
def test_mha_scores_calculation(self):
11231040
def _check_match_mha(trace_model, mat1, mat2, bias, node = "ipex::mha_scores_calc"):

0 commit comments

Comments
 (0)