Skip to content

Commit eae477f

Browse files
authored
Add filter to linear mul fusion (#2704)
1 parent f57307d commit eae477f

File tree

3 files changed

+82
-2
lines changed

3 files changed

+82
-2
lines changed

csrc/cpu/jit/passes/graph_rewrite_linear.cpp

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,23 @@ void fuseLinearMulAdd(std::shared_ptr<Graph>& graph) {
451451
%res = ipex_prepack::linear_mul_run(%input, %operand, %packed_weight)
452452
return (%res))";
453453

454+
auto filter_scalar = [](const Match& match,
455+
const std::unordered_map<std::string, Value*>& vmap) {
456+
Node* node = match.anchor;
457+
if (utils::is_scalar(node->input(1)) || utils::is_scalar(node->input(0))) {
458+
return false;
459+
}
460+
if (node->input(1)->type()->cast<TensorType>()->dim().has_value() &&
461+
node->input(1)->type()->cast<TensorType>()->dim().value() == 0) {
462+
return false;
463+
}
464+
if (node->input(0)->type()->cast<TensorType>()->dim().has_value() &&
465+
node->input(0)->type()->cast<TensorType>()->dim().value() == 0) {
466+
return false;
467+
}
468+
return true;
469+
};
470+
454471
for (const auto& mul : mul_operators) {
455472
TemplateEnv env;
456473
env.s("mul", mul);
@@ -460,8 +477,8 @@ void fuseLinearMulAdd(std::shared_ptr<Graph>& graph) {
460477
linear_mul_operand_on_the_left_rstring.format(env), linear_mul_fused);
461478
}
462479

463-
rewriter_mul_operand_on_the_right.runOnGraph(graph);
464-
rewriter_mul_operand_on_the_left.runOnGraph(graph);
480+
rewriter_mul_operand_on_the_right.runOnGraph(graph, filter_scalar);
481+
rewriter_mul_operand_on_the_left.runOnGraph(graph, filter_scalar);
465482

466483
// linear + mul + add
467484
// linear_mul Y

examples/cpu/inference/python/llm/single_instance/run_quantization.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -653,6 +653,7 @@ def calib_func(prepared_model):
653653
op_type_dict=op_type_dict,
654654
smoothquant_args=smoothquant_args
655655
)
656+
pathlib.Path(args.output_dir).mkdir(parents=True, exist_ok=True)
656657
prepared_model.save_qconf_summary(args.output_dir + "/best_configure.json")
657658

658659
else:

tests/cpu/test_jit.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -815,6 +815,19 @@ def forward(self, input):
815815
return x_l
816816

817817

818+
class LinearMulAdd_v2(nn.Module):
819+
def __init__(self, in_features, out_features):
820+
super(LinearMulAdd_v2, self).__init__()
821+
self.linear = torch.nn.Linear(in_features, out_features, bias=False)
822+
self.mul_tensor = torch.tensor(1)
823+
self.mul_scalar = 0.5
824+
825+
def forward(self, input):
826+
x_add = input
827+
result = self.mul_tensor * self.linear(input) * self.mul_scalar
828+
return result + (x_add).to(result.dtype)
829+
830+
818831
class LinearMul(nn.Module):
819832
def __init__(self, in_features, num_layers, low_rank):
820833
super(LinearMul, self).__init__()
@@ -841,6 +854,17 @@ def forward(self, input):
841854
return x_l
842855

843856

857+
class LinearMul_v2(nn.Module):
858+
def __init__(self, in_features, out_features):
859+
super(LinearMul_v2, self).__init__()
860+
self.linear = torch.nn.Linear(in_features, out_features, bias=False)
861+
self.mul_tensor = torch.tensor(1)
862+
self.mul_scalar = 0.5
863+
864+
def forward(self, input):
865+
return self.mul_scalar * self.linear(input) * self.mul_tensor
866+
867+
844868
class Linear_Reshape_Relu(nn.Module):
845869
def __init__(self, in_channels, out_channels, dest_shape, **kwargs):
846870
super(Linear_Reshape_Relu, self).__init__()
@@ -4536,6 +4560,25 @@ def test_output_linear_mul_add(self):
45364560
prec=5e-2,
45374561
)
45384562

4563+
def test_output_linear_mul_add_v2(self):
4564+
m = LinearMulAdd_v2(4, 4)
4565+
x = torch.ones(2, 4)
4566+
self._test_output(
4567+
m,
4568+
x,
4569+
kind_in_graph="aten::linear",
4570+
kind_not_in_graph="ipex_prepack::linear_mul_add_run",
4571+
)
4572+
self._test_mkl_fp32(m, x, kind_in_graph="ipex_prepack::mkl_sgemm_run")
4573+
self._test_dnnl_fp32(m, x, kind_in_graph="ipex_prepack::linear_run")
4574+
self._test_output_lowp(
4575+
m,
4576+
x,
4577+
kind_in_graph="ipex_prepack::linear_run",
4578+
kind_not_in_graph="ipex_prepack::linear_mul_add_run",
4579+
prec=5e-2,
4580+
)
4581+
45394582
def test_output_linear_mul(self):
45404583
m = LinearMul(4, 2, 8)
45414584
x = torch.ones(2, 4)
@@ -4549,6 +4592,25 @@ def test_output_linear_mul(self):
45494592
prec=5e-2,
45504593
)
45514594

4595+
def test_output_linear_mul_v2(self):
4596+
m = LinearMul_v2(4, 4)
4597+
x = torch.ones(2, 4)
4598+
self._test_output(
4599+
m,
4600+
x,
4601+
kind_in_graph="aten::linear",
4602+
kind_not_in_graph="ipex_prepack::linear_mul_run",
4603+
)
4604+
self._test_mkl_fp32(m, x, kind_in_graph="ipex_prepack::mkl_sgemm_run")
4605+
self._test_dnnl_fp32(m, x, kind_in_graph="ipex_prepack::linear_run")
4606+
self._test_output_lowp(
4607+
m,
4608+
x,
4609+
kind_in_graph="ipex_prepack::linear_run",
4610+
kind_not_in_graph="ipex_prepack::linear_mul_run",
4611+
prec=5e-2,
4612+
)
4613+
45524614
def test_output_linear_reshape_relu(self):
45534615
self._test_output(
45544616
Linear_Reshape_Relu(3, 32, (64, 16), bias=True),

0 commit comments

Comments
 (0)