Skip to content

Commit dc6ed1a

Browse files
enable Conv+LeakyRelu fusion (#589)
1 parent 424360e commit dc6ed1a

File tree

5 files changed

+133
-23
lines changed

5 files changed

+133
-23
lines changed

intel_extension_for_pytorch/csrc/jit/cpu/kernels/ConvPacked.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,16 @@ at::Tensor convolution_relu_run(
5959
return op_context->run(input, ideep::attr_t::fuse_relu());
6060
}
6161

62+
at::Tensor convolution_leaky_relu_run(
63+
const at::Tensor& input,
64+
at::Scalar alpha,
65+
const c10::intrusive_ptr<ConvolutionOpContext>& op_context) {
66+
IPEX_RECORD_FUNCTION(
67+
"ipex_prepack::convolution_leaky_relu_run", std::vector<c10::IValue>({}));
68+
auto alpha_value = alpha.to<float>();
69+
return op_context->run(input, ideep::attr_t::fuse_relu(1.0, alpha_value));
70+
}
71+
6272
at::Tensor convolution_sigmoid_run(
6373
const at::Tensor& input,
6474
const c10::intrusive_ptr<ConvolutionOpContext>& op_context) {

intel_extension_for_pytorch/csrc/jit/cpu/kernels/ConvPacked.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,11 @@ at::Tensor convolution_relu_run(
3232
const at::Tensor& input,
3333
const c10::intrusive_ptr<ConvolutionOpContext>& op_context);
3434

35+
at::Tensor convolution_leaky_relu_run(
36+
const at::Tensor& input,
37+
at::Scalar alpha,
38+
const c10::intrusive_ptr<ConvolutionOpContext>& op_context);
39+
3540
at::Tensor convolution_sigmoid_run(
3641
const at::Tensor& input,
3742
const c10::intrusive_ptr<ConvolutionOpContext>& op_context);

intel_extension_for_pytorch/csrc/jit/cpu/passes/graph_rewrite_conv.cpp

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,13 +106,15 @@ void insertPrePackedConvOp(std::shared_ptr<Graph>& graph) {
106106

107107
void fuseConvWithEltwise(std::shared_ptr<Graph>& graph) {
108108
SubgraphRewriter rewriter_relu, rewriter_sigmoid, rewriter_hardtanh,
109-
rewriter_elu, rewriter_swish, rewriter_silu;
109+
rewriter_elu, rewriter_swish, rewriter_silu, rewriter_leaky_relu;
110110
std::array<std::string, 2> relu_operators = {"relu", "relu_"};
111111
std::array<std::string, 2> sigmoid_operators = {"sigmoid", "sigmoid_"};
112112
std::array<std::string, 2> hardtanh_operators = {"hardtanh", "hardtanh_"};
113113
std::array<std::string, 2> elu_operators = {"elu", "elu_"};
114114
std::array<std::string, 2> mul_operators = {"mul", "mul_"};
115115
std::array<std::string, 2> silu_operators = {"silu", "silu_"};
116+
std::array<std::string, 2> leaky_relu_operators = {
117+
"leaky_relu", "leaky_relu_"};
116118

117119
auto conv_relu_rstring = CodeTemplate(R"(
118120
graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %kernel_size:int[], %groups:int, %output_channel:int, %weight_is_channels_last:bool, %weight_is_prepacked:bool, %input_size:int[]):
@@ -187,6 +189,19 @@ void fuseConvWithEltwise(std::shared_ptr<Graph>& graph) {
187189
%res = ipex_prepack::convolution_swish_run(%input, %packed_weight)
188190
return (%res))";
189191

192+
auto conv_leaky_relu_rstring = CodeTemplate(R"(
193+
graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %kernel_size:int[], %groups:int, %output_channel:int, %weight_is_channels_last:bool, %weight_is_prepacked:bool, %input_size:int[], %alpha):
194+
%packed_weight : __torch__.torch.classes.ipex_prepack.ConvolutionOpContext = ipex_prepack::convolution_prepack(%weight, %bias, %stride, %padding, %dilation, %kernel_size, %groups, %output_channel, %weight_is_channels_last, %weight_is_prepacked, %input_size)
195+
%x = ipex_prepack::convolution_run(%input, %packed_weight)
196+
%res = aten::${leaky_relu}(%x, %alpha)
197+
return (%res))");
198+
199+
std::string conv_leaky_relu_fused = R"(
200+
graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %kernel_size:int[], %groups:int, %output_channel:int, %weight_is_channels_last:bool, %weight_is_prepacked:bool, %input_size:int[], %alpha):
201+
%packed_weight : __torch__.torch.classes.ipex_prepack.ConvolutionOpContext = ipex_prepack::convolution_leaky_relu_prepack(%weight, %bias, %stride, %padding, %dilation, %kernel_size, %groups, %output_channel, %weight_is_channels_last, %weight_is_prepacked, %input_size, %alpha)
202+
%res = ipex_prepack::convolution_leaky_relu_run(%input, %alpha, %packed_weight)
203+
return (%res))";
204+
190205
for (const auto& relu : relu_operators) {
191206
TemplateEnv env;
192207
env.s("relu", relu);
@@ -238,12 +253,20 @@ void fuseConvWithEltwise(std::shared_ptr<Graph>& graph) {
238253
return no_input_scale;
239254
};
240255

256+
for (const auto& leaky_relu : leaky_relu_operators) {
257+
TemplateEnv env;
258+
env.s("leaky_relu", leaky_relu);
259+
rewriter_leaky_relu.RegisterRewritePattern(
260+
conv_leaky_relu_rstring.format(env), conv_leaky_relu_fused);
261+
}
262+
241263
rewriter_relu.runOnGraph(graph);
242264
rewriter_sigmoid.runOnGraph(graph);
243265
rewriter_hardtanh.runOnGraph(graph);
244266
rewriter_elu.runOnGraph(graph, filter_conv2d_elu);
245267
rewriter_swish.runOnGraph(graph);
246268
rewriter_silu.runOnGraph(graph);
269+
rewriter_leaky_relu.runOnGraph(graph);
247270
}
248271

249272
void fuseConvAddRelu(std::shared_ptr<Graph>& graph) {

intel_extension_for_pytorch/csrc/jit/cpu/passes/register_dnnl_jit_ops.cpp

Lines changed: 64 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -182,25 +182,6 @@ RegisterOperators op({
182182
};
183183
},
184184
aliasAnalysisFromSchema()),
185-
Operator(
186-
"ipex_prepack::convolution_hardtanh_run(Tensor input, Scalar "
187-
"lower_bound, Scalar upper_bound, "
188-
"__torch__.torch.classes.ipex_prepack.ConvolutionOpContext "
189-
"W_prepack) -> Tensor",
190-
[](const Node* node) -> Operation {
191-
return [](Stack* stack) {
192-
auto result = convolution_hardtanh_run(
193-
(std::move(peek(stack, 0, 4))).toTensor(),
194-
(std::move(peek(stack, 1, 4))).toScalar(),
195-
(std::move(peek(stack, 2, 4))).toScalar(),
196-
(std::move(peek(stack, 3, 4)))
197-
.toCustomClass<ConvolutionOpContext>());
198-
drop(stack, 4);
199-
pack(stack, std::move(result));
200-
return 0;
201-
};
202-
},
203-
aliasAnalysisFromSchema()),
204185
Operator(
205186
"ipex_prepack::convolution_elu_prepack(" CONV_PREPACK_ARGS
206187
", Scalar alpha, Scalar scale, Scalar input_scale) "
@@ -233,6 +214,52 @@ RegisterOperators op({
233214
};
234215
},
235216
aliasAnalysisFromSchema()),
217+
Operator(
218+
"ipex_prepack::convolution_leaky_relu_prepack(" CONV_PREPACK_ARGS
219+
", Scalar alpha) "
220+
"-> __torch__.torch.classes.ipex_prepack.ConvolutionOpContext",
221+
[](const Node* node) -> Operation {
222+
return [](Stack* stack) {
223+
auto alpha_value =
224+
(std::move(peek(stack, 11, 12))).toScalar().to<float>();
225+
auto result = IpexConvolutionOpContext::create_context(
226+
std::move((std::move(peek(stack, 0, 12))).toTensor()),
227+
std::move(toOptionalTensor(std::move(peek(stack, 1, 12)))),
228+
std::move((std::move(peek(stack, 2, 12))).toIntVector()),
229+
std::move((std::move(peek(stack, 3, 12))).toIntVector()),
230+
std::move((std::move(peek(stack, 4, 12))).toIntVector()),
231+
std::move((std::move(peek(stack, 5, 12))).toIntVector()),
232+
(std::move(peek(stack, 6, 12))).toInt(),
233+
(std::move(peek(stack, 7, 12))).toInt(),
234+
(std::move(peek(stack, 8, 12))).toBool(),
235+
(std::move(peek(stack, 9, 12))).toBool(),
236+
std::move((std::move(peek(stack, 10, 12))).toIntVector()),
237+
ideep::attr_t::fuse_relu(1.0, alpha_value));
238+
drop(stack, 12);
239+
pack(stack, std::move(result));
240+
return 0;
241+
};
242+
},
243+
aliasAnalysisFromSchema()),
244+
Operator(
245+
"ipex_prepack::convolution_hardtanh_run(Tensor input, Scalar "
246+
"lower_bound, Scalar upper_bound, "
247+
"__torch__.torch.classes.ipex_prepack.ConvolutionOpContext "
248+
"W_prepack) -> Tensor",
249+
[](const Node* node) -> Operation {
250+
return [](Stack* stack) {
251+
auto result = convolution_hardtanh_run(
252+
(std::move(peek(stack, 0, 4))).toTensor(),
253+
(std::move(peek(stack, 1, 4))).toScalar(),
254+
(std::move(peek(stack, 2, 4))).toScalar(),
255+
(std::move(peek(stack, 3, 4)))
256+
.toCustomClass<ConvolutionOpContext>());
257+
drop(stack, 4);
258+
pack(stack, std::move(result));
259+
return 0;
260+
};
261+
},
262+
aliasAnalysisFromSchema()),
236263
Operator(
237264
"ipex_prepack::convolution_elu_run(Tensor input, Scalar alpha, "
238265
"Scalar scale, Scalar input_scale, "
@@ -253,6 +280,24 @@ RegisterOperators op({
253280
};
254281
},
255282
aliasAnalysisFromSchema()),
283+
Operator(
284+
"ipex_prepack::convolution_leaky_relu_run(Tensor input, Scalar alpha, "
285+
"__torch__.torch.classes.ipex_prepack.ConvolutionOpContext "
286+
"W_prepack) -> Tensor",
287+
[](const Node* node) -> Operation {
288+
return [](Stack* stack) {
289+
auto result = convolution_leaky_relu_run(
290+
(std::move(peek(stack, 0, 3))).toTensor(),
291+
(std::move(peek(stack, 1, 3))).toScalar(),
292+
(std::move(peek(stack, 2, 3)))
293+
.toCustomClass<ConvolutionOpContext>());
294+
drop(stack, 3);
295+
pack(stack, std::move(result));
296+
return 0;
297+
};
298+
},
299+
aliasAnalysisFromSchema()),
300+
256301
Operator(
257302
"ipex_prepack::convolution_bottleneck_run(Tensor(a!) input, "
258303
"__torch__.torch.classes.ipex_prepack.ConvolutionOpContext W_prepack1, "

tests/cpu/test_jit.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,20 @@ def __init__(self, dim, in_channels, out_channels, **kwargs):
154154
def forward(self, x):
155155
return F.relu(self.conv(x), inplace=True)
156156

157+
class ConvLeakyRelu_Fixed(nn.Module):
158+
def __init__(self, dim, in_channels, out_channels, **kwargs):
159+
super(ConvLeakyRelu_Fixed, self).__init__()
160+
seed = 2018
161+
torch.manual_seed(seed)
162+
self.conv = conv_module[dim](in_channels, out_channels, bias=False, **kwargs)
163+
self.leaky_relu = nn.LeakyReLU(0.1)
164+
165+
def forward(self, x):
166+
x = self.conv(x)
167+
x = self.leaky_relu(x)
168+
return x
169+
170+
157171
class Conv_Relu_Add(nn.Module):
158172
def __init__(self, dim, in_channels, out_channels, **kwargs):
159173
super(Conv_Relu_Add, self).__init__()
@@ -717,7 +731,7 @@ def forward(self, x):
717731

718732
class LinearSwishNaive(nn.Module):
719733
def __init__(self, in_feature, out_feature):
720-
super(LinearSwishNaive, self).__init__()
734+
super(LinearSwishNaive, self).__init__()
721735
self.linear = nn.Linear(in_feature, out_feature)
722736
self.sigmoid = nn.Sigmoid()
723737

@@ -1824,11 +1838,24 @@ def test_output_conv_relu(self):
18241838
self._test_output(
18251839
ConvRelu_Fixed(dim, in_channels, out_channels, kernel_size=kernel_size, stride=1),
18261840
x,
1827-
kind_in_graph="ipex_prepack::convolution_relu_run")
1841+
kind_in_graph="ipex_prepack::convolution_relu_run",
1842+
kind_not_in_graph="ipex_prepack::convolution_relu_prepack")
18281843
self._test_output_bf16(
18291844
ConvRelu_Fixed(dim, in_channels, out_channels, kernel_size=kernel_size, stride=1),
18301845
x,
18311846
kind_in_graph="ipex_prepack::convolution_relu_run",
1847+
kind_not_in_graph="ipex_prepack::convolution_relu_prepack",
1848+
prec=0.02)
1849+
self._test_output(
1850+
ConvLeakyRelu_Fixed(dim, in_channels, out_channels, kernel_size=kernel_size, stride=1),
1851+
x,
1852+
kind_in_graph="ipex_prepack::convolution_leaky_relu_run",
1853+
kind_not_in_graph="ipex_prepack::convolution_leaky_relu_prepack")
1854+
self._test_output_bf16(
1855+
ConvLeakyRelu_Fixed(dim, in_channels, out_channels, kernel_size=kernel_size, stride=1),
1856+
x,
1857+
kind_in_graph="ipex_prepack::convolution_leaky_relu_run",
1858+
kind_not_in_graph="ipex_prepack::convolution_leaky_relu_prepack",
18321859
prec=0.02)
18331860

18341861
def test_output_conv_sum(self):
@@ -2323,7 +2350,7 @@ def _test_onednn_fp32(model, input, kind_in_graph, prec=5e-3):
23232350
res_jit = tr_model(input)
23242351
self.assertEqual(res_ref, res_jit)
23252352
self.assertTrue(any(n.kind() == kind_in_graph for n in trace_graph.nodes()))
2326-
2353+
23272354
_test_onednn_fp32(
23282355
LinearSwish_v1(3, 32, bias=True),
23292356
torch.rand(32, 3),

0 commit comments

Comments
 (0)