Skip to content

Commit a0c063b

Browse files
authored
Enable Conv1d channels last and Conv1d+gelu fusion in jit path (#657)
* Enable Conv1d channels last and conv+gelu fusion in JIT mode * Add weight prepack for conv1d * Fix UT error
1 parent e10d5e5 commit a0c063b

File tree

13 files changed

+330
-43
lines changed

13 files changed

+330
-43
lines changed

intel_extension_for_pytorch/csrc/aten/cpu/Conv.cpp

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,18 @@ void convolution_kernel_output(
4848
(IS_CONTIGUOUS_ANY(input)) && (IS_CONTIGUOUS_ANY(output)),
4949
"input and output are need contiguous tensor for "
5050
"convolution_kernel_output");
51-
const ideep::tensor mkldnn_input = itensor_view_from_dense(input);
51+
const ideep::tensor mkldnn_input_ = itensor_view_from_dense(input);
52+
ideep::tensor mkldnn_input = mkldnn_input_;
53+
// The following code forces the 3D input to channels last, which is a
54+
// temporary workaround before channels last 1D is formally supported in
55+
// PyTorch.
56+
if (mkldnn_input_.ndims() == 3 &&
57+
!mkldnn_input_.get_desc().is_channels_last()) {
58+
ideep::tensor mkldnn_input_conv1d{
59+
mkldnn_input_.get_desc().to_format(ideep::format_tag::nwc)};
60+
mkldnn_input_conv1d.feed_from(mkldnn_input_);
61+
mkldnn_input = mkldnn_input_conv1d;
62+
}
5263
auto output_sizes = output.sizes();
5364

5465
ideep::tensor mkldnn_output = itensor_view_from_dense(output);
@@ -109,9 +120,19 @@ at::Tensor convolution_kernel(
109120
std::vector<int64_t> output_sizes =
110121
calc_conv_output_size(input_size, kernel_size, padding, stride, dilation);
111122

112-
auto output = at::empty(
113-
output_sizes,
114-
input.options().memory_format(input.suggest_memory_format()));
123+
at::Tensor output;
124+
if (input.dim() != 3) {
125+
output = at::empty(
126+
output_sizes,
127+
input.options().memory_format(input.suggest_memory_format()));
128+
} else {
129+
// This a temporary workaround before channels last 1D is formally supported
130+
// in PyTorch. We will force to return nwc output.
131+
std::vector<int64_t> output_strides = {
132+
(output_sizes[1] * output_sizes[2]), 1, output_sizes[1]};
133+
output = at::empty_strided(output_sizes, output_strides, input.options());
134+
}
135+
115136
convolution_kernel_output(
116137
input,
117138
mkldnn_weight,

intel_extension_for_pytorch/csrc/aten/cpu/ParamUtils.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,11 @@ inline std::vector<int64_t> gen_dummy_input_size_for(
3535
std::vector<int64_t> kernel_size;
3636
if (5 == input_dim) {
3737
kernel_size.push_back(weight_sizes[input_dim - 3]);
38+
kernel_size.push_back(weight_sizes[input_dim - 2]);
39+
}
40+
if (4 == input_dim) {
41+
kernel_size.push_back(weight_sizes[input_dim - 2]);
3842
}
39-
kernel_size.push_back(weight_sizes[input_dim - 2]);
4043
kernel_size.push_back(weight_sizes[input_dim - 1]);
4144
std::vector<int64_t> input_sizes;
4245
auto grouped = groups > 1;
@@ -46,11 +49,10 @@ inline std::vector<int64_t> gen_dummy_input_size_for(
4649
auto ic = groups * weights_dims_g[1 + grouped];
4750
input_sizes.push_back(32);
4851
input_sizes.push_back(ic);
52+
input_sizes.push_back(14 * kernel_size[0]);
4953
if (4 == input_dim) {
50-
input_sizes.push_back(14 * kernel_size[0]);
5154
input_sizes.push_back(14 * kernel_size[1]);
52-
} else {
53-
input_sizes.push_back(14 * kernel_size[0]);
55+
} else if (5 == input_dim) {
5456
input_sizes.push_back(14 * kernel_size[1]);
5557
input_sizes.push_back(14 * kernel_size[2]);
5658
}

intel_extension_for_pytorch/csrc/cpu/ideep/ideep/operators/conv.hpp

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -225,8 +225,8 @@ struct convolution_forward
225225
const dims& src_dims = dims(),
226226
const attr_t& attr = attr_t(),
227227
const engine& aengine = engine::cpu_engine()) {
228-
auto src_size =
229-
weights_dims.size(); // weights_dims is 4 for conv2d and 5 for conv3d
228+
auto src_size = weights_dims.size(); // weights_dims is 3 for conv1d, 4 for
229+
// conv2d and 5 for conv3d
230230
auto grouped = groups > 1;
231231
auto weights_dims_g =
232232
grouped ? utils::group_dims(weights_dims, groups) : weights_dims;
@@ -244,8 +244,11 @@ struct convolution_forward
244244
auto oc = groups * dims_in[0 + grouped];
245245
if (5 == src_size) {
246246
kernel_size.push_back(dims_in[ndims - 3]);
247+
kernel_size.push_back(dims_in[ndims - 2]);
248+
}
249+
if (4 == src_size) {
250+
kernel_size.push_back(dims_in[ndims - 2]);
247251
}
248-
kernel_size.push_back(dims_in[ndims - 2]);
249252
kernel_size.push_back(dims_in[ndims - 1]);
250253
if (src_dims.empty()) {
251254
// Construct a dummy case, those shapes are from resnet50 model,
@@ -255,11 +258,10 @@ struct convolution_forward
255258
x_dims.push_back(ic);
256259
y_dims.push_back(32);
257260
y_dims.push_back(oc);
261+
x_dims.push_back(14 * kernel_size[0]);
258262
if (4 == src_size) {
259-
x_dims.push_back(14 * kernel_size[0]);
260263
x_dims.push_back(14 * kernel_size[1]);
261-
} else {
262-
x_dims.push_back(14 * kernel_size[0]);
264+
} else if (5 == src_size) {
263265
x_dims.push_back(14 * kernel_size[1]);
264266
x_dims.push_back(14 * kernel_size[2]);
265267
}
@@ -286,8 +288,17 @@ struct convolution_forward
286288
auto src_query = src_desc;
287289
auto dst_query = dst_desc;
288290
if (channels_last) {
289-
src_query = src_desc.to_format(5 == src_size ? tag::ndhwc : tag::nhwc);
290-
dst_query = dst_desc.to_format(5 == src_size ? tag::ndhwc : tag::nhwc);
291+
if (4 == src_size) {
292+
src_query = src_desc.to_format(tag::nhwc);
293+
dst_query = dst_desc.to_format(tag::nhwc);
294+
} else if (5 == src_size) {
295+
src_query = src_desc.to_format(tag::ndhwc);
296+
dst_query = dst_desc.to_format(tag::ndhwc);
297+
}
298+
}
299+
if (3 == src_size) {
300+
src_query = src_desc.to_format(tag::nwc);
301+
dst_query = dst_desc.to_format(tag::nwc);
291302
}
292303

293304
// FIXME: workaroud winograd format issue in inference
@@ -345,6 +356,7 @@ struct convolution_forward
345356
auto weights_desc_query = weights_desc;
346357
auto bias_desc_query = with_bias ? bias_desc : tensor::desc();
347358
auto dst_desc_query = dst_desc;
359+
auto src_is_channels_last = src_desc.is_channels_last();
348360
if (!keep_format) {
349361
src_desc_query = src_desc.to_format_any();
350362
weights_desc_query = weights_desc.to_format_any();
@@ -355,9 +367,15 @@ struct convolution_forward
355367
// For nhwc / ndhwc path, weight uses format_tag::any,
356368
// while activation uses format_tag::nhwc / format_tag::ndhwc.
357369
bool channels_last =
358-
src_desc.is_channels_last() || weights_desc.is_channels_last();
370+
src_is_channels_last || weights_desc.is_channels_last();
359371
if (channels_last) {
360-
auto memory_format = src_desc.get_ndims() == 4 ? tag::nhwc : tag::ndhwc;
372+
const auto dim = src_desc.get_ndims();
373+
auto memory_format = tag::nhwc;
374+
if (dim == 3) {
375+
memory_format = tag::nwc;
376+
} else if (dim == 5) {
377+
memory_format = tag::ndhwc;
378+
}
361379
src_desc_query = src_desc.to_format(memory_format);
362380
weights_desc_query = weights_desc.to_format_any();
363381
bias_desc_query = with_bias ? bias_desc.to_format_any() : tensor::desc();

intel_extension_for_pytorch/csrc/cpu/ideep/ideep/tensor.hpp

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,8 @@ class tensor : public memory {
175175
};
176176

177177
inline bool is_channels_last() const {
178-
if (!is_plain() || !(data.ndims != 4 || data.ndims != 5))
178+
if (!is_plain() ||
179+
!(data.ndims == 4 || data.ndims == 5 || data.ndims == 3))
179180
return false;
180181
const auto& dims = data.dims;
181182
const auto& strides = blocking_strides();
@@ -184,12 +185,16 @@ class tensor : public memory {
184185
return strides[n] == dims[h] * dims[w] * dims[c] &&
185186
strides[h] == dims[w] * dims[c] && strides[w] == dims[c] &&
186187
strides[c] == 1;
187-
} else {
188+
} else if (data.ndims == 5) {
188189
const auto n = 0, c = 1, d = 2, h = 3, w = 4;
189190
return strides[n] == dims[d] * dims[h] * dims[w] * dims[c] &&
190191
strides[d] == dims[h] * dims[w] * dims[c] &&
191192
strides[h] == dims[w] * dims[c] && strides[w] == dims[c] &&
192193
strides[c] == 1;
194+
} else {
195+
const auto n = 0, c = 1, w = 2;
196+
return strides[n] == dims[w] * dims[c] && strides[w] == dims[c] &&
197+
strides[c] == 1;
193198
}
194199
};
195200

@@ -808,8 +813,14 @@ class tensor : public memory {
808813
auto channels_last = old_desc.is_channels_last();
809814
if (channels_last) {
810815
// goihw (abcde) => gohwi (abdec) or goidhw (abcdef) => gohwi (abdefc)
811-
grouped_desc = grouped_desc.to_format(
812-
old_desc.get_ndims() == 4 ? format_tag::abdec : format_tag::abdefc);
816+
auto memory_format = format_tag::abdec;
817+
auto dim = old_desc.get_ndims();
818+
if (dim == 5) {
819+
memory_format = format_tag::abdefc;
820+
} else if (dim == 3) {
821+
memory_format = format_tag::abdc;
822+
}
823+
grouped_desc = grouped_desc.to_format(memory_format);
813824
}
814825
}
815826

intel_extension_for_pytorch/csrc/jit/cpu/kernels/ConvPacked.cpp

Lines changed: 50 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "csrc/aten/cpu/Conv.h"
44
#include "csrc/aten/cpu/ParamUtils.h"
55
#include "csrc/aten/cpu/WeightPack.h"
6+
#include "csrc/aten/cpu/utils/utils.h"
67
#include "csrc/cpu/ideep/IDeepConversions.h"
78
#include "csrc/cpu/ideep/ideep.hpp"
89
#include "csrc/cpu/ideep/ideep/utils.hpp"
@@ -112,6 +113,29 @@ at::Tensor convolution_swish_run(
112113
return op_context->run(input, ideep::attr_t::fuse_swish());
113114
}
114115

116+
at::Tensor convolution_gelu_run(
117+
const at::Tensor& input,
118+
const c10::string_view approximate,
119+
const c10::intrusive_ptr<ConvolutionOpContext>& op_context) {
120+
IPEX_RECORD_FUNCTION(
121+
"ipex_prepack::convolution_gelu_run", std::vector<c10::IValue>({}));
122+
// https://github.com/pytorch/pytorch/pull/61439
123+
// at::gelu can support tanh approximate now and OneDNN also support it
124+
// by changing algorithm If there is other type of approximate are added to
125+
// pytorch while OneDNN not support it, we might need a fallback path here.
126+
dnnl::algorithm gelu_type;
127+
if (approximate == "none") {
128+
gelu_type = dnnl::algorithm::eltwise_gelu_erf;
129+
} else if (approximate == "tanh") {
130+
gelu_type = dnnl::algorithm::eltwise_gelu_tanh;
131+
} else {
132+
TORCH_CHECK(
133+
false, "ipex::linear_gelu_run only support tanh approximate now");
134+
}
135+
return op_context->run(
136+
input, ideep::attr_t::fuse_gelu(1.0, 0.f, 0.f, gelu_type));
137+
}
138+
115139
at::Tensor convolution_add_run(
116140
const at::Tensor& input,
117141
at::Tensor& accumu,
@@ -320,13 +344,17 @@ ContextConvolution create(
320344
weight.suggest_memory_format() == at::MemoryFormat::ChannelsLast3d;
321345

322346
auto memory_format = at::MemoryFormat::Contiguous;
323-
auto format_tag = input_size.size() == 4 ? ideep::format_tag::nchw
324-
: ideep::format_tag::ncdhw;
347+
auto format_tag = ideep::format_tag::nchw;
348+
if (input_size.size() == 5) {
349+
format_tag = ideep::format_tag::ncdhw;
350+
} else if (input_size.size() == 3) {
351+
format_tag = ideep::format_tag::nwc;
352+
}
325353
if (weight_is_channels_last_) {
326354
if (input_size.size() == 4) {
327355
memory_format = at::MemoryFormat::ChannelsLast;
328356
format_tag = ideep::format_tag::nhwc;
329-
} else {
357+
} else if (input_size.size() == 5) {
330358
memory_format = at::MemoryFormat::ChannelsLast3d;
331359
format_tag = ideep::format_tag::ndhwc;
332360
}
@@ -451,17 +479,27 @@ at::Tensor run(
451479
if (use_channels_last) {
452480
if (input.dim() == 4) {
453481
memory_format = at::MemoryFormat::ChannelsLast;
454-
} else {
482+
} else if (input.dim() == 5) {
455483
memory_format = at::MemoryFormat::ChannelsLast3d;
456484
}
457485
}
458-
auto input_ = input.contiguous(memory_format);
486+
auto input_ = input;
487+
if (!is_channels_last_1d(input)) {
488+
input_ = input.contiguous(memory_format);
489+
}
459490
if (input_.sizes().vec() == context.conv_params_.pd.src_desc().dims() &&
460491
attr == context.conv_params_.op_attr &&
461492
omp_get_max_threads() == context.conv_params_.pd_use_threads) {
493+
auto output_sizes = context.conv_params_.pd.dst_desc().dims();
462494
auto output = at::empty(
463-
context.conv_params_.pd.dst_desc().dims(),
495+
output_sizes,
464496
input_.options().memory_format(input_.suggest_memory_format()));
497+
if (input.dim() == 3) {
498+
std::vector<int64_t> output_strides = {
499+
(output_sizes[1] * output_sizes[2]), 1, output_sizes[1]};
500+
output =
501+
at::empty_strided(output_sizes, output_strides, input_.options());
502+
}
465503
const ideep::tensor mkldnn_input = itensor_view_from_dense(input_);
466504
ideep::tensor mkldnn_output = itensor_view_from_dense(output);
467505
if (context.bias_.is_empty()) {
@@ -507,11 +545,14 @@ at::Tensor& run(
507545
if (use_channels_last) {
508546
if (input.dim() == 4) {
509547
memory_format = at::MemoryFormat::ChannelsLast;
510-
} else {
548+
} else if (input.dim() == 5) {
511549
memory_format = at::MemoryFormat::ChannelsLast3d;
512550
}
513551
}
514-
auto input_ = input.contiguous(memory_format);
552+
auto input_ = input;
553+
if (!is_channels_last_1d(input)) {
554+
input_ = input.contiguous(memory_format);
555+
}
515556
// always align accumu format with inputs' format.
516557
accumu = accumu.contiguous(memory_format);
517558
if (input_.sizes().vec() == context.conv_params_.pd.src_desc().dims() &&
@@ -608,7 +649,7 @@ at::Tensor unpack(ContextConvolution& context, const at::Tensor& tensor) {
608649
if (context.weight_is_channels_last_) {
609650
if (context.original_desc_.get_ndims() == 4) {
610651
result = result.to(at::MemoryFormat::ChannelsLast);
611-
} else {
652+
} else if (context.original_desc_.get_ndims() == 5) {
612653
result = result.to(at::MemoryFormat::ChannelsLast3d);
613654
}
614655
}

intel_extension_for_pytorch/csrc/jit/cpu/kernels/ConvPacked.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,11 @@ at::Tensor convolution_swish_run(
5757
const at::Tensor& input,
5858
const c10::intrusive_ptr<ConvolutionOpContext>& op_context);
5959

60+
at::Tensor convolution_gelu_run(
61+
const at::Tensor& input,
62+
c10::string_view approximate,
63+
const c10::intrusive_ptr<ConvolutionOpContext>& op_context);
64+
6065
at::Tensor convolution_add_run(
6166
const at::Tensor& input,
6267
at::Tensor& accumu,

0 commit comments

Comments
 (0)