Skip to content

Commit b24cc53

Browse files
Cherry pick (1) linear + tanh fusion (2) fix dim size 1 issue (3) fix shuffle2d (#711)
* Add linear+tanh fusion for inference (#685) * init linear tanh fusion * add ut * Check tensor stride if it contains size 1 before passing to OneDNN (#689) * init fix * add condition for channelslast contiguous * add comments and refine the code * rebase linear pattern * Fix shufflenet reg with dynamic shape context shuffle2d pattern (#724) * fix shufflenet reg with dynamic shape context pattern * refine code * refine filter code and add no match ut Co-authored-by: chunyuan-w <[email protected]>
1 parent d2cce99 commit b24cc53

File tree

9 files changed

+275
-17
lines changed

9 files changed

+275
-17
lines changed

intel_extension_for_pytorch/csrc/cpu/ideep/IDeepConversions.cpp

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,53 @@ using IDeepTensorWrapperPtr = c10::intrusive_ptr<IDeepTensorWrapper>;
3838
using MKLDNNTensorImpl = at::OpaqueTensorImpl<IDeepTensorWrapperPtr>;
3939
using MKLDNNTensor = at::Tensor;
4040

41+
dnnl::memory::dims get_stride_with_size_1_fix(const at::Tensor& tensor) {
42+
bool need_check_stride = false;
43+
bool is_channelslast_contiguous = false;
44+
auto strides_ = tensor.strides().vec();
45+
auto dim_ = tensor.dim();
46+
// check if the tensor need to check (dim size contains 1 and is contiguous)
47+
for (int i = 0; i < dim_; i++) {
48+
if (tensor.size(i) == 1) {
49+
if (tensor.is_contiguous()) {
50+
need_check_stride = true;
51+
} else if (
52+
tensor.is_contiguous(at::MemoryFormat::ChannelsLast) ||
53+
tensor.is_contiguous(at::MemoryFormat::ChannelsLast3d)) {
54+
is_channelslast_contiguous = true;
55+
need_check_stride = true;
56+
}
57+
break;
58+
}
59+
}
60+
if (need_check_stride) {
61+
// default contiguous dim is last dim, while channel last contiguous dim fix
62+
// to channel dim (idx = 1)
63+
int contiguous_idx = is_channelslast_contiguous ? 1 : dim_ - 1;
64+
// contiguous dim must have stride 1
65+
strides_[contiguous_idx] = 1;
66+
// loop for checking each dim from last to first
67+
for (int i = dim_ - 1; i >= 0; i--) {
68+
// only check stride where dim size is 1 and not the contiguous dim that
69+
// has already set
70+
if (tensor.size(i) == 1 && i != contiguous_idx) {
71+
if (i == dim_ - 1 && is_channelslast_contiguous) {
72+
// handle the last dim when channel last contiguous
73+
strides_[i] = tensor.size(contiguous_idx) * strides_[contiguous_idx];
74+
} else if (i == 0 && is_channelslast_contiguous) {
75+
// handle the first dim when channel last contiguous
76+
strides_[i] = tensor.size(2) * strides_[2];
77+
} else {
78+
// for other cases, they are next_dim_stride*next_dim_size since
79+
// stride computation order is from last to first
80+
strides_[i] = tensor.size(i + 1) * strides_[i + 1];
81+
}
82+
}
83+
}
84+
}
85+
return strides_;
86+
}
87+
4188
ideep::tensor::data_type get_mkldnn_dtype(at::ScalarType type) {
4289
switch (type) {
4390
case at::ScalarType::Float:
@@ -76,10 +123,11 @@ ideep::tensor itensor_view_from_dense(const at::Tensor& tensor) {
76123
tensor.scalar_type() == at::ScalarType::Float ||
77124
tensor.scalar_type() == at::ScalarType::BFloat16,
78125
"itensor_view_from_dense expects float tensor input");
126+
79127
return {
80128
{tensor.sizes().vec(),
81129
get_mkldnn_dtype(tensor.scalar_type()),
82-
tensor.strides().vec()},
130+
get_stride_with_size_1_fix(tensor)},
83131
tensor.data_ptr()};
84132
}
85133

intel_extension_for_pytorch/csrc/cpu/ideep/IDeepConversions.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,5 +38,22 @@ at::Tensor empty_aten_tensor_from_desc(
3838
const ideep::tensor::desc& desc,
3939
const at::TensorOptions& options);
4040

41+
// ##Background##
42+
// This function returns the input tensor's stride with a workaround that checks
43+
// (and fixes) the stride when the input tensor has dim size 1. Currently oneDNN
44+
// is not expected the behavior that with dim size 1, a PyTorch tensor's stride
45+
// is meanless and may not follow strict contiguous context, which may make
46+
// oneDNN go into ref path (perf drop). For example: A tensor with shape [1,
47+
// 768] and stride [1536, 1] is not expected to current oneDNN though PyTorch
48+
// will think it is contiguous since dim0 is size 1. Such a Tensor can be
49+
// constructed by slice [:,0,:] from another tensor with shape [1, 2, 768] and
50+
// stride [1536, 768, 1], and it is a real case in Albert model pooler layer.
51+
// ##Performance Impact##
52+
// It takes ~0.05us on average for calling this function when creating a mkldnn
53+
// tensor.
54+
// ##TODO##
55+
// Will remove this workaround after oneDNN's fix.
56+
dnnl::memory::dims get_stride_with_size_1_fix(const at::Tensor& tensor);
57+
4158
} // namespace cpu
4259
} // namespace torch_ipex

