Skip to content

Commit 866d2ee

Browse files
author
Haixin Huang
authored
[Transform] Refinements on microkernel dialect lowering (#324)
* [WIP] minor fix & improvements * fix allow dynamic stride * change linalgx frontend to linalg.generic * fix clang-tify * fix deepTileContractionOp * fix clang-tidy
1 parent 73412a2 commit 866d2ee

File tree

8 files changed

+225
-80
lines changed

8 files changed

+225
-80
lines changed

include/gc/Transforms/Utils/ValueUtils.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,11 @@ bool isValConstZero(Value val);
2020
// Returns true if the op defining `val` represents a zero filled tensor.
2121
bool isZeroTensor(Value val);
2222

23-
// Returns the strides of `val`. The method returns something usefull
23+
// Returns the strides of `val`. The method returns something useful
24+
// only if the `val` type is a strided memref.
25+
FailureOr<SmallVector<int64_t>> getStrides(Value val);
26+
27+
// Returns the strides of `val`. The method returns something useful
2428
// only if the `val` type is a strided memref and the strides are statically
2529
// known.
2630
FailureOr<SmallVector<int64_t>> getStaticStrides(Value val);

lib/gc/Transforms/DeepTileContractionOp.cpp

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -833,21 +833,16 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
833833
loc, resultOprand.getType(), ValueRange{dataOprand, weightOprand},
834834
resultOprand);
835835
} else {
836-
// TODO: replace liangx brgemm with the generic in the comment when
837-
// microkernel is ready
838-
matmul = rewriter.create<linalgx::BatchReduceMatmulVnniOp>(
839-
loc, resultOprand.getType(), ValueRange{dataOprand, weightOprand},
840-
resultOprand);
841-
842-
// auto inputRange = ValueRange{dataOprand, weightOprand};
843-
// auto resRange = ValueRange{resultOprand};
844-
// auto res = linalgx::makeGenericPackedMatmulOp(
845-
// rewriter, loc, linalgx::PackingType::VNNI_BRMM3D, inputRange,
846-
// resRange);
847-
// if (succeeded(res))
848-
// matmul = *res;
849-
// else
850-
// return failure();
836+
auto inputRange = SmallVector<Value>{dataOprand, weightOprand};
837+
auto resRange = SmallVector<Value>{resultOprand};
838+
839+
auto res = linalgx::makeGenericPackedMatmulOp(
840+
rewriter, loc, linalgx::PackingType::VNNI_BRMM3D, inputRange,
841+
resRange);
842+
if (succeeded(res))
843+
matmul = *res;
844+
else
845+
return failure();
851846
}
852847

853848
Value result = matmul.getOperation()->getResult(0);
@@ -1046,4 +1041,4 @@ struct DeepTileContractionOp
10461041

10471042
} // namespace
10481043
} // namespace gc
1049-
} // namespace mlir
1044+
} // namespace mlir

lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
#include "mlir/Support/LogicalResult.h"
2222
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2323

