Skip to content

Commit cc855ff

Browse files
add linear_swish fusion using MKL for FSI riskfuel after code review. (#551)
* add linear_swish fusion using MKL for FSI riskfuel Co-authored-by: Wang Weihan <[email protected]>
1 parent c411e8e commit cc855ff

File tree

11 files changed

+424
-53
lines changed

11 files changed

+424
-53
lines changed
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
#include <csrc/jit/cpu/kernels/AddSwish.h>
2+
3+
#if defined(CPU_CAPABILITY_AVX512)
4+
#include "csrc/cpu/vec512/add_swish.h"
5+
#endif
6+
7+
namespace torch_ipex {
8+
namespace cpu {
9+
10+
#if defined(DYN_DISP_BUILD)
11+
namespace {
12+
#endif
13+
14+
at::Tensor add_swish_kernel_impl(
15+
at::Tensor& x,
16+
at::Tensor& a,
17+
const at::Tensor& b,
18+
const at::Tensor& c) {
19+
#if defined(CPU_CAPABILITY_AVX512)
20+
if (a.scalar_type() == at::kFloat && c.scalar_type() == at::kFloat) {
21+
return torch_ipex::cpu::kernel::vec::vec512::dil_add_swish<float>(a, c);
22+
} else if (
23+
a.scalar_type() == at::kBFloat16 && c.scalar_type() == at::kBFloat16) {
24+
return torch_ipex::cpu::kernel::vec::vec512::dil_add_swish<at::BFloat16>(
25+
a, c);
26+
}
27+
#endif
28+
auto lin_res = at::linear(x, b, c);
29+
auto sigmoid_res = at::sigmoid(lin_res);
30+
return at::mul(lin_res, sigmoid_res);
31+
}
32+
33+
#if defined(DYN_DISP_BUILD)
34+
} // anonymous namespace
35+
36+
REGISTER_DISPATCH(add_swish_kernel_stub, &add_swish_kernel_impl);
37+
38+
#endif
39+
40+
} // namespace cpu
41+
} // namespace torch_ipex
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
#pragma once
2+
3+
#include <immintrin.h>
4+
5+
#include <ATen/ATen.h>
6+
#include <ATen/ExpandUtils.h>
7+
#include <ATen/Parallel.h>
8+
#include <c10/util/SmallVector.h>
9+
#include <limits>
10+
#include "add_softmax.h"
11+
#include "utils.h"
12+
13+
namespace torch_ipex {
14+
namespace cpu {
15+
namespace kernel {
16+
namespace vec {
17+
namespace vec512 {
18+
19+
template <typename scalar_t>
20+
inline void _dil_add_swish_fusion_kernel(
21+
scalar_t* a,
22+
const scalar_t* b,
23+
const int& size) {
24+
auto vec_ps_min = _mm512_set1_ps(std::numeric_limits<float>::min());
25+
auto vec_ps_1 = _mm512_set1_ps(1.0);
26+
__m512 vec_a, vec_b;
27+
__m512 vec_add_tmp, vec_addone_tmp;
28+
29+
int i = 0;
30+
31+
// load tensor<float> a & b
32+
// assum the same size , no need to broadcast
33+
for (; i <= size - 16; i += 16) {
34+
// a is first operand of add, b is bias
35+
vec_a = _loadu(a + i);
36+
vec_b = _loadu(b + i);
37+
38+
// add bias
39+
vec_a = _mm512_add_ps(vec_a, vec_b);
40+
vec_add_tmp =
41+
vec_a; // keep the intermediate result for later use in the mul
42+
43+
// caculate sigmoid e^x / (1 + e^x)
44+
vec_a = _dil_exp_kernel(vec_a);
45+
vec_addone_tmp = _mm512_add_ps(vec_a, vec_ps_1);
46+
vec_a = _mm512_div_ps(vec_a, vec_addone_tmp);
47+
vec_a = _mm512_mul_ps(vec_a, vec_add_tmp);
48+
49+
_storeu(a + i, vec_a);
50+
}
51+
52+
// 512 tail
53+
if (i < size) {
54+
// mask load
55+
__mmask16 mask = (1 << (size - i)) - 1;
56+
vec_a = _maskz_loadu(a + i, mask);
57+
vec_b = _maskz_loadu(b + i, mask);
58+
59+
// add bias
60+
vec_a = _mm512_add_ps(vec_a, vec_b);
61+
vec_add_tmp =
62+
vec_a; // keep the intermediate result for later use in the second mul
63+
64+
// caculate sigmoid e^x / (1 + e^x)
65+
vec_a = _dil_exp_kernel(vec_a);
66+
vec_addone_tmp = _mm512_add_ps(vec_a, vec_ps_1);
67+
vec_a = _mm512_div_ps(vec_a, vec_addone_tmp);
68+
69+
vec_a = _mm512_mul_ps(vec_a, vec_add_tmp);
70+
71+
// mask store
72+
_mask_storeu(a + i, vec_a, mask);
73+
}
74+
}
75+
76+
template <typename scalar_t>
77+
at::Tensor dil_add_swish(const at::Tensor& mm_output, const at::Tensor& bias) {
78+
scalar_t* mm_output_data_base = mm_output.data_ptr<scalar_t>();
79+
scalar_t* bias_data_base = bias.data_ptr<scalar_t>();
80+
81+
auto infered_size = mm_output.sizes().vec();
82+
int64_t dim_size = infered_size[infered_size.size() - 1];
83+
int64_t outer_size = 1;
84+
// The last dim is the loop unit. We need to minus 2 to exclude the last dim.
85+
// infered_size.size() - 2 is the -2th dimension.
86+
for (int64_t i = infered_size.size() - 2; i >= 0; i--) {
87+
// Calculate outer loop number;
88+
outer_size *= infered_size[i];
89+
}
90+
91+
int64_t grain_size = at::internal::GRAIN_SIZE / (16 * dim_size);
92+
if (grain_size < 1)
93+
grain_size = 1;
94+
95+
at::parallel_for(0, outer_size, grain_size, [&](int64_t begin, int64_t end) {
96+
for (int64_t i = begin; i < end; i++) {
97+
_dil_add_swish_fusion_kernel<scalar_t>(
98+
mm_output_data_base + i * dim_size, bias_data_base, dim_size);
99+
}
100+
});
101+
102+
return mm_output;
103+
} // dil_add_swish
104+
105+
} // namespace vec512
106+
} // namespace vec
107+
} // namespace kernel
108+
} // namespace cpu
109+
} // namespace torch_ipex
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#include "AddSwish.h"
2+
#include <ATen/Context.h>
3+
#include <ATen/InferSize.h>
4+
#include <c10/util/Exception.h>
5+
#include <c10/util/Logging.h>
6+
#include <torch/csrc/autograd/function.h>
7+
#include <iostream>
8+
9+
#include <limits>
10+
11+
namespace torch_ipex {
12+
namespace cpu {
13+
DEFINE_DISPATCH(add_swish_kernel_stub);
14+
15+
// Currently we only support 1D tensor of bias(operand of add).
16+
at::Tensor AddSwish(
17+
at::Tensor& x,
18+
at::Tensor& mm_output,
19+
const at::Tensor& weight,
20+
const at::Tensor& bias) {
21+
#if defined(DYN_DISP_BUILD)
22+
return add_swish_kernel_stub(kCPU, x, mm_output, weight, bias);
23+
#else
24+
return add_swish_kernel_impl(x, mm_output, weight, bias);
25+
#endif
26+
}
27+
28+
} // namespace cpu
29+
} // namespace torch_ipex
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
#pragma once
2+
3+
#include <ATen/ATen.h>
4+
#include <csrc/dyndisp/DispatchStub.h>
5+
6+
namespace torch_ipex {
7+
namespace cpu {
8+
9+
// Currently we only support 1D tensor of bias(operand of add).
10+
at::Tensor AddSwish(
11+
at::Tensor& x,
12+
at::Tensor& mm_output,
13+
const at::Tensor& weight,
14+
const at::Tensor& bias);
15+
16+
#if defined(DYN_DISP_BUILD)
17+
namespace {
18+
#endif
19+
20+
at::Tensor add_swish_kernel_impl(
21+
at::Tensor& x,
22+
at::Tensor& a,
23+
const at::Tensor& b,
24+
const at::Tensor& c);
25+
26+
#if defined(DYN_DISP_BUILD)
27+
}
28+
#endif
29+
30+
using add_swish_kernel_fn = at::Tensor (*)(
31+
at::Tensor&,
32+
at::Tensor&,
33+
const at::Tensor&,
34+
const at::Tensor&);
35+
DECLARE_DISPATCH(add_swish_kernel_fn, add_swish_kernel_stub);
36+
37+
} // namespace cpu
38+
} // namespace torch_ipex
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#include "LinearSwishCustomized.h"
2+
#include "AddSwish.h"
3+
4+
#include <ATen/Context.h>
5+
#include <ATen/InferSize.h>
6+
#include <c10/util/Exception.h>
7+
#include <c10/util/Logging.h>
8+
#include <torch/csrc/autograd/function.h>
9+
10+
#include <limits>
11+
12+
#include "csrc/cpu/ideep/ideep.hpp"
13+
#include "csrc/utils/ipex_op_profile.h"
14+
15+
namespace torch_ipex {
16+
namespace cpu {
17+
18+
at::Tensor dil_linear_swish_customized(
19+
at::Tensor& x,
20+
const at::Tensor& weight,
21+
const at::Tensor& bias) {
22+
#if defined(IPEX_PROFILE_OP)
23+
RECORD_FUNCTION("dil_linear_swish_customized", std::vector<c10::IValue>({}));
24+
#endif
25+
26+
// at::linear w/o bias
27+
auto linear_res = at::linear(x, weight);
28+
return AddSwish(x, linear_res, weight, bias);
29+
}
30+
31+
} // namespace cpu
32+
} // namespace torch_ipex
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#pragma once
2+
3+
#include <ATen/Tensor.h>
4+
5+
#include <c10/core/Scalar.h>
6+
#include <torch/csrc/jit/runtime/custom_operator.h>
7+
8+
#include "csrc/cpu/ideep/ideep.hpp"
9+
10+
namespace torch_ipex {
11+
namespace cpu {
12+
13+
at::Tensor dil_linear_swish_customized(
14+
at::Tensor& x,
15+
const at::Tensor& weight,
16+
const at::Tensor& bias);
17+
} // namespace cpu
18+
} // namespace torch_ipex

intel_extension_for_pytorch/csrc/jit/cpu/passes/graph_rewrite.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -513,6 +513,24 @@ void FuseConcatBnRelu(std::shared_ptr<Graph>& graph) {
513513
rewriter_concatbnrelu.runOnGraph(graph, fusion_filter);
514514
}
515515

516+
void FuseLinearSwishCustomized(std::shared_ptr<Graph>& graph) {
517+
std::string linear_swish = R"(
518+
graph(%x, %weight, %bias):
519+
%_linear_res = aten::linear(%x, %weight, %bias)
520+
%_sigmod_res = aten::sigmoid(%_linear_res)
521+
%_mul_res2 = aten::mul(%_linear_res, %_sigmod_res)
522+
return (%_mul_res2) )";
523+
524+
std::string linear_swish_fusion = R"(
525+
graph(%x, %weight, %bias):
526+
%_res = ipex::linear_swish_customized(%x, %weight, %bias)
527+
return (%_res) )";
528+
529+
SubgraphRewriter ls_fusion;
530+
ls_fusion.RegisterRewritePattern(linear_swish, linear_swish_fusion);
531+
ls_fusion.runOnGraph(graph);
532+
}
533+
516534
} // namespace graph_rewrite
517535
} // namespace jit
518536
} // namespace torch

