Skip to content

Commit 1fee896

Browse files
authored
[LinalgToXeGPU] Lower linalg.matmul_transpose_b into xegpu.dpas (#347)
Signed-off-by: dchigarev <[email protected]>
1 parent 199501e commit 1fee896

7 files changed

+557
-19
lines changed

lib/gc/Transforms/GPU/LinalgToXeGPU.cpp

Lines changed: 123 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ static bool isDPASCompatible(linalg::LinalgOp linalgOp, int kTile,
9494
ArrayRef<int64_t> dpasTile) {
9595
if (!(isa<linalg::MatmulOp>(linalgOp) ||
9696
isa<linalg::BatchReduceMatmulOp>(linalgOp) ||
97+
isa<linalg::MatmulTransposeBOp>(linalgOp) ||
9798
isa<linalg::GenericOp>(linalgOp))) {
9899
return false;
99100
}
@@ -633,12 +634,11 @@ static SmallVector<Value> updateTilesOffsets(PatternRewriter &rewriter,
633634
//
634635
// The descriptor sub-tiles are ordered in row-major fashion with respect to the
635636
// whole load tile.
636-
static SmallVector<Value> createDescriptorTiles(PatternRewriter &rewriter,
637-
Location loc, Value src,
638-
ArrayRef<int64_t> loadShape,
639-
ArrayRef<int64_t> loadOffsets,
640-
ArrayRef<int64_t> descTile,
641-
int arrayLength = 1) {
637+
static SmallVector<Value>
638+
createDescriptorTiles(PatternRewriter &rewriter, Location loc, Value src,
639+
ArrayRef<int64_t> loadShape,
640+
ArrayRef<int64_t> loadOffsets, ArrayRef<int64_t> descTile,
641+
int arrayLength = 1, bool transpose = false) {
642642
assert(arrayLength == 1 && "Array descriptors are not supported");
643643

644644
auto type = cast<ShapedType>(src.getType());
@@ -669,6 +669,9 @@ static SmallVector<Value> createDescriptorTiles(PatternRewriter &rewriter,
669669
Value newRowOffs = rewriter.create<arith::ConstantIndexOp>(loc, i);
670670
for (int j = 0; j < loadShape[1]; j += descTile[1] * arrayLength) {
671671
Value newColOffs = rewriter.create<arith::ConstantIndexOp>(loc, j);
672+
if (transpose) {
673+
std::swap(newRowOffs, newColOffs);
674+
}
672675
auto tile = rewriter
673676
.create<xegpu::UpdateNdOffsetOp>(
674677
loc, descType, rootTile,
@@ -693,7 +696,8 @@ static SmallVector<Value> createDescriptorTiles(PatternRewriter &rewriter,
693696
static SmallVector<Value> createCoarseDscTiles(PatternRewriter &rewriter,
694697
Location loc, Value src,
695698
ArrayRef<int64_t> sgTile,
696-
bool isVnni) {
699+
bool isVnni,
700+
bool transpose = false) {
697701
assert(sgTile.size() <= 2 &&
698702
"Require at most 2D tile size for eltwise lowering");
699703

@@ -727,7 +731,8 @@ static SmallVector<Value> createCoarseDscTiles(PatternRewriter &rewriter,
727731
// NOLINTEND
728732

729733
return createDescriptorTiles(rewriter, loc, src, sgTile2D, {0, 0},
730-
{sgLoadRows, sgLoadCols}, arrayLength);
734+
{sgLoadRows, sgLoadCols}, arrayLength,
735+
transpose);
731736
}
732737

733738
// Return vector type with specified VNNI shape.
@@ -745,7 +750,8 @@ static SmallVector<Value>
745750
loadNdDescTiles(PatternRewriter &rewriter, Location loc, ValueRange loadTiles,
746751
xegpu::CachePolicyAttr hint,
747752
std::optional<VnniConfig> vnniConf = std::nullopt,
748-
DenseI64ArrayAttr transpose = nullptr) {
753+
DenseI64ArrayAttr transpose = nullptr,
754+
IntegerAttr transpose_bit = nullptr) {
749755
// Assume all tiles have the same shape.
750756
auto tileType = cast<xegpu::TensorDescType>(loadTiles[0].getType());
751757
assert(llvm::all_of(loadTiles,
@@ -760,7 +766,6 @@ loadNdDescTiles(PatternRewriter &rewriter, Location loc, ValueRange loadTiles,
760766
*vnniConf);
761767
packedAttr = mlir::UnitAttr::get(rewriter.getContext());
762768
}
763-
IntegerAttr transpose_bit = nullptr;
764769
SmallVector<Value> loadVec;
765770
for (auto tile : loadTiles) {
766771

@@ -860,13 +865,82 @@ extractVecSubTiles(PatternRewriter &rewriter, Location loc,
860865
return subTiles;
861866
}
862867

868+
// Checks whether the given `matmulOperand` is produced by a
869+
// `linalg::TransposeOp` and ensures that the transpose result is only used by
870+
// valid operations, such as `linalg::MatmulOp`, `linalg::BatchReduceMatmulOp`,
871+
// or `linalg::GenericOp`.
872+
//
873+
// If a valid transpose operation is found, the function records it for later
874+
// removal and returns the operand of the transpose operation as the new matrix
875+
// multiplication operand.
876+
static FailureOr<Value> findAndReplaceTranspose(const Value &matmulOperand,
877+
size_t operandIdx,
878+
PatternRewriter &rewriter) {
879+
auto defOp = matmulOperand.getDefiningOp();
880+
if (!defOp) {
881+
return failure();
882+
}
883+
linalg::TransposeOp transposeOp = nullptr;
884+
885+
for (auto x : defOp->getUsers()) {
886+
if (isa<linalg::TransposeOp>(x)) {
887+
if (transposeOp) {
888+
return rewriter.notifyMatchFailure(
889+
transposeOp, "Only one transpose operation is allowed");
890+
}
891+
892+
transposeOp = dyn_cast<linalg::TransposeOp>(x);
893+
894+
auto transposeRes = transposeOp.getDpsInits()[0];
895+
// verify that there are no other users of the transpose result
896+
// rather than our matmul
897+
for (auto trUser : transposeRes.getUsers()) {
898+
if (isa<linalg::MatmulOp>(trUser) ||
899+
isa<linalg::BatchReduceMatmulOp>(trUser) ||
900+
isa<linalg::GenericOp>(trUser)) {
901+
auto matmulOp = dyn_cast<linalg::LinalgOp>(trUser);
902+
auto actualMatmulOperand = matmulOp.getDpsInputs()[operandIdx];
903+
if (actualMatmulOperand != matmulOperand) {
904+
return rewriter.notifyMatchFailure(
905+
trUser,
906+
"Transpose result is used by more than one matmul operation");
907+
}
908+
} else if (isa<memref::DeallocOp>(trUser)) {
909+
// allow deallocs as users
910+
continue;
911+
} else if (isa<linalg::TransposeOp>(trUser)) {
912+
// check if it's the same transpose as we're processing
913+
if (!mlir::OperationEquivalence::isEquivalentTo(trUser, transposeOp,
914+
/*flags=*/nullptr)) {
915+
return rewriter.notifyMatchFailure(
916+
trUser, "Only one transpose operation is allowed");
917+
}
918+
continue;
919+
} else {
920+
return rewriter.notifyMatchFailure(
921+
trUser,
922+
"Transpose result is not allowed to be used by this operation");
923+
}
924+
}
925+
}
926+
}
927+
if (transposeOp) {
928+
auto ret = transposeOp.getDpsInputs()[0];
929+
rewriter.eraseOp(transposeOp);
930+
return ret;
931+
}
932+
return rewriter.notifyMatchFailure(
933+
defOp, "No transpose operation producing the operand was found");
934+
}
935+
863936
// Create XeGPU DPAS kernel out of GEMM-like operation.
864937
static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp,
865938
ArrayRef<int64_t> dpasTile, int kTile,
866939
int prefetchStages,
867940
PatternRewriter &rewriter) {
868941
assert((isa<linalg::MatmulOp>(linalgOp) ||
869942
isa<linalg::BatchReduceMatmulOp>(linalgOp) ||
943+
isa<linalg::MatmulTransposeBOp>(linalgOp) ||
870944
isa<linalg::GenericOp>(linalgOp)) &&
871945
"Requires a GEMM-like op for DPAS lowering");
872946

@@ -877,6 +951,17 @@ static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp,
877951
auto matB = linalgOp.getDpsInputs()[1];
878952
auto matC = linalgOp.getDpsInits()[0];
879953

954+
bool transposeB = false;
955+
if (isa<linalg::MatmulTransposeBOp>(linalgOp)) {
956+
transposeB = true;
957+
} else {
958+
auto newMatB = findAndReplaceTranspose(matB, /*operandIdx=*/1, rewriter);
959+
if (!failed(newMatB)) {
960+
matB = *newMatB;
961+
transposeB = true;
962+
}
963+
}
964+
880965
auto typeA = cast<ShapedType>(matA.getType());
881966
auto typeC = cast<ShapedType>(matC.getType());
882967

@@ -961,7 +1046,8 @@ static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp,
9611046

9621047
// Create B sub-tiles.
9631048
SmallVector<Value> tilesB =
964-
createCoarseDscTiles(rewriter, loc, matB, {kTile, dimN}, /*isVnni=*/true);
1049+
createCoarseDscTiles(rewriter, loc, matB, {kTile, dimN},
1050+
/*isVnni=*/true, transposeB);
9651051

9661052
// Create input prefetch tiles.
9671053
int64_t numThreads = 1;
@@ -997,7 +1083,9 @@ static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp,
9971083
{dimM, dimN}, kTile);
9981084
auto prefetchDescB = createGemmCoopPrefetchTile(
9991085
rewriter, linalgOp, /*inputPos=*/1, numThreads, {blockRows, blockCols},
1000-
{dimM, dimN}, kTile);
1086+
(transposeB) ? std::vector<int32_t>{dimM, dimN}
1087+
: std::vector<int32_t>{dimN, dimM},
1088+
kTile);
10011089

10021090
if (succeeded(prefetchDescA) && succeeded(prefetchDescB)) {
10031091
prefetchA = prefetchDescA->getResult();
@@ -1012,7 +1100,9 @@ static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp,
10121100
prefetchA = updateTilesOffsets(rewriter, loc, ValueRange{prefetchA},
10131101
{0, kTile})[0];
10141102
prefetchB = updateTilesOffsets(rewriter, loc, ValueRange{prefetchB},
1015-
{kTile, 0})[0];
1103+
(transposeB)
1104+
? std::vector<int64_t>{0, kTile}
1105+
: std::vector<int64_t>{kTile, 0})[0];
10161106
}
10171107
} else {
10181108
// Disable coop prefetching on failure.
@@ -1083,15 +1173,26 @@ static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp,
10831173
loadNdDescTiles(rewriter, loc, tilesA, readCacheHint);
10841174
auto tileTypeA = cast<xegpu::TensorDescType>(tilesA[0].getType());
10851175

1176+
DenseI64ArrayAttr transpose = nullptr;
1177+
IntegerAttr transpose_bit = nullptr;
1178+
1179+
if (transposeB) {
1180+
transpose_bit = rewriter.getIntegerAttr(rewriter.getIntegerType(32), 32);
1181+
transpose = DenseI64ArrayAttr::get(rewriter.getContext(), {1, 0});
1182+
}
1183+
10861184
// Load B sub-tiles.
10871185
SmallVector<Value> loadVecB =
1088-
loadNdDescTiles(rewriter, loc, tilesB, readCacheHint, vnniConfB);
1186+
loadNdDescTiles(rewriter, loc, tilesB, readCacheHint, vnniConfB,
1187+
transpose, transpose_bit);
10891188
auto tileTypeB = cast<xegpu::TensorDescType>(tilesB[0].getType());
10901189

10911190
// Update offsets of the input tiles.
10921191
// Shift along the reduction dimension.
10931192
tilesA = updateTilesOffsets(rewriter, loc, tilesA, {0, kTile});
1094-
tilesB = updateTilesOffsets(rewriter, loc, tilesB, {kTile, 0});
1193+
tilesB = updateTilesOffsets(rewriter, loc, tilesB,
1194+
transposeB ? std::vector<int64_t>{0, kTile}
1195+
: std::vector<int64_t>{kTile, 0});
10951196

10961197
// Prefetch the next set of input tiles.
10971198
if (isCoopPrefetch) {
@@ -1101,7 +1202,9 @@ static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp,
11011202
prefetchA =
11021203
updateTilesOffsets(rewriter, loc, ValueRange{prefetchA}, {0, kTile})[0];
11031204
prefetchB =
1104-
updateTilesOffsets(rewriter, loc, ValueRange{prefetchB}, {kTile, 0})[0];
1205+
updateTilesOffsets(rewriter, loc, ValueRange{prefetchB},
1206+
transposeB ? std::vector<int64_t>{0, kTile}
1207+
: std::vector<int64_t>{kTile, 0})[0];
11051208
} else {
11061209
// Apply naive prefetching for each subgroup separately.
11071210
prefetchTiles(rewriter, loc, tilesA, readCacheHint);
@@ -1288,7 +1391,7 @@ struct ConvertGemmLikeToXeGPU : public OpRewritePattern<LinalgOpTy> {
12881391
// Constrain conversion to the supported GEMM-like ops.
12891392
static_assert(
12901393
llvm::is_one_of<LinalgOpTy, linalg::MatmulOp, linalg::BatchReduceMatmulOp,
1291-
linalg::GenericOp>::value);
1394+
linalg::GenericOp, linalg::MatmulTransposeBOp>::value);
12921395

12931396
ConvertGemmLikeToXeGPU(MLIRContext *ctx, LinalgToXeGPUOptions options)
12941397
: OpRewritePattern<LinalgOpTy>(ctx), options(options) {}
@@ -1495,8 +1598,9 @@ struct ConvertMemoryFillToXeGPU : public OpRewritePattern<LinalgOpTy> {
14951598
void populateLinalgGemmToXeGPUPatterns(RewritePatternSet &patterns,
14961599
LinalgToXeGPUOptions options) {
14971600
patterns.add<ConvertGemmLikeToXeGPU<linalg::MatmulOp>,
1498-
ConvertGemmLikeToXeGPU<linalg::GenericOp>>(patterns.getContext(),
1499-
options);
1601+
ConvertGemmLikeToXeGPU<linalg::GenericOp>,
1602+
ConvertGemmLikeToXeGPU<linalg::MatmulTransposeBOp>>(
1603+
patterns.getContext(), options);
15001604
}
15011605

15021606
void populateLinalgEltwiseToXeGPUPatterns(RewritePatternSet &patterns,
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
// RUN: gc-opt %s -linalg-to-xegpu="dpas-tile=8,16,16 k-tile=16" -canonicalize -split-input-file | FileCheck %s
2+
3+
module {
4+
func.func @matmul_transpose_b_sep(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf16>) {
5+
%c0 = arith.constant 0 : index
6+
%c32 = arith.constant 32 : index
7+
%c1024 = arith.constant 1024 : index
8+
%alloc = memref.alloc() {alignment = 64 : i64} : memref<1024x1024xf16>
9+
scf.parallel (%arg3, %arg4) = (%c0, %c0) to (%c1024, %c1024) step (%c32, %c32) {
10+
%subview_0 = memref.subview %arg2[%arg3, %arg4] [32, 32] [1, 1] : memref<1024x1024xf16> to memref<32x32xf16, strided<[1024, 1], offset: ?>>
11+
%subview_1 = memref.subview %arg0[%arg3, 0] [32, 1024] [1, 1] : memref<1024x1024xf16> to memref<32x1024xf16, strided<[1024, 1], offset: ?>>
12+
%subview_2 = memref.subview %arg1[%arg4, 0] [32, 1024] [1, 1] : memref<1024x1024xf16> to memref<32x1024xf16, strided<[1024, 1], offset: ?>>
13+
%subview_3 = memref.subview %alloc[0, %arg4] [1024, 32] [1, 1] : memref<1024x1024xf16> to memref<1024x32xf16, strided<[1024, 1], offset: ?>>
14+
linalg.transpose ins(%subview_2 : memref<32x1024xf16, strided<[1024, 1], offset: ?>>) outs(%subview_3 : memref<1024x32xf16, strided<[1024, 1], offset: ?>>) permutation = [1, 0]
15+
linalg.matmul ins(%subview_1, %subview_3 : memref<32x1024xf16, strided<[1024, 1], offset: ?>>, memref<1024x32xf16, strided<[1024, 1], offset: ?>>) outs(%subview_0 : memref<32x32xf16, strided<[1024, 1], offset: ?>>)
16+
scf.reduce
17+
}
18+
memref.dealloc %alloc : memref<1024x1024xf16>
19+
return
20+
}
21+
}
22+
23+
// CHECK-LABEL: func.func @matmul_transpose_b_sep
24+
// CHECK-SAME: %[[Ap:.+]]: memref<1024x1024xf16>, %[[Bp:.+]]: memref<1024x1024xf16>, %[[Cp:.+]]: memref<1024x1024xf16>
25+
26+
// CHECK-NOT: memref.alloc()
27+
28+
// CHECK: scf.parallel (%[[iter1:.+]], %[[iter2:.+]]) = (%c0, %c0) to (%c1024, %c1024) step (%c32, %c32) {
29+
// CHECK: %[[C:.+]] = memref.subview %[[Cp]][%[[iter1]], %[[iter2]]] {{.*}}
30+
// CHECK: %[[A:.+]] = memref.subview %[[Ap]][%[[iter1]], 0] {{.*}}
31+
// CHECK: %[[B:.+]] = memref.subview %[[Bp]][%[[iter2]], 0] {{.*}}
32+
33+
// CHECK-NOT: linalg.transpose
34+
35+
// Create output initial value load tiles.
36+
// CHECK-DAG: %[[rootC:.+]] = xegpu.create_nd_tdesc %[[C]]
37+
// CHECK: %[[tC:.+]] = xegpu.update_nd_offset %[[rootC]], [%c0, %c0]
38+
// CHECK-COUNT-7: xegpu.update_nd_offset
39+
40+
// Load initial accumulator values.
41+
// CHECK-DAG: %[[vC:.+]] = xegpu.load_nd %[[tC]]
42+
// CHECK-COUNT-7: xegpu.load_nd
43+
44+
// Extend the type to match DPAS output precision.
45+
// CHECK: %[[vC_f32:.+]] = arith.extf %[[vC]]
46+
// CHECK-COUNT-7: arith.extf
47+
48+
// Create input load tiles.
49+
// CHECK: %[[rootA:.+]] = xegpu.create_nd_tdesc %[[A]]
50+
// CHECK: %[[tA:.+]] = xegpu.update_nd_offset %[[rootA]], [%c0, %c0]
51+
// CHECK: %[[rootB:.+]] = xegpu.create_nd_tdesc %[[B]]
52+
// CHECK: %[[tB:.+]] = xegpu.update_nd_offset %[[rootB]], [%c0, %c0]
53+
// CHECK: %[[tB1:.+]] = xegpu.update_nd_offset %[[rootB]], [%c16, %c0]
54+
55+
// Create DPAS computation loop over tiled reduction dimension.
56+
// CHECK: %[[res:.+]]:11 = scf.for{{.*}}%c0 to %c1024 step %c16
57+
// CHECK-SAME: iter_args(%[[acc:.+]] = %[[vC_f32]],{{.*}}%[[iterA:.+]] = %[[tA]],{{.*}}%[[iterB:.+]] = %[[tB]],{{.*}}%[[iterB1:.+]] = %[[tB1]]
58+
// CHECK-SAME: {
59+
60+
// Load input values and update the load tile position.
61+
// CHECK: %[[vA:.+]] = xegpu.load_nd %[[iterA]]
62+
// CHECK: %[[vB:.+]] = xegpu.load_nd %[[iterB]] {{.*}}transpose = array<i64: 1, 0>{{.*}}transpose_bit_width = 32 : i32{{.*}}
63+
// CHECK: %[[vB1:.+]] = xegpu.load_nd %[[iterB1]] {{.*}}transpose = array<i64: 1, 0>, transpose_bit_width = 32 : i32{{.*}}
64+
65+
// CHECK: %[[new_tA:.+]] = xegpu.update_nd_offset %[[iterA]], [%c0, %c16]
66+
// CHECK: %[[new_tB:.+]] = xegpu.update_nd_offset %[[iterB]], [%c0, %c16]
67+
// CHECK: %[[new_tB1:.+]] = xegpu.update_nd_offset %[[iterB1]], [%c0, %c16]
68+
69+
// Apply simple prefetching scheme - start loading the next set of input
70+
// tiles before computation is started.
71+
// CHECK: xegpu.prefetch_nd %[[new_tA]]
72+
// CHECK: xegpu.prefetch_nd %[[new_tB]]
73+
// CHECK: xegpu.prefetch_nd %[[new_tB1]]
74+
75+
// Extract DPAS-sized chunks from larger loaded tile A.
76+
// Tile B is already in the correct shape.
77+
// CHECK: %[[vA_flat:.+]] = vector.shape_cast %[[vA]] : vector<32x16xf16> to vector<512xf16>
78+
// CHECK: %[[vA_dpas_flat:.+]] = vector.extract_strided_slice{{.*}}: vector<512xf16> to vector<128xf16>
79+
// CHECK: %[[vA_dpas:.+]] = vector.shape_cast %[[vA_dpas_flat]] : vector<128xf16> to vector<8x8x2xf16>
80+
// CHECK-COUNT-3: vector.extract_strided_slice
81+
82+
// Perform DPAS computation.
83+
// CHECK: %[[dpas:.+]] = xegpu.dpas %[[vA_dpas]], %[[vB]], %[[acc]]
84+
// CHECK-COUNT-7: xegpu.dpas
85+
86+
// CHECK-NOT: memref.dealloc()

0 commit comments

Comments
 (0)