Skip to content

Commit 72191f7

Browse files
Zain Jaffalfhahn
authored andcommitted
[InstCombine] Matrix multiplication negation optimisation
If one of the operands in a matrix multiplication is negated we can optimise the equation by moving the negation to the smallest element of the operands or the result. Reviewed By: spatel, fhahn Differential Revision: https://reviews.llvm.org/D133300
1 parent 99422f8 commit 72191f7

File tree

2 files changed

+113
-29
lines changed

2 files changed

+113
-29
lines changed

llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1758,6 +1758,8 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
17581758
break;
17591759
}
17601760
case Intrinsic::matrix_multiply: {
1761+
// Optimize negation in matrix multiplication.
1762+
17611763
// -A * -B -> A * B
17621764
Value *A, *B;
17631765
if (match(II->getArgOperand(0), m_FNeg(m_Value(A))) &&
@@ -1766,6 +1768,54 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
17661768
replaceOperand(*II, 1, B);
17671769
return II;
17681770
}
1771+
1772+
Value *Op0 = II->getOperand(0);
1773+
Value *Op1 = II->getOperand(1);
1774+
Value *OpNotNeg, *NegatedOp;
1775+
unsigned NegatedOpArg, OtherOpArg;
1776+
if (match(Op0, m_FNeg(m_Value(OpNotNeg)))) {
1777+
NegatedOp = Op0;
1778+
NegatedOpArg = 0;
1779+
OtherOpArg = 1;
1780+
} else if (match(Op1, m_FNeg(m_Value(OpNotNeg)))) {
1781+
NegatedOp = Op1;
1782+
NegatedOpArg = 1;
1783+
OtherOpArg = 0;
1784+
} else
1785+
// Multiplication doesn't have a negated operand.
1786+
break;
1787+
1788+
// Only optimize if the negated operand has only one use.
1789+
if (!NegatedOp->hasOneUse())
1790+
break;
1791+
1792+
Value *OtherOp = II->getOperand(OtherOpArg);
1793+
VectorType *RetTy = cast<VectorType>(II->getType());
1794+
VectorType *NegatedOpTy = cast<VectorType>(NegatedOp->getType());
1795+
VectorType *OtherOpTy = cast<VectorType>(OtherOp->getType());
1796+
ElementCount NegatedCount = NegatedOpTy->getElementCount();
1797+
ElementCount OtherCount = OtherOpTy->getElementCount();
1798+
ElementCount RetCount = RetTy->getElementCount();
1799+
// (-A) * B -> A * (-B), if it is cheaper to negate B and vice versa.
1800+
if (ElementCount::isKnownGT(NegatedCount, OtherCount) &&
1801+
ElementCount::isKnownLT(OtherCount, RetCount)) {
1802+
Value *InverseOtherOp = Builder.CreateFNeg(OtherOp);
1803+
replaceOperand(*II, NegatedOpArg, OpNotNeg);
1804+
replaceOperand(*II, OtherOpArg, InverseOtherOp);
1805+
return II;
1806+
}
1807+
// (-A) * B -> -(A * B), if it is cheaper to negate the result
1808+
if (ElementCount::isKnownGT(NegatedCount, RetCount)) {
1809+
SmallVector<Value *, 5> NewArgs(II->args());
1810+
SmallVector<Type *, 5> Types;
1811+
// Only add overloaded types
1812+
Types.push_back(II->getType());
1813+
Types.push_back(NewArgs[0]->getType());
1814+
Types.push_back(NewArgs[1]->getType());
1815+
NewArgs[NegatedOpArg] = OpNotNeg;
1816+
Instruction *NewMul = Builder.CreateIntrinsic(IID, Types, NewArgs, II);
1817+
return replaceInstUsesWith(*II, Builder.CreateFNegFMF(NewMul, II));
1818+
}
17691819
break;
17701820
}
17711821
case Intrinsic::fmuladd: {

llvm/test/Transforms/InstCombine/matrix-multiplication-negation.ll

Lines changed: 63 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
; The result has the fewest vector elements between the result and the two operands so the negation can be moved there
55
define <2 x double> @test_negation_move_to_result(<6 x double> %a, <3 x double> %b) {
66
; CHECK-LABEL: @test_negation_move_to_result(
7-
; CHECK-NEXT: [[A_NEG:%.*]] = fneg <6 x double> [[A:%.*]]
8-
; CHECK-NEXT: [[RES:%.*]] = tail call <2 x double> @llvm.matrix.multiply.v2f64.v6f64.v3f64(<6 x double> [[A_NEG]], <3 x double> [[B:%.*]], i32 2, i32 3, i32 1)
9-
; CHECK-NEXT: ret <2 x double> [[RES]]
7+
; CHECK-NEXT: [[TMP1:%.*]] = call <2 x double> @llvm.matrix.multiply.v2f64.v6f64.v3f64(<6 x double> [[A:%.*]], <3 x double> [[B:%.*]], i32 2, i32 3, i32 1)
8+
; CHECK-NEXT: [[TMP2:%.*]] = fneg <2 x double> [[TMP1]]
9+
; CHECK-NEXT: ret <2 x double> [[TMP2]]
1010
;
1111
%a.neg = fneg <6 x double> %a
1212
%res = tail call <2 x double> @llvm.matrix.multiply.v2f64.v6f64.v3f64(<6 x double> %a.neg, <3 x double> %b, i32 2, i32 3, i32 1)
@@ -17,20 +17,53 @@ define <2 x double> @test_negation_move_to_result(<6 x double> %a, <3 x double>
1717
; Fast flag should be preserved
1818
define <2 x double> @test_negation_move_to_result_with_fastflags(<6 x double> %a, <3 x double> %b) {
1919
; CHECK-LABEL: @test_negation_move_to_result_with_fastflags(
20-
; CHECK-NEXT: [[A_NEG:%.*]] = fneg <6 x double> [[A:%.*]]
21-
; CHECK-NEXT: [[RES:%.*]] = tail call fast <2 x double> @llvm.matrix.multiply.v2f64.v6f64.v3f64(<6 x double> [[A_NEG]], <3 x double> [[B:%.*]], i32 2, i32 3, i32 1)
22-
; CHECK-NEXT: ret <2 x double> [[RES]]
20+
; CHECK-NEXT: [[TMP1:%.*]] = call fast <2 x double> @llvm.matrix.multiply.v2f64.v6f64.v3f64(<6 x double> [[A:%.*]], <3 x double> [[B:%.*]], i32 2, i32 3, i32 1)
21+
; CHECK-NEXT: [[TMP2:%.*]] = fneg fast <2 x double> [[TMP1]]
22+
; CHECK-NEXT: ret <2 x double> [[TMP2]]
2323
;
2424
%a.neg = fneg <6 x double> %a
2525
%res = tail call fast <2 x double> @llvm.matrix.multiply.v2f64.v6f64.v3f64(<6 x double> %a.neg, <3 x double> %b, i32 2, i32 3, i32 1)
2626
ret <2 x double> %res
2727
}
2828

29+
define <2 x double> @test_negation_move_to_result_with_nnan_flag(<6 x double> %a, <3 x double> %b) {
30+
; CHECK-LABEL: @test_negation_move_to_result_with_nnan_flag(
31+
; CHECK-NEXT: [[TMP1:%.*]] = call nnan <2 x double> @llvm.matrix.multiply.v2f64.v6f64.v3f64(<6 x double> [[A:%.*]], <3 x double> [[B:%.*]], i32 2, i32 3, i32 1)
32+
; CHECK-NEXT: [[TMP2:%.*]] = fneg nnan <2 x double> [[TMP1]]
33+
; CHECK-NEXT: ret <2 x double> [[TMP2]]
34+
;
35+
%a.neg = fneg <6 x double> %a
36+
%res = tail call nnan <2 x double> @llvm.matrix.multiply.v2f64.v6f64.v3f64(<6 x double> %a.neg, <3 x double> %b, i32 2, i32 3, i32 1)
37+
ret <2 x double> %res
38+
}
39+
40+
define <2 x double> @test_negation_move_to_result_with_nsz_flag(<6 x double> %a, <3 x double> %b) {
41+
; CHECK-LABEL: @test_negation_move_to_result_with_nsz_flag(
42+
; CHECK-NEXT: [[TMP1:%.*]] = call nsz <2 x double> @llvm.matrix.multiply.v2f64.v6f64.v3f64(<6 x double> [[A:%.*]], <3 x double> [[B:%.*]], i32 2, i32 3, i32 1)
43+
; CHECK-NEXT: [[TMP2:%.*]] = fneg nsz <2 x double> [[TMP1]]
44+
; CHECK-NEXT: ret <2 x double> [[TMP2]]
45+
;
46+
%a.neg = fneg <6 x double> %a
47+
%res = tail call nsz <2 x double> @llvm.matrix.multiply.v2f64.v6f64.v3f64(<6 x double> %a.neg, <3 x double> %b, i32 2, i32 3, i32 1)
48+
ret <2 x double> %res
49+
}
50+
51+
define <2 x double> @test_negation_move_to_result_with_fastflag_on_negation(<6 x double> %a, <3 x double> %b) {
52+
; CHECK-LABEL: @test_negation_move_to_result_with_fastflag_on_negation(
53+
; CHECK-NEXT: [[TMP1:%.*]] = call <2 x double> @llvm.matrix.multiply.v2f64.v6f64.v3f64(<6 x double> [[A:%.*]], <3 x double> [[B:%.*]], i32 2, i32 3, i32 1)
54+
; CHECK-NEXT: [[TMP2:%.*]] = fneg <2 x double> [[TMP1]]
55+
; CHECK-NEXT: ret <2 x double> [[TMP2]]
56+
;
57+
%a.neg = fneg fast<6 x double> %a
58+
%res = tail call <2 x double> @llvm.matrix.multiply.v2f64.v6f64.v3f64(<6 x double> %a.neg, <3 x double> %b, i32 2, i32 3, i32 1)
59+
ret <2 x double> %res
60+
}
61+
2962
; %b has the fewest vector elements between the result and the two operands so the negation can be moved there
3063
define <9 x double> @test_move_negation_to_second_operand(<27 x double> %a, <3 x double> %b) {
3164
; CHECK-LABEL: @test_move_negation_to_second_operand(
32-
; CHECK-NEXT: [[A_NEG:%.*]] = fneg <27 x double> [[A:%.*]]
33-
; CHECK-NEXT: [[RES:%.*]] = tail call <9 x double> @llvm.matrix.multiply.v9f64.v27f64.v3f64(<27 x double> [[A_NEG]], <3 x double> [[B:%.*]], i32 9, i32 3, i32 1)
65+
; CHECK-NEXT: [[TMP1:%.*]] = fneg <3 x double> [[B:%.*]]
66+
; CHECK-NEXT: [[RES:%.*]] = tail call <9 x double> @llvm.matrix.multiply.v9f64.v27f64.v3f64(<27 x double> [[A:%.*]], <3 x double> [[TMP1]], i32 9, i32 3, i32 1)
3467
; CHECK-NEXT: ret <9 x double> [[RES]]
3568
;
3669
%a.neg = fneg <27 x double> %a
@@ -42,8 +75,8 @@ define <9 x double> @test_move_negation_to_second_operand(<27 x double> %a, <3 x
4275
; Fast flag should be preserved
4376
define <9 x double> @test_move_negation_to_second_operand_with_fast_flags(<27 x double> %a, <3 x double> %b) {
4477
; CHECK-LABEL: @test_move_negation_to_second_operand_with_fast_flags(
45-
; CHECK-NEXT: [[A_NEG:%.*]] = fneg <27 x double> [[A:%.*]]
46-
; CHECK-NEXT: [[RES:%.*]] = tail call fast <9 x double> @llvm.matrix.multiply.v9f64.v27f64.v3f64(<27 x double> [[A_NEG]], <3 x double> [[B:%.*]], i32 9, i32 3, i32 1)
78+
; CHECK-NEXT: [[TMP1:%.*]] = fneg <3 x double> [[B:%.*]]
79+
; CHECK-NEXT: [[RES:%.*]] = tail call fast <9 x double> @llvm.matrix.multiply.v9f64.v27f64.v3f64(<27 x double> [[A:%.*]], <3 x double> [[TMP1]], i32 9, i32 3, i32 1)
4780
; CHECK-NEXT: ret <9 x double> [[RES]]
4881
;
4982
%a.neg = fneg <27 x double> %a
@@ -54,9 +87,9 @@ define <9 x double> @test_move_negation_to_second_operand_with_fast_flags(<27 x
5487
; The result has the fewest vector elements between the result and the two operands so the negation can be moved there
5588
define <2 x double> @test_negation_move_to_result_from_second_operand(<3 x double> %a, <6 x double> %b){
5689
; CHECK-LABEL: @test_negation_move_to_result_from_second_operand(
57-
; CHECK-NEXT: [[B_NEG:%.*]] = fneg <6 x double> [[B:%.*]]
58-
; CHECK-NEXT: [[RES:%.*]] = tail call <2 x double> @llvm.matrix.multiply.v2f64.v3f64.v6f64(<3 x double> [[A:%.*]], <6 x double> [[B_NEG]], i32 1, i32 3, i32 2)
59-
; CHECK-NEXT: ret <2 x double> [[RES]]
90+
; CHECK-NEXT: [[TMP1:%.*]] = call <2 x double> @llvm.matrix.multiply.v2f64.v3f64.v6f64(<3 x double> [[A:%.*]], <6 x double> [[B:%.*]], i32 1, i32 3, i32 2)
91+
; CHECK-NEXT: [[TMP2:%.*]] = fneg <2 x double> [[TMP1]]
92+
; CHECK-NEXT: ret <2 x double> [[TMP2]]
6093
;
6194
%b.neg = fneg <6 x double> %b
6295
%res = tail call <2 x double> @llvm.matrix.multiply.v2f64.v3f64.v6f64(<3 x double> %a, <6 x double> %b.neg, i32 1, i32 3, i32 2)
@@ -66,8 +99,8 @@ define <2 x double> @test_negation_move_to_result_from_second_operand(<3 x doubl
6699
; %a has the fewest vector elements between the result and the two operands so the negation can be moved there
67100
define <9 x double> @test_move_negation_to_first_operand(<3 x double> %a, <27 x double> %b) {
68101
; CHECK-LABEL: @test_move_negation_to_first_operand(
69-
; CHECK-NEXT: [[B_NEG:%.*]] = fneg <27 x double> [[B:%.*]]
70-
; CHECK-NEXT: [[RES:%.*]] = tail call <9 x double> @llvm.matrix.multiply.v9f64.v3f64.v27f64(<3 x double> [[A:%.*]], <27 x double> [[B_NEG]], i32 1, i32 3, i32 9)
102+
; CHECK-NEXT: [[TMP1:%.*]] = fneg <3 x double> [[A:%.*]]
103+
; CHECK-NEXT: [[RES:%.*]] = tail call <9 x double> @llvm.matrix.multiply.v9f64.v3f64.v27f64(<3 x double> [[TMP1]], <27 x double> [[B:%.*]], i32 1, i32 3, i32 9)
71104
; CHECK-NEXT: ret <9 x double> [[RES]]
72105
;
73106
%b.neg = fneg <27 x double> %b
@@ -172,9 +205,10 @@ define <4 x double> @matrix_multiply_two_operands_negated_with_same_size(<2 x do
172205

173206
define <2 x double> @matrix_multiply_two_operands_with_multiple_uses(<6 x double> %a, <3 x double> %b) {
174207
; CHECK-LABEL: @matrix_multiply_two_operands_with_multiple_uses(
175-
; CHECK-NEXT: [[RES:%.*]] = tail call <2 x double> @llvm.matrix.multiply.v2f64.v6f64.v3f64(<6 x double> [[A:%.*]], <3 x double> [[B:%.*]], i32 2, i32 3, i32 1)
176-
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <6 x double> [[A]], <6 x double> poison, <2 x i32> <i32 0, i32 1>
177-
; CHECK-NEXT: [[RES_3:%.*]] = fsub <2 x double> [[RES]], [[TMP1]]
208+
; CHECK-NEXT: [[A_NEG:%.*]] = fneg <6 x double> [[A:%.*]]
209+
; CHECK-NEXT: [[RES:%.*]] = tail call <2 x double> @llvm.matrix.multiply.v2f64.v6f64.v3f64(<6 x double> [[A]], <3 x double> [[B:%.*]], i32 2, i32 3, i32 1)
210+
; CHECK-NEXT: [[RES_2:%.*]] = shufflevector <6 x double> [[A_NEG]], <6 x double> undef, <2 x i32> <i32 0, i32 1>
211+
; CHECK-NEXT: [[RES_3:%.*]] = fadd <2 x double> [[RES_2]], [[RES]]
178212
; CHECK-NEXT: ret <2 x double> [[RES_3]]
179213
;
180214
%a.neg = fneg <6 x double> %a
@@ -234,8 +268,8 @@ define <12 x double> @fneg_with_multiple_uses_2(<15 x double> %a, <20 x double>
234268
; negation should be moved to the second operand given it has the smallest operand count
235269
define <72 x double> @chain_of_matrix_mutliplies(<27 x double> %a, <3 x double> %b, <8 x double> %c) {
236270
; CHECK-LABEL: @chain_of_matrix_mutliplies(
237-
; CHECK-NEXT: [[A_NEG:%.*]] = fneg <27 x double> [[A:%.*]]
238-
; CHECK-NEXT: [[RES:%.*]] = tail call <9 x double> @llvm.matrix.multiply.v9f64.v27f64.v3f64(<27 x double> [[A_NEG]], <3 x double> [[B:%.*]], i32 9, i32 3, i32 1)
271+
; CHECK-NEXT: [[TMP1:%.*]] = fneg <3 x double> [[B:%.*]]
272+
; CHECK-NEXT: [[RES:%.*]] = tail call <9 x double> @llvm.matrix.multiply.v9f64.v27f64.v3f64(<27 x double> [[A:%.*]], <3 x double> [[TMP1]], i32 9, i32 3, i32 1)
239273
; CHECK-NEXT: [[RES_2:%.*]] = tail call <72 x double> @llvm.matrix.multiply.v72f64.v9f64.v8f64(<9 x double> [[RES]], <8 x double> [[C:%.*]], i32 9, i32 1, i32 8)
240274
; CHECK-NEXT: ret <72 x double> [[RES_2]]
241275
;
@@ -249,11 +283,11 @@ define <72 x double> @chain_of_matrix_mutliplies(<27 x double> %a, <3 x double>
249283
; second negation should be moved to the result of the second multipication
250284
define <6 x double> @chain_of_matrix_mutliplies_with_two_negations(<3 x double> %a, <5 x double> %b, <10 x double> %c) {
251285
; CHECK-LABEL: @chain_of_matrix_mutliplies_with_two_negations(
252-
; CHECK-NEXT: [[B_NEG:%.*]] = fneg <5 x double> [[B:%.*]]
253-
; CHECK-NEXT: [[RES:%.*]] = tail call <15 x double> @llvm.matrix.multiply.v15f64.v3f64.v5f64(<3 x double> [[A:%.*]], <5 x double> [[B_NEG]], i32 3, i32 1, i32 5)
254-
; CHECK-NEXT: [[RES_NEG:%.*]] = fneg <15 x double> [[RES]]
255-
; CHECK-NEXT: [[RES_2:%.*]] = tail call <6 x double> @llvm.matrix.multiply.v6f64.v15f64.v10f64(<15 x double> [[RES_NEG]], <10 x double> [[C:%.*]], i32 3, i32 5, i32 2)
256-
; CHECK-NEXT: ret <6 x double> [[RES_2]]
286+
; CHECK-NEXT: [[TMP1:%.*]] = fneg <3 x double> [[A:%.*]]
287+
; CHECK-NEXT: [[RES:%.*]] = tail call <15 x double> @llvm.matrix.multiply.v15f64.v3f64.v5f64(<3 x double> [[TMP1]], <5 x double> [[B:%.*]], i32 3, i32 1, i32 5)
288+
; CHECK-NEXT: [[TMP2:%.*]] = call <6 x double> @llvm.matrix.multiply.v6f64.v15f64.v10f64(<15 x double> [[RES]], <10 x double> [[C:%.*]], i32 3, i32 5, i32 2)
289+
; CHECK-NEXT: [[TMP3:%.*]] = fneg <6 x double> [[TMP2]]
290+
; CHECK-NEXT: ret <6 x double> [[TMP3]]
257291
;
258292
%b.neg = fneg <5 x double> %b
259293
%res = tail call <15 x double> @llvm.matrix.multiply.v15f64.v3f64.v5f64(<3 x double> %a, <5 x double> %b.neg, i32 3, i32 1, i32 5)
@@ -265,10 +299,10 @@ define <6 x double> @chain_of_matrix_mutliplies_with_two_negations(<3 x double>
265299
; negation should be propagated to the result of the second matrix multiplication
266300
define <6 x double> @chain_of_matrix_mutliplies_propagation(<15 x double> %a, <20 x double> %b, <8 x double> %c){
267301
; CHECK-LABEL: @chain_of_matrix_mutliplies_propagation(
268-
; CHECK-NEXT: [[A_NEG:%.*]] = fneg <15 x double> [[A:%.*]]
269-
; CHECK-NEXT: [[RES:%.*]] = tail call <12 x double> @llvm.matrix.multiply.v12f64.v15f64.v20f64(<15 x double> [[A_NEG]], <20 x double> [[B:%.*]], i32 3, i32 5, i32 4)
270-
; CHECK-NEXT: [[RES_2:%.*]] = tail call <6 x double> @llvm.matrix.multiply.v6f64.v12f64.v8f64(<12 x double> [[RES]], <8 x double> [[C:%.*]], i32 3, i32 4, i32 2)
271-
; CHECK-NEXT: ret <6 x double> [[RES_2]]
302+
; CHECK-NEXT: [[TMP1:%.*]] = call <12 x double> @llvm.matrix.multiply.v12f64.v15f64.v20f64(<15 x double> [[A:%.*]], <20 x double> [[B:%.*]], i32 3, i32 5, i32 4)
303+
; CHECK-NEXT: [[TMP2:%.*]] = call <6 x double> @llvm.matrix.multiply.v6f64.v12f64.v8f64(<12 x double> [[TMP1]], <8 x double> [[C:%.*]], i32 3, i32 4, i32 2)
304+
; CHECK-NEXT: [[TMP3:%.*]] = fneg <6 x double> [[TMP2]]
305+
; CHECK-NEXT: ret <6 x double> [[TMP3]]
272306
;
273307
%a.neg = fneg <15 x double> %a
274308
%res = tail call <12 x double> @llvm.matrix.multiply.v12f64.v15f64.v20f64(<15 x double> %a.neg, <20 x double> %b, i32 3, i32 5, i32 4)

0 commit comments

Comments
 (0)