@@ -1359,10 +1359,12 @@ class LowerMatrixIntrinsics {
1359
1359
return ;
1360
1360
1361
1361
auto CanBeFlattened = [](Value *Op) {
1362
- return match (Op, m_OneUse (m_CombineOr (
1363
- m_Load (m_Value ()),
1364
- m_Intrinsic<Intrinsic::matrix_column_major_load>(
1365
- m_Value (), m_SpecificInt (1 )))));
1362
+ return match (
1363
+ Op, m_OneUse (m_CombineOr (
1364
+ m_Load (m_Value ()),
1365
+ m_CombineOr (m_Intrinsic<Intrinsic::matrix_transpose>(),
1366
+ m_Intrinsic<Intrinsic::matrix_column_major_load>(
1367
+ m_Value (), m_SpecificInt (1 ))))));
1366
1368
};
1367
1369
// Returns the cost benefit of using \p Op with the dot product lowering. If
1368
1370
// the returned cost is < 0, the argument is cheaper to use in the
@@ -1374,21 +1376,34 @@ class LowerMatrixIntrinsics {
1374
1376
FixedVectorType *VecTy = cast<FixedVectorType>(Op->getType ());
1375
1377
Type *EltTy = VecTy->getElementType ();
1376
1378
1377
- if (CanBeFlattened (Op)) {
1378
- if (N == 1 )
1379
- return InstructionCost (0 );
1379
+ if (!CanBeFlattened (Op)) {
1380
+ InstructionCost EmbedCost (0 );
1381
+ // Roughly estimate the cost for embedding the columns into a vector.
1382
+ for (unsigned I = 1 ; I < N; ++I)
1383
+ EmbedCost -=
1384
+ TTI.getShuffleCost (TTI::SK_Splice, FixedVectorType::get (EltTy, 1 ),
1385
+ std::nullopt, TTI::TCK_RecipThroughput);
1386
+ return EmbedCost;
1387
+ }
1380
1388
1381
- return TTI.getMemoryOpCost (Instruction::Load, VecTy, Align (1 ), 0 ) -
1382
- N * TTI.getMemoryOpCost (Instruction::Load, EltTy, Align (1 ), 0 );
1389
+ if (match (Op, m_Intrinsic<Intrinsic::matrix_transpose>())) {
1390
+ // The transpose can be skipped for the dot product lowering, roughly
1391
+ // estimate the savings as the cost of embedding the columns in a
1392
+ // vector.
1393
+ InstructionCost EmbedCost (0 );
1394
+ for (unsigned I = 1 ; I < N; ++I)
1395
+ EmbedCost +=
1396
+ TTI.getShuffleCost (TTI::SK_Splice, FixedVectorType::get (EltTy, 1 ),
1397
+ std::nullopt, TTI::TCK_RecipThroughput);
1398
+ return EmbedCost;
1383
1399
}
1384
1400
1385
- InstructionCost EmbedCost (0 );
1386
- // Roughly estimate the cost for embedding the columns into a vector.
1387
- for (unsigned I = 1 ; I < N; ++I)
1388
- EmbedCost +=
1389
- TTI.getShuffleCost (TTI::SK_Splice, FixedVectorType::get (EltTy, 1 ),
1390
- std::nullopt, TTI::TCK_RecipThroughput);
1391
- return EmbedCost;
1401
+ // Costs for loads.
1402
+ if (N == 1 )
1403
+ return InstructionCost (0 );
1404
+
1405
+ return TTI.getMemoryOpCost (Instruction::Load, VecTy, Align (1 ), 0 ) -
1406
+ N * TTI.getMemoryOpCost (Instruction::Load, EltTy, Align (1 ), 0 );
1392
1407
};
1393
1408
auto LHSCost = GetCostForArg (LHS, LShape.NumColumns );
1394
1409
@@ -1410,24 +1425,30 @@ class LowerMatrixIntrinsics {
1410
1425
1411
1426
FusedInsts.insert (MatMul);
1412
1427
IRBuilder<> Builder (MatMul);
1413
- auto FlattenArg = [&Builder, &FusedInsts,
1414
- &CanBeFlattened ](Value *Op) -> Value * {
1428
+ auto FlattenArg = [&Builder, &FusedInsts, &CanBeFlattened,
1429
+ this ](Value *Op) -> Value * {
1415
1430
// Matmul must be the only user of loads because we don't use LowerLoad
1416
1431
// for row vectors (LowerLoad results in scalar loads and shufflevectors
1417
1432
// instead of single vector load).
1418
1433
if (!CanBeFlattened (Op))
1419
1434
return Op;
1420
1435
1421
1436
FusedInsts.insert (cast<Instruction>(Op));
1437
+
1422
1438
// If vector uses the builtin load, lower to a LoadInst
1423
- Value *Ptr ;
1439
+ Value *Arg ;
1424
1440
if (match (Op, m_Intrinsic<Intrinsic::matrix_column_major_load>(
1425
- m_Value (Ptr )))) {
1426
- auto *NewLoad = Builder.CreateLoad (Op->getType (), Ptr );
1441
+ m_Value (Arg )))) {
1442
+ auto *NewLoad = Builder.CreateLoad (Op->getType (), Arg );
1427
1443
Op->replaceAllUsesWith (NewLoad);
1428
1444
cast<Instruction>(Op)->eraseFromParent ();
1429
1445
return NewLoad;
1446
+ } else if (match (Op, m_Intrinsic<Intrinsic::matrix_transpose>(
1447
+ m_Value (Arg)))) {
1448
+ ToRemove.push_back (cast<Instruction>(Op));
1449
+ return Arg;
1430
1450
}
1451
+
1431
1452
return Op;
1432
1453
};
1433
1454
LHS = FlattenArg (LHS);
0 commit comments