Skip to content

Commit d1379aa

Browse files
authored
Wenzhe/bmm add (#407)
* add bmm_add fusion and test * allow get_params to take binary kind refactor fuse_binary * add filter for bmm_add
1 parent b5f7770 commit d1379aa

File tree

8 files changed

+124
-0
lines changed

8 files changed

+124
-0
lines changed

intel_extension_for_pytorch/csrc/cpu/ideep/ideep/attributes.hpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,14 @@ struct attr_t : public dnnl::primitive_attr {
6565
return attr;
6666
}
6767

68+
static attr_t fuse_binary(algorithm alg, memory::desc src_desc) {
69+
attr_t attr;
70+
post_ops po;
71+
po.append_binary(alg, src_desc);
72+
attr.set_post_ops(po);
73+
return attr;
74+
}
75+
6876
static attr_t fuse_relu(
6977
float scale = 1.0,
7078
float alpha = 0.f,
@@ -162,6 +170,7 @@ struct attr_t : public dnnl::primitive_attr {
162170

163171
algorithm alg = algorithm::undef;
164172
float scale = 1.0, alpha = 1.0, beta = 0.0;
173+
memory::desc binary_src_desc;
165174

166175
auto akind = po.kind(index);
167176
switch (akind) {
@@ -171,6 +180,9 @@ struct attr_t : public dnnl::primitive_attr {
171180
case kind::eltwise:
172181
po.get_params_eltwise(index, scale, alg, alpha, beta);
173182
break;
183+
case kind::binary:
184+
po.get_params_binary(index, alg, binary_src_desc);
185+
break;
174186
default:
175187
error::wrap_c_api(dnnl_invalid_arguments, "could not get params");
176188
break;
@@ -243,6 +255,10 @@ struct attr_t : public dnnl::primitive_attr {
243255
utils::to_bytes(bytes, beta);
244256
bytes.append(1, '.');
245257
utils::to_bytes(bytes, alg);
258+
case kind::binary:
259+
utils::to_bytes(bytes, akind);
260+
bytes.append(1, '.');
261+
utils::to_bytes(bytes, alg);
246262
default:
247263
break;
248264
}

intel_extension_for_pytorch/csrc/jit/cpu/kernels/Matmul.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,5 +121,28 @@ at::Tensor dil_matmul_div(
121121
}
122122
}
123123

124+
at::Tensor dil_bmm_add(
125+
const at::Tensor& input,
126+
const at::Tensor& batch1,
127+
const at::Tensor& batch2,
128+
const c10::Scalar& alpha) {
129+
#if defined(IPEX_PROFILE_OP)
130+
RECORD_FUNCTION("dil_bmm_add", std::vector<c10::IValue>({}));
131+
#endif
132+
auto batch1_dim = batch1.dim();
133+
auto batch2_dim = batch2.dim();
134+
if (batch1_dim == batch2_dim && batch1_dim >= 3) {
135+
auto _input = input.is_contiguous() ? input : input.contiguous();
136+
ideep::tensor onednn_input = itensor_view_from_dense(_input);
137+
138+
auto op_attr = ideep::attr_t::fuse_binary(
139+
dnnl::algorithm::binary_add, onednn_input.get_desc());
140+
return bmm_impl(
141+
batch1, batch2, at::Tensor(), op_attr, {onednn_input}, 1.0f);
142+
} else {
143+
return at::baddbmm(input, batch1, batch2);
144+
}
145+
}
146+
124147
} // namespace cpu
125148
} // namespace torch_ipex

intel_extension_for_pytorch/csrc/jit/cpu/kernels/Matmul.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ namespace jit {
1515
// So we fake some op namespaces to workaround that.
1616
namespace ipex {
1717
static auto matmul_div = Symbol::fromQualString("ipex::matmul_div");
18+
static auto bmm_add = Symbol::fromQualString("ipex::bmm_add");
1819

1920
} // namespace ipex
2021

@@ -36,5 +37,11 @@ at::Tensor dil_matmul_div(
3637
at::Tensor out_opt,
3738
const c10::Scalar& div_input);
3839

40+
at::Tensor dil_bmm_add(
41+
const at::Tensor& input,
42+
const at::Tensor& batch1,
43+
const at::Tensor& batch2,
44+
const c10::Scalar& alpha);
45+
3946
} // namespace cpu
4047
} // namespace torch_ipex

intel_extension_for_pytorch/csrc/jit/cpu/passes/graph_rewrite.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,42 @@ void replaceInteractionWithQInteraction(std::shared_ptr<Graph>& graph) {
432432
}
433433
}
434434