24-
#include "gc/Dialect/Linalgx/LinalgxOps.h"
24+
#include "gc/Dialect/Linalgx/Utils.h"
2525
#include "gc/Transforms/Microkernel/MicrokernelPasses.h"
2626
#include "gc/Transforms/Utils/StructuredOpMatcher.h"
2727
#include "gc/Transforms/Utils/ValueUtils.h"
@@ -53,7 +53,8 @@ customInferContractionDims(linalg::LinalgOp linalgOp) {
5353
auto dims = linalg::inferContractionDims(linalgOp);
5454
if (failed(dims))
5555
return failure();
56-
if (llvm::isa<linalgx::BatchReduceMatmulVnniOp>(linalgOp)) {
56+
if (linalgx::isGenericPackedMatmulOp(linalgOp,
57+
linalgx::PackingType::VNNI_BRMM3D)) {
5758
// For VnniOp, the K reduction dims (dim index 3 & 4) cannot be infered by
5859
// linalg utils because they form complex affine in operand A; Manually add
5960
// them here
@@ -338,7 +339,7 @@ static bool checkFusibleFillOp(DenseMap<Value, Value> &replaceMap,
338339
bool fuseFill = false;
339340
Value operandC = op.getDpsInitsMutable()[0].get();
340341
auto defOp = operandC.getDefiningOp();
341-
if (auto fillOp = dyn_cast<linalg::FillOp>(defOp)) {
342+
if (auto fillOp = dyn_cast_or_null<linalg::FillOp>(defOp)) {
342343
auto inputCst = dyn_cast_or_null<arith::ConstantOp>(
343344
fillOp.getInputs()[0].getDefiningOp());
344345
if (isZeroArithConstant(inputCst)) {
@@ -356,6 +357,10 @@ class ConvertContractionOpToBrgemmRewriter
356357
using OpRewritePattern<ContractionOp>::OpRewritePattern;
357358
LogicalResult matchAndRewrite(ContractionOp op,
358359
PatternRewriter &rewriter) const final {
360+
if (!isa<linalg::BatchReduceMatmulOp>(op) &&
361+
!linalgx::isGenericPackedMatmulOp(op,
362+
linalgx::PackingType::VNNI_BRMM3D))
363+
return failure();
359364
if (!op.hasPureTensorSemantics())
360365
return failure();
361366

@@ -384,8 +389,7 @@ class ConvertLinalgToMicrokernel
384389
patterns
385390
.add<ConvertContractionOpToBrgemmRewriter<linalg::BatchReduceMatmulOp>>(
386391
&getContext());
387-
patterns.add<
388-
ConvertContractionOpToBrgemmRewriter<linalgx::BatchReduceMatmulVnniOp>>(
392+
patterns.add<ConvertContractionOpToBrgemmRewriter<linalg::GenericOp>>(
389393
&getContext());
390394
FrozenRewritePatternSet patternSet(std::move(patterns));
391395
if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet)))

lib/gc/Transforms/Microkernel/ExpandMicrokernel.cpp

Lines changed: 54 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,28 @@ struct BrgemmInfo {
5151
BrgemmMode mode;
5252
};
5353

54+
// This method try to retrieve static strides from MemRef, and allow dynamic
55+
// strides if corresponding dims == `1` and they are batch/leading dims. Would
56+
// place `INT_MAX` in corresponding stride position.
57+
static FailureOr<SmallVector<int64_t>>
58+
getCompensatedStrides(ArrayRef<int64_t> shape, Value val, int64_t batchDim,
59+
int64_t leadingDim) {
60+
auto strides = utils::getStrides(val);
61+
if (failed(strides))
62+
return failure();
63+
for (size_t idx = 0; idx < strides->size(); idx++) {
64+
if ((*strides)[idx] == ShapedType::kDynamic) {
65+
if (idx != (size_t)batchDim || idx != (size_t)leadingDim)
66+
return failure();
67+
// We can ignore the stride if dim == 1 (no need to step)
68+
if (shape[idx] != 1)
69+
return failure();
70+
(*strides)[idx] = LONG_MAX;
71+
}
72+
}
73+
return strides;
74+
}
75+
5476
static FailureOr<BrgemmInfo> inferBrgemmInfo(microkernel::BrgemmOp brgemmOp) {
5577
Value operandA = brgemmOp.getOperandA();
5678
Value operandB = brgemmOp.getOperandB();
@@ -82,66 +104,57 @@ static FailureOr<BrgemmInfo> inferBrgemmInfo(microkernel::BrgemmOp brgemmOp) {
82104
return {batchDimSize, leadingDimSize, minorDimSize};
83105
};
84106

85-
auto checkAndGetLdStride = [&](int64_t leadingDim,
86-
Value operand) -> FailureOr<int64_t> {
107+
auto checkAndGetStride =
108+
[&](int64_t batchDim, int64_t leadingDim,
109+
Value operand) -> FailureOr<std::pair<int64_t, int64_t>> {
87110
auto operandShape = checkTypeAndGetShape(operand);
88111
if (failed(operandShape))
89112
return failure();
90-
auto stridesOnOperand = utils::getStaticStrides(operand);
113+
auto stridesOnOperand =
114+
getCompensatedStrides(*operandShape, operand, batchDim, leadingDim);
91115
if (failed(stridesOnOperand))
92116
return failure();
93117
auto leadingDimStride = (*stridesOnOperand)[leadingDim];
94118
if (operandShape->size() == 4)
95119
// Input B VNNI format exists, special treatment to align with non-VNNI
96120
// format
97-
return leadingDimStride / (*operandShape)[3];
98-
return leadingDimStride;
99-
};
100-
101-
auto checkAndGetBatchStride = [&](int64_t batchDim,
102-
Value operand) -> FailureOr<int64_t> {
103-
auto stridesOnOperand = utils::getStaticStrides(operand);
104-
if (failed(stridesOnOperand))
105-
return failure();
106-
return (*stridesOnOperand)[batchDim];
121+
return std::pair<int64_t, int64_t>{(*stridesOnOperand)[batchDim],
122+
leadingDimStride / (*operandShape)[3]};
123+
return std::pair<int64_t, int64_t>{(*stridesOnOperand)[batchDim],
124+
leadingDimStride};
107125
};
108126

109127
// A(m, k)
110128
auto batchDimA = brgemmOp.getBatchDimA();
111129
auto leadingDimA = brgemmOp.getLeadingDimA();
112130
auto [batchA, M, KA] = checkAndGetDimSize(batchDimA, leadingDimA, operandA);
113-
auto lda = checkAndGetLdStride(leadingDimA, operandA);
114-
if (failed(batchA) || failed(M) || failed(KA) || failed(lda))
131+
auto strideA = checkAndGetStride(batchDimA, leadingDimA, operandA);
132+
if (failed(batchA) || failed(M) || failed(KA) || failed(strideA))
115133
return failure();
116134
LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmInfo] M, K, Lda for A: " << *M << ", "
117-
<< *KA << ", " << *lda << "\n");
135+
<< *KA << ", " << strideA->first << ", "
136+
<< strideA->second << "\n");
118137

119138
// B(k, n)
120139
auto batchDimB = brgemmOp.getBatchDimB();
121140
auto leadingDimB = brgemmOp.getLeadingDimB();
122141
auto [batchB, KB, N] = checkAndGetDimSize(batchDimB, leadingDimB, operandB);
123-
auto ldb = checkAndGetLdStride(leadingDimB, operandB);
124-
if (failed(batchB) || failed(KB) || failed(N) || failed(ldb))
142+
auto strideB = checkAndGetStride(batchDimB, leadingDimB, operandB);
143+
if (failed(batchB) || failed(KB) || failed(N) || failed(strideB))
125144
return failure();
126145
LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmInfo] K, N, Ldb for B: " << *KB
127-
<< ", " << *N << ", " << *ldb << "\n");
146+
<< ", " << *N << ", " << strideB->first << ", "
147+
<< strideB->second << "\n");
128148
assert(*batchA == *batchB && *KA == *KB &&
129149
"Expecting matching shapes of inputs");
130150

131151
// C(m, n)
132-
auto ldc = checkAndGetLdStride(0, operandC);
133-
if (failed(ldc))
134-
return failure();
135-
LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmInfo] Ld stride on C: " << ldc
136-
<< "\n");
137-
138-
auto strideA = checkAndGetBatchStride(brgemmOp.getBatchDimA(), operandA);
139-
if (failed(strideA))
140-
return failure();
141-
142-
auto strideB = checkAndGetBatchStride(brgemmOp.getBatchDimB(), operandB);
143-
if (failed(strideB))
152+
// Put irrelevant value in parameter `batchDim` for C as we don't need it
153+
auto strideC = checkAndGetStride(0, 0, operandC);
154+
if (failed(strideC))
144155
return failure();
156+
LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmInfo] Ld stride on C: "
157+
<< strideC->second << "\n");
145158

