@@ -94,6 +94,7 @@ static bool isDPASCompatible(linalg::LinalgOp linalgOp, int kTile,
94
94
ArrayRef<int64_t > dpasTile) {
95
95
if (!(isa<linalg::MatmulOp>(linalgOp) ||
96
96
isa<linalg::BatchReduceMatmulOp>(linalgOp) ||
97
+ isa<linalg::MatmulTransposeBOp>(linalgOp) ||
97
98
isa<linalg::GenericOp>(linalgOp))) {
98
99
return false ;
99
100
}
@@ -633,12 +634,11 @@ static SmallVector<Value> updateTilesOffsets(PatternRewriter &rewriter,
633
634
//
634
635
// The descriptor sub-tiles are ordered in row-major fashion with respect to the
635
636
// whole load tile.
636
- static SmallVector<Value> createDescriptorTiles (PatternRewriter &rewriter,
637
- Location loc, Value src,
638
- ArrayRef<int64_t > loadShape,
639
- ArrayRef<int64_t > loadOffsets,
640
- ArrayRef<int64_t > descTile,
641
- int arrayLength = 1 ) {
637
+ static SmallVector<Value>
638
+ createDescriptorTiles (PatternRewriter &rewriter, Location loc, Value src,
639
+ ArrayRef<int64_t > loadShape,
640
+ ArrayRef<int64_t > loadOffsets, ArrayRef<int64_t > descTile,
641
+ int arrayLength = 1 , bool transpose = false ) {
642
642
assert (arrayLength == 1 && " Array descriptors are not supported" );
643
643
644
644
auto type = cast<ShapedType>(src.getType ());
@@ -669,6 +669,9 @@ static SmallVector<Value> createDescriptorTiles(PatternRewriter &rewriter,
669
669
Value newRowOffs = rewriter.create <arith::ConstantIndexOp>(loc, i);
670
670
for (int j = 0 ; j < loadShape[1 ]; j += descTile[1 ] * arrayLength) {
671
671
Value newColOffs = rewriter.create <arith::ConstantIndexOp>(loc, j);
672
+ if (transpose) {
673
+ std::swap (newRowOffs, newColOffs);
674
+ }
672
675
auto tile = rewriter
673
676
.create <xegpu::UpdateNdOffsetOp>(
674
677
loc, descType, rootTile,
@@ -693,7 +696,8 @@ static SmallVector<Value> createDescriptorTiles(PatternRewriter &rewriter,
693
696
static SmallVector<Value> createCoarseDscTiles (PatternRewriter &rewriter,
694
697
Location loc, Value src,
695
698
ArrayRef<int64_t > sgTile,
696
- bool isVnni) {
699
+ bool isVnni,
700
+ bool transpose = false ) {
697
701
assert (sgTile.size () <= 2 &&
698
702
" Require at most 2D tile size for eltwise lowering" );
699
703
@@ -727,7 +731,8 @@ static SmallVector<Value> createCoarseDscTiles(PatternRewriter &rewriter,
727
731
// NOLINTEND
728
732
729
733
return createDescriptorTiles (rewriter, loc, src, sgTile2D, {0 , 0 },
730
- {sgLoadRows, sgLoadCols}, arrayLength);
734
+ {sgLoadRows, sgLoadCols}, arrayLength,
735
+ transpose);
731
736
}
732
737
733
738
// Return vector type with specified VNNI shape.
@@ -745,7 +750,8 @@ static SmallVector<Value>
745
750
loadNdDescTiles (PatternRewriter &rewriter, Location loc, ValueRange loadTiles,
746
751
xegpu::CachePolicyAttr hint,
747
752
std::optional<VnniConfig> vnniConf = std::nullopt,
748
- DenseI64ArrayAttr transpose = nullptr ) {
753
+ DenseI64ArrayAttr transpose = nullptr ,
754
+ IntegerAttr transpose_bit = nullptr ) {
749
755
// Assume all tiles have the same shape.
750
756
auto tileType = cast<xegpu::TensorDescType>(loadTiles[0 ].getType ());
751
757
assert (llvm::all_of (loadTiles,
@@ -760,7 +766,6 @@ loadNdDescTiles(PatternRewriter &rewriter, Location loc, ValueRange loadTiles,
760
766
*vnniConf);
761
767
packedAttr = mlir::UnitAttr::get (rewriter.getContext ());
762
768
}
763
- IntegerAttr transpose_bit = nullptr ;
764
769
SmallVector<Value> loadVec;
765
770
for (auto tile : loadTiles) {
766
771
@@ -860,13 +865,82 @@ extractVecSubTiles(PatternRewriter &rewriter, Location loc,
860
865
return subTiles;
861
866
}
862
867
868
+ // Checks whether the given `matmulOperand` is produced by a
869
+ // `linalg::TransposeOp` and ensures that the transpose result is only used by
870
+ // valid operations, such as `linalg::MatmulOp`, `linalg::BatchReduceMatmulOp`,
871
+ // or `linalg::GenericOp`.
872
+ //
873
+ // If a valid transpose operation is found, the function records it for later
874
+ // removal and returns the operand of the transpose operation as the new matrix
875
+ // multiplication operand.
876
+ static FailureOr<Value> findAndReplaceTranspose (const Value &matmulOperand,
877
+ size_t operandIdx,
878
+ PatternRewriter &rewriter) {
879
+ auto defOp = matmulOperand.getDefiningOp ();
880
+ if (!defOp) {
881
+ return failure ();
882
+ }
883
+ linalg::TransposeOp transposeOp = nullptr ;
884
+
885
+ for (auto x : defOp->getUsers ()) {
886
+ if (isa<linalg::TransposeOp>(x)) {
887
+ if (transposeOp) {
888
+ return rewriter.notifyMatchFailure (
889
+ transposeOp, " Only one transpose operation is allowed" );
890
+ }
891
+
892
+ transposeOp = dyn_cast<linalg::TransposeOp>(x);
893
+
894
+ auto transposeRes = transposeOp.getDpsInits ()[0 ];
895
+ // verify that there are no other users of the transpose result
896
+ // rather than our matmul
897
+ for (auto trUser : transposeRes.getUsers ()) {
898
+ if (isa<linalg::MatmulOp>(trUser) ||
899
+ isa<linalg::BatchReduceMatmulOp>(trUser) ||
900
+ isa<linalg::GenericOp>(trUser)) {
901
+ auto matmulOp = dyn_cast<linalg::LinalgOp>(trUser);
902
+ auto actualMatmulOperand = matmulOp.getDpsInputs ()[operandIdx];
903
+ if (actualMatmulOperand != matmulOperand) {
904
+ return rewriter.notifyMatchFailure (
905
+ trUser,
906
+ " Transpose result is used by more than one matmul operation" );
907
+ }
908
+ } else if (isa<memref::DeallocOp>(trUser)) {
909
+ // allow deallocs as users
910
+ continue ;
911
+ } else if (isa<linalg::TransposeOp>(trUser)) {
912
+ // check if it's the same transpose as we're processing
913
+ if (!mlir::OperationEquivalence::isEquivalentTo (trUser, transposeOp,
914
+ /* flags=*/ nullptr )) {
915
+ return rewriter.notifyMatchFailure (
916
+ trUser, " Only one transpose operation is allowed" );
917
+ }
918
+ continue ;
919
+ } else {
920
+ return rewriter.notifyMatchFailure (
921
+ trUser,
922
+ " Transpose result is not allowed to be used by this operation" );
923
+ }
924
+ }
925
+ }
926
+ }
927
+ if (transposeOp) {
928
+ auto ret = transposeOp.getDpsInputs ()[0 ];
929
+ rewriter.eraseOp (transposeOp);
930
+ return ret;
931
+ }
932
+ return rewriter.notifyMatchFailure (
933
+ defOp, " No transpose operation producing the operand was found" );
934
+ }
935
+
863
936
// Create XeGPU DPAS kernel out of GEMM-like operation.
864
937
static LogicalResult createDPASKernel (linalg::LinalgOp linalgOp,
865
938
ArrayRef<int64_t > dpasTile, int kTile ,
866
939
int prefetchStages,
867
940
PatternRewriter &rewriter) {
868
941
assert ((isa<linalg::MatmulOp>(linalgOp) ||
869
942
isa<linalg::BatchReduceMatmulOp>(linalgOp) ||
943
+ isa<linalg::MatmulTransposeBOp>(linalgOp) ||
870
944
isa<linalg::GenericOp>(linalgOp)) &&
871
945
" Requires a GEMM-like op for DPAS lowering" );
872
946
@@ -877,6 +951,17 @@ static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp,
877
951
auto matB = linalgOp.getDpsInputs ()[1 ];
878
952
auto matC = linalgOp.getDpsInits ()[0 ];
879
953
954
+ bool transposeB = false ;
955
+ if (isa<linalg::MatmulTransposeBOp>(linalgOp)) {
956
+ transposeB = true ;
957
+ } else {
958
+ auto newMatB = findAndReplaceTranspose (matB, /* operandIdx=*/ 1 , rewriter);
959
+ if (!failed (newMatB)) {
960
+ matB = *newMatB;
961
+ transposeB = true ;
962
+ }
963
+ }
964
+
880
965
auto typeA = cast<ShapedType>(matA.getType ());
881
966
auto typeC = cast<ShapedType>(matC.getType ());
882
967
@@ -961,7 +1046,8 @@ static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp,
961
1046
962
1047
// Create B sub-tiles.
963
1048
SmallVector<Value> tilesB =
964
- createCoarseDscTiles (rewriter, loc, matB, {kTile , dimN}, /* isVnni=*/ true );
1049
+ createCoarseDscTiles (rewriter, loc, matB, {kTile , dimN},
1050
+ /* isVnni=*/ true , transposeB);
965
1051
966
1052
// Create input prefetch tiles.
967
1053
int64_t numThreads = 1 ;
@@ -997,7 +1083,9 @@ static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp,
997
1083
{dimM, dimN}, kTile );
998
1084
auto prefetchDescB = createGemmCoopPrefetchTile (
999
1085
rewriter, linalgOp, /* inputPos=*/ 1 , numThreads, {blockRows, blockCols},
1000
- {dimM, dimN}, kTile );
1086
+ (transposeB) ? std::vector<int32_t >{dimM, dimN}
1087
+ : std::vector<int32_t >{dimN, dimM},
1088
+ kTile );
1001
1089
1002
1090
if (succeeded (prefetchDescA) && succeeded (prefetchDescB)) {
1003
1091
prefetchA = prefetchDescA->getResult ();
@@ -1012,7 +1100,9 @@ static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp,
1012
1100
prefetchA = updateTilesOffsets (rewriter, loc, ValueRange{prefetchA},
1013
1101
{0 , kTile })[0 ];
1014
1102
prefetchB = updateTilesOffsets (rewriter, loc, ValueRange{prefetchB},
1015
- {kTile , 0 })[0 ];
1103
+ (transposeB)
1104
+ ? std::vector<int64_t >{0 , kTile }
1105
+ : std::vector<int64_t >{kTile , 0 })[0 ];
1016
1106
}
1017
1107
} else {
1018
1108
// Disable coop prefetching on failure.
@@ -1083,15 +1173,26 @@ static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp,
1083
1173
loadNdDescTiles (rewriter, loc, tilesA, readCacheHint);
1084
1174
auto tileTypeA = cast<xegpu::TensorDescType>(tilesA[0 ].getType ());
1085
1175
1176
+ DenseI64ArrayAttr transpose = nullptr ;
1177
+ IntegerAttr transpose_bit = nullptr ;
1178
+
1179
+ if (transposeB) {
1180
+ transpose_bit = rewriter.getIntegerAttr (rewriter.getIntegerType (32 ), 32 );
1181
+ transpose = DenseI64ArrayAttr::get (rewriter.getContext (), {1 , 0 });
1182
+ }
1183
+
1086
1184
// Load B sub-tiles.
1087
1185
SmallVector<Value> loadVecB =
1088
- loadNdDescTiles (rewriter, loc, tilesB, readCacheHint, vnniConfB);
1186
+ loadNdDescTiles (rewriter, loc, tilesB, readCacheHint, vnniConfB,
1187
+ transpose, transpose_bit);
1089
1188
auto tileTypeB = cast<xegpu::TensorDescType>(tilesB[0 ].getType ());
1090
1189
1091
1190
// Update offsets of the input tiles.
1092
1191
// Shift along the reduction dimension.
1093
1192
tilesA = updateTilesOffsets (rewriter, loc, tilesA, {0 , kTile });
1094
- tilesB = updateTilesOffsets (rewriter, loc, tilesB, {kTile , 0 });
1193
+ tilesB = updateTilesOffsets (rewriter, loc, tilesB,
1194
+ transposeB ? std::vector<int64_t >{0 , kTile }
1195
+ : std::vector<int64_t >{kTile , 0 });
1095
1196
1096
1197
// Prefetch the next set of input tiles.
1097
1198
if (isCoopPrefetch) {
@@ -1101,7 +1202,9 @@ static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp,
1101
1202
prefetchA =
1102
1203
updateTilesOffsets (rewriter, loc, ValueRange{prefetchA}, {0 , kTile })[0 ];
1103
1204
prefetchB =
1104
- updateTilesOffsets (rewriter, loc, ValueRange{prefetchB}, {kTile , 0 })[0 ];
1205
+ updateTilesOffsets (rewriter, loc, ValueRange{prefetchB},
1206
+ transposeB ? std::vector<int64_t >{0 , kTile }
1207
+ : std::vector<int64_t >{kTile , 0 })[0 ];
1105
1208
} else {
1106
1209
// Apply naive prefetching for each subgroup separately.
1107
1210
prefetchTiles (rewriter, loc, tilesA, readCacheHint);
@@ -1288,7 +1391,7 @@ struct ConvertGemmLikeToXeGPU : public OpRewritePattern<LinalgOpTy> {
1288
1391
// Constrain conversion to the supported GEMM-like ops.
1289
1392
static_assert (
1290
1393
llvm::is_one_of<LinalgOpTy, linalg::MatmulOp, linalg::BatchReduceMatmulOp,
1291
- linalg::GenericOp>::value);
1394
+ linalg::GenericOp, linalg::MatmulTransposeBOp >::value);
1292
1395
1293
1396
ConvertGemmLikeToXeGPU (MLIRContext *ctx, LinalgToXeGPUOptions options)
1294
1397
: OpRewritePattern<LinalgOpTy>(ctx), options(options) {}
@@ -1495,8 +1598,9 @@ struct ConvertMemoryFillToXeGPU : public OpRewritePattern<LinalgOpTy> {
1495
1598
void populateLinalgGemmToXeGPUPatterns (RewritePatternSet &patterns,
1496
1599
LinalgToXeGPUOptions options) {
1497
1600
patterns.add <ConvertGemmLikeToXeGPU<linalg::MatmulOp>,
1498
- ConvertGemmLikeToXeGPU<linalg::GenericOp>>(patterns.getContext (),
1499
- options);
1601
+ ConvertGemmLikeToXeGPU<linalg::GenericOp>,
1602
+ ConvertGemmLikeToXeGPU<linalg::MatmulTransposeBOp>>(
1603
+ patterns.getContext (), options);
1500
1604
}
1501
1605
1502
1606
void populateLinalgEltwiseToXeGPUPatterns (RewritePatternSet &patterns,
0 commit comments