intel_extension_for_pytorch/csrc/cpu/ideep/ideep/attributes.hpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,17 @@ struct attr_t : public dnnl::primitive_attr {
101101
return attr;
102102
}
103103

104+
static attr_t fuse_tanh(
105+
float scale = 1.0,
106+
float alpha = 0.f,
107+
float beta = 0.f) {
108+
attr_t attr;
109+
post_ops po;
110+
po.append_eltwise(scale, algorithm::eltwise_tanh, alpha, beta);
111+
attr.set_post_ops(po);
112+
return attr;
113+
}
114+
104115
static attr_t fuse_elu(
105116
float scale = 1.0,
106117
float alpha = 0.f,

intel_extension_for_pytorch/csrc/jit/cpu/kernels/LinearPacked.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,15 @@ at::Tensor linear_gelu_run(
6969
input, ideep::attr_t::fuse_gelu(1.0, 0.f, 0.f, gelu_type));
7070
}
7171

72+
at::Tensor linear_tanh_run(
73+
const at::Tensor& input,
74+
const c10::intrusive_ptr<LinearOpContext>& op_context) {
75+
IPEX_RECORD_FUNCTION(
76+
"ipex_prepack::linear_tanh_run", std::vector<c10::IValue>({}));
77+
78+
return op_context->run(input, ideep::attr_t::fuse_tanh());
79+
}
80+
7281
at::Tensor linear_sigmoid_run(
7382
const at::Tensor& input,
7483
const c10::intrusive_ptr<LinearOpContext>& op_context) {

intel_extension_for_pytorch/csrc/jit/cpu/kernels/LinearPacked.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ at::Tensor linear_gelu_run(
2929
const c10::intrusive_ptr<LinearOpContext>& op_context,
3030
c10::string_view approximate);
3131

32+
at::Tensor linear_tanh_run(
33+
const at::Tensor& input,
34+
const c10::intrusive_ptr<LinearOpContext>& op_context);
35+
3236
at::Tensor linear_sigmoid_run(
3337
const at::Tensor& input,
3438
const c10::intrusive_ptr<LinearOpContext>& op_context);

intel_extension_for_pytorch/csrc/jit/cpu/passes/graph_rewrite.cpp

Lines changed: 76 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,25 +24,55 @@ c10::optional<IValue> getIValue(
2424
return toIValue(getValue(name, match_vmap, vmap));
2525
}
2626

27+
// FuseShuffle is matching the channelshuffle pattern, where:
28+
// (1) the first view is [n, c, h, w] => [n, groups, c // groups, h, w]
29+
// (2) the tranpose is for groups => [n, c // groups, grpups, h, w]
30+
// (3) the output view shape should be the same as the input tensor shape
2731
void FuseShuffle(std::shared_ptr<Graph>& graph) {
28-
std::string shuffle = R"(
32+
// below is channelshuffle for staic view shape pattern
33+
std::string channelshuffle_with_static_shape = R"(
2934
graph(%input, %view_shape:int[], %trans_dim0:int, %trans_dim1:int, %mem_format:int, %flattern_shape:int[]):
3035
%r1 = aten::view(%input, %view_shape)
3136
%r2 = aten::transpose(%r1, %trans_dim0, %trans_dim1)
3237
%r3 = aten::contiguous(%r2, %mem_format)
3338
%r4 = aten::view(%r3, %flattern_shape)
3439
return (%r4) )";
3540

36-
std::string shuffle_2d_fusion = R"(
41+
std::string shuffle_2d_fusion_with_static_shape = R"(
3742
graph(%input, %view_shape:int[], %trans_dim0:int, %trans_dim1:int, %mem_format:int, %flattern_shape:int[]):
3843
%r = ipex::shuffle_2d(%input, %view_shape, %trans_dim0, %trans_dim1)
3944
return (%r) )";
4045