146159
bool isInit = false;
147160
auto flags = brgemmOp.getFlagsAttr();
@@ -157,19 +170,21 @@ static FailureOr<BrgemmInfo> inferBrgemmInfo(microkernel::BrgemmOp brgemmOp) {
157170

158171
LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmInfo] final BrgemmInfo: m(" << *M
159172
<< "), n(" << *N << "), k(" << *KB << "), batch("
160-
<< *batchA << "), lda(" << *lda << "), ldb(" << *ldb
161-
<< "), ldc(" << *ldc << "), strideA(" << *strideA
162-
<< "), strideB(" << *strideB << ")\n");
173+
<< *batchA << "), lda(" << strideA->second
174+
<< "), ldb(" << strideB->second << "), ldc("
175+
<< strideC->second << "), batchStrideA("
176+
<< strideA->first << "), batchStrideB("
177+
<< strideB->first << ")\n");
163178
BrgemmInfo info{*M,
164179
*N,
165180
*KA,
166181
*batchA,
167182
0 /* addrLen useless under stride mode */,
168-
*lda,
169-
*ldb,
170-
*ldc,
171-
*strideA,
172-
*strideB,
183+
strideA->second,
184+
strideB->second,
185+
strideC->second,
186+
strideA->first,
187+
strideB->first,
173188
isInit,
174189
BrgemmInfo::STRIDE_MODE};
175190
return info;