intel_extension_for_pytorch/csrc/jit/cpu/passes/graph_rewrite.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ c10::optional<IValue> getIValue(
2525

2626
void FuseShuffle(std::shared_ptr<Graph>& graph);
2727
void FuseMHAScoreCalc(std::shared_ptr<Graph>& graph);
28+
void FuseLinearSwishCustomized(std::shared_ptr<Graph>& graph);
2829
void replaceAtenMaxPool2dWithIpexMaxPool2d(std::shared_ptr<Graph>& graph);
2930
void replaceOpsWithAtenInplaceOps(std::shared_ptr<Graph>& graph);
3031
void replaceAtenOpsWithIpexInplaceOps(std::shared_ptr<Graph>& graph);

intel_extension_for_pytorch/csrc/jit/cpu/passes/register_dnnl_jit_ops.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "csrc/jit/cpu/kernels/Embeddingbag.h"
1111
#include "csrc/jit/cpu/kernels/Interaction.h"
1212
#include "csrc/jit/cpu/kernels/LinearPacked.h"
13+
#include "csrc/jit/cpu/kernels/LinearSwishCustomized.h"
1314
#include "csrc/jit/cpu/kernels/Matmul.h"
1415
#include "csrc/jit/cpu/kernels/MaxPool2D.h"
1516
#include "csrc/jit/cpu/kernels/Mha.h"
@@ -521,6 +522,18 @@ RegisterOperators op({
521522
},
522523
aliasAnalysisFromSchema()),
523524

525+
Operator(
526+
"ipex::linear_swish_customized(Tensor x, Tensor weight, Tensor ? bias) -> Tensor",
527+
[](Stack& stack) {
528+
auto result = dil_linear_swish_customized(
529+
peek(stack, 0, 3).toTensor(),
530+
peek(stack, 1, 3).toTensor(),
531+
toOptionalTensor(std::move(peek(stack, 2, 3))));
532+
drop(stack, 3);
533+
pack(stack, std::move(result));
534+
},
535+
aliasAnalysisFromSchema()),
536+
524537
Operator(
525538
"ipex::distil_mha_scores_calc(Tensor q, Tensor k, Tensor mask_qk, "
526539
"int[] mask_qk_reshp, int transpose_dim_a, int transpose_dim_b, "
@@ -539,6 +552,20 @@ RegisterOperators op({
539552
peek(stack, 8, 10).toInt(),
540553
peek(stack, 9, 10));
541554
drop(stack, 10);
555+
556+
pack(stack, std::move(result));
557+
},
558+
aliasAnalysisFromSchema()),
559+
560+
Operator(
561+
"ipex::linear_swish_customized(Tensor x, Tensor weight, Tensor ? bias) -> Tensor",
562+
[](Stack& stack) {
563+
auto result = dil_linear_swish_customized(
564+
peek(stack, 0, 3).toTensor(),
565+
peek(stack, 1, 3).toTensor(),
566+
toOptionalTensor(std::move(peek(stack, 2, 3))));
567+
drop(stack, 3);
568+
542569
pack(stack, std::move(result));
543570
},
544571
aliasAnalysisFromSchema()),

intel_extension_for_pytorch/csrc/jit/fusion_pass.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,7 @@ void IPEXFusionPass(std::shared_ptr<Graph>& graph) {
377377
graph_rewrite::fuseLinearWithEltwise(graph);
378378
graph_rewrite::fuseLinearAddRelu(graph);
379379

380+
graph_rewrite::FuseLinearSwishCustomized(graph);
380381
// fuse add+layernorm
381382
graph_rewrite::FuseAddLayerNorm(graph);
382383
// deconvolution fusion

0 commit comments

Comments
 (0)