41-
// this filter passes only for the following conditions:
42-
// (1) the first view is [n, c, h, w] => [n, groups, c // groups, h, w]
43-
// (2) the tranpose is for groups => [n, c // groups, grpups, h, w]
44-
// (3) the output view shape should be the same as the input tensor shape
45-
auto filter_shuffle_2d_fusion =
46+
// below is channelshuffle for dynamic view shape pattern
47+
std::string dynamic_shape_input = R"(
48+
graph(%input, %idx_0:int, %idx_1:int, %idx_2:int, %idx_3:int, %div_g, %g:int, %type, %flattern_c):
49+
%n_ = aten::size(%input, %idx_0)
50+
%c_ = aten::size(%input, %idx_1)
51+
%tensor_c_ = prim::NumToTensor(%c_)
52+
%h_ = aten::size(%input, %idx_2)
53+
%w_ = aten::size(%input, %idx_3)
54+
%c_div_g_ = aten::div(%tensor_c_, %div_g, %type)
55+
%int_c_div_g_ = aten::Int(%c_div_g_)
56+
%view_shape:int[] = prim::ListConstruct(%n_, %g, %int_c_div_g_, %h_, %w_) )";
57+
58+
std::string channelshuffle_for_dynamic_shape = R"(
59+
%r1 = aten::view(%input, %view_shape)
60+
%r2 = aten::transpose(%r1, %idx_1, %idx_2)
61+
%r3 = aten::contiguous(%r2, %idx_0)
62+
%flattern_shape:int[] = prim::ListConstruct(%n_, %flattern_c, %h_, %w_)
63+
%r4 = aten::view(%r3, %flattern_shape)
64+
return (%r4) )";
65+
66+
std::string shuffle_2d_fusion_for_dynamic_shape = R"(
67+
%r = ipex::shuffle_2d(%input, %view_shape, %idx_1, %idx_2)
68+
return (%r) )";
69+
70+
std::string channelshuffle_with_dynamic_shape =
71+
dynamic_shape_input + channelshuffle_for_dynamic_shape;
72+
std::string shuffle_2d_fusion_with_dynamic_shape =
73+
dynamic_shape_input + shuffle_2d_fusion_for_dynamic_shape;
74+
75+
auto filter_shuffle_2d_static_fusion =
4676
[](const Match& match,
4777
const std::unordered_map<std::string, Value*>& vmap) {
4878
const auto& match_vmap = match.values_map;
@@ -86,11 +116,12 @@ void FuseShuffle(std::shared_ptr<Graph>& graph) {
86116
return false;
87117
}
88118

89-
// if the view shape and flattern shape is not set
119+
// if the view shape or flattern shape is not set
90120
if (!toIValue(view_shape_).has_value() ||
91121
!toIValue(flattern_shape_).has_value()) {
92122
return false;
93123
}
124+
94125
auto view_shape_list = toIValue(view_shape_).value().toIntVector();
95126
auto flattern_shape_list =
96127
toIValue(flattern_shape_).value().toIntVector();
@@ -134,10 +165,43 @@ void FuseShuffle(std::shared_ptr<Graph>& graph) {
134165
return true;
135166
};
136167

137-
SubgraphRewriter rewriter_shuffle_2d;
138-
rewriter_shuffle_2d.RegisterRewritePattern(shuffle, shuffle_2d_fusion);
139-
rewriter_shuffle_2d.runOnGraph(graph, filter_shuffle_2d_fusion);
140-
}
168+
auto filter_shuffle_2d_dynamic_fusion =
169+
[](const Match& match,
170+
const std::unordered_map<std::string, Value*>& vmap) {
171+
const auto& match_vmap = match.values_map;
172+
173+
auto n_idx = getIValue("idx_0", match_vmap, vmap);
174+
auto c_idx = getIValue("idx_1", match_vmap, vmap);
175+
auto h_idx = getIValue("idx_2", match_vmap, vmap);
176+
auto w_idx = getIValue("idx_3", match_vmap, vmap);
177+
if (!n_idx.has_value() || !c_idx.has_value() || !h_idx.has_value() ||
178+
!w_idx.has_value()) {
179+
return false;
180+
}
181+
182+
auto n_idx_ = n_idx.value().toInt();
183+
auto c_idx_ = c_idx.value().toInt();
184+
auto h_idx_ = h_idx.value().toInt();
185+
auto w_idx_ = w_idx.value().toInt();
186+
187+
if ((n_idx_ != 0) || (c_idx_ != 1) || (h_idx_ != 2) || (w_idx_ != 3)) {
188+
return false;
189+
}
190+
191+
return true;
192+
};
193+
194+
SubgraphRewriter rewriter_shuffle_2d_dynamic;
195+
rewriter_shuffle_2d_dynamic.RegisterRewritePattern(
196+
channelshuffle_with_dynamic_shape, shuffle_2d_fusion_with_dynamic_shape);
197+
rewriter_shuffle_2d_dynamic.runOnGraph(
198+
graph, filter_shuffle_2d_dynamic_fusion);
199+
SubgraphRewriter rewriter_shuffle_2d_static;
200+
rewriter_shuffle_2d_static.RegisterRewritePattern(
201+
channelshuffle_with_static_shape, shuffle_2d_fusion_with_static_shape);
202+
rewriter_shuffle_2d_static.runOnGraph(graph, filter_shuffle_2d_static_fusion);
203+
204+
} // namespace graph_rewrite
141205