lib/gc/Transforms/Utils/ValueUtils.cpp

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -98,21 +98,26 @@ static bool isZeroOp(Operation *defOp) {
9898
.Default([&](Operation *op) { return false; });
9999
}
100100

101-
FailureOr<SmallVector<int64_t>> getStaticStrides(Value value) {
101+
FailureOr<SmallVector<int64_t>> getStrides(Value value) {
102102
auto valueType = value.getType();
103103
if (!isa<MemRefType>(valueType))
104104
return failure();
105105
auto memrefType = cast<MemRefType>(valueType);
106106
SmallVector<int64_t> strides;
107107
int64_t offset;
108-
if (failed(getStridesAndOffset(memrefType, strides, offset))) {
108+
if (failed(getStridesAndOffset(memrefType, strides, offset)))
109109
return failure();
110-
}
111-
if (llvm::any_of(strides, [](int64_t stride) {
110+
return strides;
111+
}
112+
113+
FailureOr<SmallVector<int64_t>> getStaticStrides(Value value) {
114+
auto strides = getStrides(value);
115+
if (failed(strides))
116+
return failure();
117+
if (llvm::any_of(*strides, [](int64_t stride) {
112118
return stride == ShapedType::kDynamic;
113-
})) {
119+
}))
114120
return failure();
115-
}
116121
return strides;
117122
}
118123

test/mlir/test/gc/Dialect/Microkernel/expand-microkernel.mlir

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,3 +167,71 @@ func.func @transpose_expand_microkernel_init_vnni() {
167167
// CHECK-NEXT: microkernel.brgemm.epilogue(%[[DIS]]) : (i64) -> ()
168168

169169
// -----
170+
171+
#map = affine_map<(d0) -> (-d0 + 344, 11)>
172+
#map1 = affine_map<(d0)[s0] -> (-d0 + s0, 8)>
173+
#map2 = affine_map<()[s0, s1, s2] -> (s0 + s1 + s2)>
174+
#map3 = affine_map<(d0, d1) -> (d0, d1)>
175+
module {
176+
func.func @expand_microkernel_with_dynamic(%arg0: memref<1x128x1x32xbf16>, %arg1: memref<344x128x16x32x2xbf16>, %arg2: memref<1x344x1x32xbf16>) attributes {llvm.emit_c_interface} {
177+
%c1 = arith.constant 1 : index
178+
%c64 = arith.constant 64 : index
179+
%c128 = arith.constant 128 : index
180+
%c8 = arith.constant 8 : index
181+
%c0 = arith.constant 0 : index
182+
scf.forall (%arg3) = (0) to (344) step (11) {
183+
%0 = affine.min #map(%arg3)
184+
%subview = memref.subview %arg2[0, %arg3, 0, 0] [1, %0, 1, 32] [1, 1, 1, 1] : memref<1x344x1x32xbf16> to memref<1x?x1x32xbf16, strided<[11008, 32, 32, 1], offset: ?>>
185+
scf.for %arg4 = %c0 to %0 step %c8 {
186+
%1 = affine.min #map1(%arg4)[%0]
187+
%subview_0 = memref.subview %subview[0, %arg4, 0, 0] [1, %1, 1, 32] [1, 1, 1, 1] : memref<1x?x1x32xbf16, strided<[11008, 32, 32, 1], offset: ?>> to memref<1x?x1x32xbf16, strided<[11008, 32, 32, 1], offset: ?>>
188+
%alloc = memref.alloc(%1) {alignment = 64 : i64} : memref<1x?x1x32xf32>
189+
scf.for %arg5 = %c0 to %c128 step %c64 {
190+
%subview_1 = memref.subview %alloc[0, 0, 0, 0] [1, %1, 1, 32] [1, 1, 1, 1] : memref<1x?x1x32xf32> to memref<1x?x1x32xf32, strided<[?, 32, 32, 1]>>
191+
%subview_2 = memref.subview %arg0[0, %arg5, 0, 0] [1, 64, 1, 32] [1, 1, 1, 1] : memref<1x128x1x32xbf16> to memref<64x1x32xbf16, strided<[32, 32, 1], offset: ?>>
192+
%2 = arith.cmpi eq, %arg5, %c0 : index
193+
%3 = arith.addi %arg5, %c64 : index
194+
%4 = arith.cmpi sge, %3, %c128 : index
195+
scf.for %arg6 = %c0 to %1 step %c1 {
196+
%5 = affine.apply #map2()[%arg3, %arg6, %arg4]
197+
%subview_3 = memref.subview %arg1[%5, %arg5, 0, 0, 0] [1, 64, 16, 32, 2] [1, 1, 1, 1, 1] : memref<344x128x16x32x2xbf16> to memref<64x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>
198+
%subview_4 = memref.subview %subview_1[0, %arg6, 0, 0] [1, 1, 1, 32] [1, 1, 1, 1] : memref<1x?x1x32xf32, strided<[?, 32, 32, 1]>> to memref<1x32xf32, strided<[?, 1], offset: ?>>
199+
%subview_5 = memref.subview %subview_0[0, %arg6, 0, 0] [1, 1, 1, 32] [1, 1, 1, 1] : memref<1x?x1x32xbf16, strided<[11008, 32, 32, 1], offset: ?>> to memref<1x32xbf16, strided<[11008, 1], offset: ?>>
200+
scf.if %2 {
201+
microkernel.brgemm ins(%subview_2, %subview_3 : memref<64x1x32xbf16, strided<[32, 32, 1], offset: ?>>, memref<64x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>) outs(%subview_4 : memref<1x32xf32, strided<[?, 1], offset: ?>>) batch_dims(0, 0) leading_dims(1, 1) flags(beta_0)
202+
} else {
203+
microkernel.brgemm ins(%subview_2, %subview_3 : memref<64x1x32xbf16, strided<[32, 32, 1], offset: ?>>, memref<64x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>) outs(%subview_4 : memref<1x32xf32, strided<[?, 1], offset: ?>>) batch_dims(0, 0) leading_dims(1, 1) flags()
204+
}
205+
scf.if %4 {
206+
linalg.generic {indexing_maps = [#map3, #map3], iterator_types = ["parallel", "parallel"]} ins(%subview_4 : memref<1x32xf32, strided<[?, 1], offset: ?>>) outs(%subview_5 : memref<1x32xbf16, strided<[11008, 1], offset: ?>>) {
207+
^bb0(%in: f32, %out: bf16):
208+
%6 = arith.truncf %in : f32 to bf16
209+
linalg.yield %6 : bf16
210+
}
211+
}
212+
}
213+
}
214+
memref.dealloc %alloc : memref<1x?x1x32xf32>
215+
}
216+
}
217+
return
218+
}
219+
}
220+
221+
// CHECK-LABEL: expand_microkernel_with_dynamic
222+
// CHECK: scf.forall (%[[ARG:.+]]) = (0) to (344) step (11)
223+
// CHECK: scf.for %[[ARG2:.+]] = %[[CST0:.+]] to %[[AFF:.+]] step %[[CST8:.+]]
224+
// CHECK: scf.for %[[ARG3:.+]] = %[[CST0]] to %[[CST128:.+]] step %[[CST64:.+]]
225+
// CHECK: scf.for %[[ARG4:.+]] = %[[CST0]] to %[[AFF1:.+]] step %[[CST1:.+]]
226+
// CHECK: scf.if
227+
// CHECK: %[[DIS:.+]] = microkernel.brgemm.dispatch [1, 32, 32, 32, 32, 9223372036854775807, 32, 1024] flags(beta_0, stride) data_type(bf16, bf16)
228+
// CHECK-NEXT: microkernel.brgemm.prologue(%[[DIS]]) : (i64) -> ()
229+
// CHECK-NEXT: microkernel.brgemm.execute(%[[DIS]]
230+
// CHECK-NEXT: microkernel.brgemm.epilogue(%[[DIS]]) : (i64) -> ()
231+
// CHECK: else
232+
// CHECK: %[[DIS2:.+]] = microkernel.brgemm.dispatch [1, 32, 32, 32, 32, 9223372036854775807, 32, 1024] flags(stride) data_type(bf16, bf16)
233+
// CHECK-NEXT: microkernel.brgemm.prologue(%[[DIS2]]) : (i64) -> ()
234+
// CHECK-NEXT: microkernel.brgemm.execute(%[[DIS]]
235+
// CHECK-NEXT: microkernel.brgemm.epilogue(%[[DIS2]]) : (i64) -> ()
236+
237+
// -----

0 commit comments

Comments
 (0)