Skip to content

Commit ebbcbb2

Browse files
committed
[Matrix] Remove redundant transpose with dot product lowering.
Extend dot-product handling to skip transposes of the first operand. As this is a vector, the conversion between column and row vector via the transpose isn't needed. Reviewed By: thegameg Differential Revision: https://reviews.llvm.org/D148428
1 parent 709f59e commit ebbcbb2

File tree

2 files changed

+49
-48
lines changed

2 files changed

+49
-48
lines changed

llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp

Lines changed: 42 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1359,10 +1359,12 @@ class LowerMatrixIntrinsics {
13591359
return;
13601360

13611361
auto CanBeFlattened = [](Value *Op) {
1362-
return match(Op, m_OneUse(m_CombineOr(
1363-
m_Load(m_Value()),
1364-
m_Intrinsic<Intrinsic::matrix_column_major_load>(
1365-
m_Value(), m_SpecificInt(1)))));
1362+
return match(
1363+
Op, m_OneUse(m_CombineOr(
1364+
m_Load(m_Value()),
1365+
m_CombineOr(m_Intrinsic<Intrinsic::matrix_transpose>(),
1366+
m_Intrinsic<Intrinsic::matrix_column_major_load>(
1367+
m_Value(), m_SpecificInt(1))))));
13661368
};
13671369
// Returns the cost benefit of using \p Op with the dot product lowering. If
13681370
// the returned cost is < 0, the argument is cheaper to use in the
@@ -1374,21 +1376,34 @@ class LowerMatrixIntrinsics {
13741376
FixedVectorType *VecTy = cast<FixedVectorType>(Op->getType());
13751377
Type *EltTy = VecTy->getElementType();
13761378

1377-
if (CanBeFlattened(Op)) {
1378-
if (N == 1)
1379-
return InstructionCost(0);
1379+
if (!CanBeFlattened(Op)) {
1380+
InstructionCost EmbedCost(0);
1381+
// Roughly estimate the cost for embedding the columns into a vector.
1382+
for (unsigned I = 1; I < N; ++I)
1383+
EmbedCost -=
1384+
TTI.getShuffleCost(TTI::SK_Splice, FixedVectorType::get(EltTy, 1),
1385+
std::nullopt, TTI::TCK_RecipThroughput);
1386+
return EmbedCost;
1387+
}
13801388

1381-
return TTI.getMemoryOpCost(Instruction::Load, VecTy, Align(1), 0) -
1382-
N * TTI.getMemoryOpCost(Instruction::Load, EltTy, Align(1), 0);
1389+
if (match(Op, m_Intrinsic<Intrinsic::matrix_transpose>())) {
1390+
// The transpose can be skipped for the dot product lowering, roughly
1391+
// estimate the savings as the cost of embedding the columns in a
1392+
// vector.
1393+
InstructionCost EmbedCost(0);
1394+
for (unsigned I = 1; I < N; ++I)
1395+
EmbedCost +=
1396+
TTI.getShuffleCost(TTI::SK_Splice, FixedVectorType::get(EltTy, 1),
1397+
std::nullopt, TTI::TCK_RecipThroughput);
1398+
return EmbedCost;
13831399
}
13841400

1385-
InstructionCost EmbedCost(0);
1386-
// Roughly estimate the cost for embedding the columns into a vector.
1387-
for (unsigned I = 1; I < N; ++I)
1388-
EmbedCost +=
1389-
TTI.getShuffleCost(TTI::SK_Splice, FixedVectorType::get(EltTy, 1),
1390-
std::nullopt, TTI::TCK_RecipThroughput);
1391-
return EmbedCost;
1401+
// Costs for loads.
1402+
if (N == 1)
1403+
return InstructionCost(0);
1404+
1405+
return TTI.getMemoryOpCost(Instruction::Load, VecTy, Align(1), 0) -
1406+
N * TTI.getMemoryOpCost(Instruction::Load, EltTy, Align(1), 0);
13921407
};
13931408
auto LHSCost = GetCostForArg(LHS, LShape.NumColumns);
13941409

@@ -1410,24 +1425,30 @@ class LowerMatrixIntrinsics {
14101425

14111426
FusedInsts.insert(MatMul);
14121427
IRBuilder<> Builder(MatMul);
1413-
auto FlattenArg = [&Builder, &FusedInsts,
1414-
&CanBeFlattened](Value *Op) -> Value * {
1428+
auto FlattenArg = [&Builder, &FusedInsts, &CanBeFlattened,
1429+
this](Value *Op) -> Value * {
14151430
// Matmul must be the only user of loads because we don't use LowerLoad
14161431
// for row vectors (LowerLoad results in scalar loads and shufflevectors
14171432
// instead of single vector load).
14181433
if (!CanBeFlattened(Op))
14191434
return Op;
14201435

14211436
FusedInsts.insert(cast<Instruction>(Op));
1437+
14221438
// If vector uses the builtin load, lower to a LoadInst
1423-
Value *Ptr;
1439+
Value *Arg;
14241440
if (match(Op, m_Intrinsic<Intrinsic::matrix_column_major_load>(
1425-
m_Value(Ptr)))) {
1426-
auto *NewLoad = Builder.CreateLoad(Op->getType(), Ptr);
1441+
m_Value(Arg)))) {
1442+
auto *NewLoad = Builder.CreateLoad(Op->getType(), Arg);
14271443
Op->replaceAllUsesWith(NewLoad);
14281444
cast<Instruction>(Op)->eraseFromParent();
14291445
return NewLoad;
1446+
} else if (match(Op, m_Intrinsic<Intrinsic::matrix_transpose>(
1447+
m_Value(Arg)))) {
1448+
ToRemove.push_back(cast<Instruction>(Op));
1449+
return Arg;
14301450
}
1451+
14311452
return Op;
14321453
};
14331454
LHS = FlattenArg(LHS);

llvm/test/Transforms/LowerMatrixIntrinsics/dot-product-transpose-int.ll

Lines changed: 7 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,9 @@
55
define void @transposed_multiply_feeding_dot_product_v4i322(<4 x i32> %a, <4 x i32> %b) {
66
; CHECK-LABEL: @transposed_multiply_feeding_dot_product_v4i322(
77
; CHECK-NEXT: entry:
8-
; CHECK-NEXT: [[SPLIT:%.*]] = shufflevector <4 x i32> [[A:%.*]], <4 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
9-
; CHECK-NEXT: [[TMP0:%.*]] = extractelement <4 x i32> [[SPLIT]], i64 0
10-
; CHECK-NEXT: [[TMP1:%.*]] = insertelement <1 x i32> poison, i32 [[TMP0]], i64 0
11-
; CHECK-NEXT: [[TMP2:%.*]] = extractelement <4 x i32> [[SPLIT]], i64 1
12-
; CHECK-NEXT: [[TMP3:%.*]] = insertelement <1 x i32> poison, i32 [[TMP2]], i64 0
13-
; CHECK-NEXT: [[TMP4:%.*]] = extractelement <4 x i32> [[SPLIT]], i64 2
14-
; CHECK-NEXT: [[TMP5:%.*]] = insertelement <1 x i32> poison, i32 [[TMP4]], i64 0
15-
; CHECK-NEXT: [[TMP6:%.*]] = extractelement <4 x i32> [[SPLIT]], i64 3
16-
; CHECK-NEXT: [[TMP7:%.*]] = insertelement <1 x i32> poison, i32 [[TMP6]], i64 0
17-
; CHECK-NEXT: [[TMP8:%.*]] = shufflevector <1 x i32> [[TMP1]], <1 x i32> [[TMP3]], <2 x i32> <i32 0, i32 1>
18-
; CHECK-NEXT: [[TMP9:%.*]] = shufflevector <1 x i32> [[TMP5]], <1 x i32> [[TMP7]], <2 x i32> <i32 0, i32 1>
19-
; CHECK-NEXT: [[TMP10:%.*]] = shufflevector <2 x i32> [[TMP8]], <2 x i32> [[TMP9]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
20-
; CHECK-NEXT: [[TMP11:%.*]] = mul <4 x i32> [[TMP10]], [[B:%.*]]
21-
; CHECK-NEXT: [[TMP12:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP11]])
22-
; CHECK-NEXT: [[TMP13:%.*]] = insertelement <1 x i32> poison, i32 [[TMP12]], i64 0
8+
; CHECK-NEXT: [[TMP0:%.*]] = mul <4 x i32> [[A:%.*]], [[B:%.*]]
9+
; CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP0]])
10+
; CHECK-NEXT: [[TMP2:%.*]] = insertelement <1 x i32> poison, i32 [[TMP1]], i64 0
2311
; CHECK-NEXT: ret void
2412
;
2513
entry:
@@ -61,18 +49,10 @@ define void @transposed_multiply_feeding_dot_produc_v4i32(<4 x i32> %a, <4 x i32
6149
; CHECK-NEXT: [[TMP11:%.*]] = add <2 x i32> [[TMP8]], [[TMP10]]
6250
; CHECK-NEXT: [[TMP12:%.*]] = shufflevector <2 x i32> [[TMP11]], <2 x i32> poison, <2 x i32> <i32 0, i32 1>
6351
; CHECK-NEXT: [[TMP13:%.*]] = shufflevector <2 x i32> undef, <2 x i32> [[TMP12]], <2 x i32> <i32 2, i32 3>
64-
; CHECK-NEXT: [[TMP14:%.*]] = extractelement <2 x i32> [[TMP6]], i64 0
65-
; CHECK-NEXT: [[TMP15:%.*]] = insertelement <2 x i32> poison, i32 [[TMP14]], i64 0
66-
; CHECK-NEXT: [[TMP16:%.*]] = extractelement <2 x i32> [[TMP13]], i64 0
67-
; CHECK-NEXT: [[TMP17:%.*]] = insertelement <2 x i32> [[TMP15]], i32 [[TMP16]], i64 1
68-
; CHECK-NEXT: [[TMP18:%.*]] = extractelement <2 x i32> [[TMP6]], i64 1
69-
; CHECK-NEXT: [[TMP19:%.*]] = insertelement <2 x i32> poison, i32 [[TMP18]], i64 0
70-
; CHECK-NEXT: [[TMP20:%.*]] = extractelement <2 x i32> [[TMP13]], i64 1
71-
; CHECK-NEXT: [[TMP21:%.*]] = insertelement <2 x i32> [[TMP19]], i32 [[TMP20]], i64 1
72-
; CHECK-NEXT: [[TMP22:%.*]] = shufflevector <2 x i32> [[TMP17]], <2 x i32> [[TMP21]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
73-
; CHECK-NEXT: [[TMP23:%.*]] = mul <4 x i32> [[TMP22]], [[C:%.*]]
74-
; CHECK-NEXT: [[TMP24:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP23]])
75-
; CHECK-NEXT: [[TMP25:%.*]] = insertelement <1 x i32> poison, i32 [[TMP24]], i64 0
52+
; CHECK-NEXT: [[TMP14:%.*]] = shufflevector <2 x i32> [[TMP6]], <2 x i32> [[TMP13]], <4 x i32> <i32 0, i32 1, i32 2, i32 3>
53+
; CHECK-NEXT: [[TMP15:%.*]] = mul <4 x i32> [[TMP14]], [[C:%.*]]
54+
; CHECK-NEXT: [[TMP16:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP15]])
55+
; CHECK-NEXT: [[TMP17:%.*]] = insertelement <1 x i32> poison, i32 [[TMP16]], i64 0
7656
; CHECK-NEXT: ret void
7757
;
7858
entry:

0 commit comments

Comments
 (0)