435+
void fuseBmmAdd(std::shared_ptr<Graph>& graph) {
436+
std::array<std::string, 2> add_operators = {"add", "add_"};
437+
438+
auto bmm_add_rstring_v1 = R"(
439+
graph(%input, %batch1, %batch2, %alpha):
440+
%x = aten::bmm(%batch1, %batch2)
441+
%res = aten::add(%x, %input, %alpha)
442+
return (%res))";
443+
std::string bmm_add_fused = R"(
444+
graph(%input, %batch1, %batch2, %alpha):
445+
%res = ipex::bmm_add(%input, %batch1, %batch2, %alpha)
446+
return (%res))";
447+
// fliter the unsupported case
448+
auto fusion_filter = [](const Match& match,
449+
const std::unordered_map<std::string, Value*>& vmap) {
450+
Node* node = match.anchor;
451+
const auto& match_vmap = match.values_map;
452+
453+
auto batch1 = node->input(1)->type()->cast<TensorType>();
454+
auto batch2 = node->input(2)->type()->cast<TensorType>();
455+
if (batch1->dim() != batch2->dim()) {
456+
return false;
457+
}
458+
459+
if (batch1->dim().value() < 3) {
460+
return false;
461+
}
462+
463+
return true;
464+
};
465+
466+
SubgraphRewriter rewriter_add_v1;
467+
rewriter_add_v1.RegisterRewritePattern(bmm_add_rstring_v1, bmm_add_fused);
468+
rewriter_add_v1.runOnGraph(graph, fusion_filter);
469+
}
470+
435471
void FuseConcatBnRelu(std::shared_ptr<Graph>& graph) {
436472
std::string aten_concat_bn_relu = R"(
437473
graph(%input : Tensor[], %dim:int, %weight, %bias, %running_mean, %running_var, %training, %momentum, %eps, %cudnn_enabled):

intel_extension_for_pytorch/csrc/jit/cpu/passes/graph_rewrite.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ void FuseShuffle(std::shared_ptr<Graph>& graph);
2727
void FuseMHAScoreCalc(std::shared_ptr<Graph>& graph);
2828
void FuseLinearSwishCustomized(std::shared_ptr<Graph>& graph);
2929
void replaceAtenMaxPool2dWithIpexMaxPool2d(std::shared_ptr<Graph>& graph);
30+
void fuseBmmAdd(std::shared_ptr<Graph>& graph);
31+
3032
void replaceOpsWithAtenInplaceOps(std::shared_ptr<Graph>& graph);
3133
void replaceAtenOpsWithIpexInplaceOps(std::shared_ptr<Graph>& graph);
3234
void replaceAtenSoftmaxWithIpexSoftmax(std::shared_ptr<Graph>& graph);

intel_extension_for_pytorch/csrc/jit/cpu/passes/register_dnnl_jit_ops.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -550,6 +550,23 @@ RegisterOperators op({
550550
},
551551
aliasAnalysisFromSchema()),
552552

553+
Operator(
554+
"ipex::bmm_add(Tensor input, Tensor batch1, Tensor batch2, Scalar alpha) -> "
555+
"Tensor",
556+
[](const Node* node) -> Operation {
557+
return [](Stack* stack) {
558+
auto result = dil_bmm_add(
559+
(std::move(peek(stack, 0, 4))).toTensor(),
560+
(std::move(peek(stack, 1, 4))).toTensor(),
561+
(std::move(peek(stack, 2, 4))).toTensor(),
562+
(std::move(peek(stack, 3, 4))).toScalar());
563+
drop(stack, 4);
564+
pack(stack, std::move(result));
565+
return 0;
566+
};
567+
},
568+
aliasAnalysisFromSchema()),
569+
553570
Operator(
554571
"ipex::mha_scores_calc(Tensor q, Tensor k, Tensor rel_qk, Scalar "
555572
"alpha, "

intel_extension_for_pytorch/csrc/jit/fusion_pass.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,9 @@ void IPEXFusionPass(std::shared_ptr<Graph>& graph) {
355355
// Multi-Head-Attention
356356
graph_rewrite::FuseMHAScoreCalc(graph);
357357

358+
// Fuse bmm + add for bmm_add
359+
graph_rewrite::fuseBmmAdd(graph);
360+
358361
// Replace _convolution with conv2d or conv3d
359362
graph_rewrite_helper::replaceConvolutionWithAtenConv(graph);
360363

tests/cpu/test_jit.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -621,6 +621,15 @@ def forward(self, x):
621621
else:
622622
return mm_res.div(torch.ones(mm_res_shape,dtype=x.dtype)+1)
623623

624+
class BmmAdd(nn.Module):
625+
def __init__(self):
626+
super(BmmAdd, self).__init__()
627+
628+
def forward(self, input, batch1, batch2):
629+
bmm_res = torch.bmm(batch1, batch2)
630+
res = torch.add(bmm_res, input)
631+
return res
632+
624633
class MHAScoresCalculation(nn.Module):
625634
def __init__(self, dim_per_head, softmax_dim=-1):
626635
super(MHAScoresCalculation, self).__init__()
@@ -2476,6 +2485,17 @@ def test_matmul_div(self):
24762485
kind_not_in_graph=None,
24772486
prec=5e-3)
24782487

2488+
def test_bmm_add(self):
2489+
M = torch.randn(10, 3, 5)
2490+
batch1 = torch.randn(10, 3, 4)
2491+
batch2 = torch.randn(10, 4, 5)
2492+
mod = BmmAdd()
2493+
traced_mod = torch.jit.trace(mod, (M, batch1, batch2))
2494+
fused_mod = traced_mod.graph_for(M, batch1, batch2)
2495+
out = traced_mod(M, batch1, batch2)
2496+
expected = torch.baddbmm(M, batch1, batch2)
2497+
self.assertTrue(torch.allclose(out, expected))
2498+
24792499
def test_ipex_softmax(self):
24802500
self._test_output(
24812501
AtenSoftmaxRepalce(),

0 commit comments

Comments
 (0)