142206
void FuseAddLayerNorm(std::shared_ptr<Graph>& graph) {
143207
std::string aten_add_layernorm = R"(

intel_extension_for_pytorch/csrc/jit/cpu/passes/graph_rewrite_linear.cpp

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,11 +125,12 @@ void insertPrePackedLinearOp(std::shared_ptr<Graph>& graph) {
125125

126126
void fuseLinearWithEltwise(std::shared_ptr<Graph>& graph) {
127127
SubgraphRewriter rewriter_relu, rewriter_gelu, rewriter_silu,
128-
rewriter_sigmoid, rewriter_swish;
128+
rewriter_sigmoid, rewriter_swish, rewriter_tanh;
129129
std::array<std::string, 2> relu_operators = {"relu", "relu_"};
130130
std::array<std::string, 2> sigmoid_operators = {"sigmoid", "sigmoid_"};
131131
std::array<std::string, 2> silu_operators = {"silu", "silu_"};
132132
std::array<std::string, 2> mul_operators = {"mul", "mul_"};
133+
std::array<std::string, 2> tanh_operators = {"tanh", "tanh_"};
133134

134135
auto linear_relu_rstring = CodeTemplate(R"(
135136
graph(%input, %packed_weight):
@@ -142,6 +143,17 @@ void fuseLinearWithEltwise(std::shared_ptr<Graph>& graph) {
142143
%res = ipex_prepack::linear_relu_run(%input, %packed_weight)
143144
return (%res))";
144145

146+
auto linear_tanh_rstring = CodeTemplate(R"(
147+
graph(%input, %packed_weight):
148+
%x = ipex_prepack::linear_run(%input, %packed_weight)
149+
%res = aten::${tanh}(%x)
150+
return (%res))");
151+
152+
std::string linear_tanh_fused = R"(
153+
graph(%input, %packed_weight):
154+
%res = ipex_prepack::linear_tanh_run(%input, %packed_weight)
155+
return (%res))";
156+
145157
std::string linear_gelu = R"(
146158
graph(%input, %approximate, %packed_weight):
147159
%x = ipex_prepack::linear_run(%input, %packed_weight)
@@ -189,6 +201,13 @@ void fuseLinearWithEltwise(std::shared_ptr<Graph>& graph) {
189201
linear_relu_rstring.format(env), linear_relu_fused);
190202
}
191203

204+
for (const auto& tanh : tanh_operators) {
205+
TemplateEnv env;
206+
env.s("tanh", tanh);
207+
rewriter_tanh.RegisterRewritePattern(
208+
linear_tanh_rstring.format(env), linear_tanh_fused);
209+
}
210+
192211
for (const auto& silu : silu_operators) {
193212
TemplateEnv env;
194213
env.s("silu", silu);
@@ -213,6 +232,7 @@ void fuseLinearWithEltwise(std::shared_ptr<Graph>& graph) {
213232
rewriter_gelu.RegisterRewritePattern(linear_gelu, linear_gelu_fused);
214233

215234
rewriter_relu.runOnGraph(graph);
235+
rewriter_tanh.runOnGraph(graph);
216236
rewriter_gelu.runOnGraph(graph);
217237
}
218238

intel_extension_for_pytorch/csrc/jit/cpu/passes/register_dnnl_jit_ops.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,24 @@ RegisterOperators op({
456456
};
457457
},
458458
aliasAnalysisFromSchema()),
459+
460+
Operator(
461+
"ipex_prepack::linear_tanh_run(Tensor input, "
462+
"__torch__.torch.classes.ipex_prepack.LinearOpContext W_prepack) "
463+
"-> Tensor",
464+
[](const Node* node) -> Operation {
465+
return [](Stack* stack) {
466+
auto result = linear_tanh_run(
467+
(std::move(peek(stack, 0, 2))).toTensor(),
468+
(std::move(peek(stack, 1, 2)))
469+
.toCustomClass<LinearOpContext>());
470+
drop(stack, 2);
471+
pack(stack, std::move(result));
472+
return 0;
473+
};
474+
},
475+
aliasAnalysisFromSchema()),
476+
459477
Operator(
460478
"ipex_prepack::linear_sigmoid_run(Tensor input, "
461479
"__torch__.torch.classes.ipex_prepack.LinearOpContext W_prepack) "

0 commit comments

Comments
 (0)