Skip to content

[Transform] Refinements on microkernel dialect lowering #324

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion include/gc/Transforms/Utils/ValueUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@ bool isValConstZero(Value val);
// Returns true if the op defining `val` represents a zero filled tensor.
bool isZeroTensor(Value val);

// Returns the strides of `val`. The method returns something usefull
// Returns the strides of `val`. The method returns something useful
// only if the `val` type is a strided memref.
FailureOr<SmallVector<int64_t>> getStrides(Value val);

// Returns the strides of `val`. The method returns something useful
// only if the `val` type is a strided memref and the strides are statically
// known.
FailureOr<SmallVector<int64_t>> getStaticStrides(Value val);
Expand Down
27 changes: 11 additions & 16 deletions lib/gc/Transforms/DeepTileContractionOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -833,21 +833,16 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
loc, resultOprand.getType(), ValueRange{dataOprand, weightOprand},
resultOprand);
} else {
// TODO: replace liangx brgemm with the generic in the comment when
// microkernel is ready
matmul = rewriter.create<linalgx::BatchReduceMatmulVnniOp>(
loc, resultOprand.getType(), ValueRange{dataOprand, weightOprand},
resultOprand);

// auto inputRange = ValueRange{dataOprand, weightOprand};
// auto resRange = ValueRange{resultOprand};
// auto res = linalgx::makeGenericPackedMatmulOp(
// rewriter, loc, linalgx::PackingType::VNNI_BRMM3D, inputRange,
// resRange);
// if (succeeded(res))
// matmul = *res;
// else
// return failure();
auto inputRange = SmallVector<Value>{dataOprand, weightOprand};
auto resRange = SmallVector<Value>{resultOprand};

auto res = linalgx::makeGenericPackedMatmulOp(
rewriter, loc, linalgx::PackingType::VNNI_BRMM3D, inputRange,
resRange);
if (succeeded(res))
matmul = *res;
else
return failure();
}

Value result = matmul.getOperation()->getResult(0);
Expand Down Expand Up @@ -1046,4 +1041,4 @@ struct DeepTileContractionOp

} // namespace
} // namespace gc
} // namespace mlir
} // namespace mlir
14 changes: 9 additions & 5 deletions lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#include "gc/Dialect/Linalgx/LinalgxOps.h"
#include "gc/Dialect/Linalgx/Utils.h"
#include "gc/Transforms/Microkernel/MicrokernelPasses.h"
#include "gc/Transforms/Utils/StructuredOpMatcher.h"
#include "gc/Transforms/Utils/ValueUtils.h"
Expand Down Expand Up @@ -53,7 +53,8 @@ customInferContractionDims(linalg::LinalgOp linalgOp) {
auto dims = linalg::inferContractionDims(linalgOp);
if (failed(dims))
return failure();
if (llvm::isa<linalgx::BatchReduceMatmulVnniOp>(linalgOp)) {
if (linalgx::isGenericPackedMatmulOp(linalgOp,
linalgx::PackingType::VNNI_BRMM3D)) {
// For VnniOp, the K reduction dims (dim index 3 & 4) cannot be infered by
// linalg utils because they form complex affine in operand A; Manually add
// them here
Expand Down Expand Up @@ -338,7 +339,7 @@ static bool checkFusibleFillOp(DenseMap<Value, Value> &replaceMap,
bool fuseFill = false;
Value operandC = op.getDpsInitsMutable()[0].get();
auto defOp = operandC.getDefiningOp();
if (auto fillOp = dyn_cast<linalg::FillOp>(defOp)) {
if (auto fillOp = dyn_cast_or_null<linalg::FillOp>(defOp)) {
auto inputCst = dyn_cast_or_null<arith::ConstantOp>(
fillOp.getInputs()[0].getDefiningOp());
if (isZeroArithConstant(inputCst)) {
Expand All @@ -356,6 +357,10 @@ class ConvertContractionOpToBrgemmRewriter
using OpRewritePattern<ContractionOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ContractionOp op,
PatternRewriter &rewriter) const final {
if (!isa<linalg::BatchReduceMatmulOp>(op) &&
!linalgx::isGenericPackedMatmulOp(op,
linalgx::PackingType::VNNI_BRMM3D))
return failure();
if (!op.hasPureTensorSemantics())
return failure();

Expand Down Expand Up @@ -384,8 +389,7 @@ class ConvertLinalgToMicrokernel
patterns
.add<ConvertContractionOpToBrgemmRewriter<linalg::BatchReduceMatmulOp>>(
&getContext());
patterns.add<
ConvertContractionOpToBrgemmRewriter<linalgx::BatchReduceMatmulVnniOp>>(
patterns.add<ConvertContractionOpToBrgemmRewriter<linalg::GenericOp>>(
&getContext());
FrozenRewritePatternSet patternSet(std::move(patterns));
if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet)))
Expand Down
93 changes: 54 additions & 39 deletions lib/gc/Transforms/Microkernel/ExpandMicrokernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,28 @@ struct BrgemmInfo {
BrgemmMode mode;
};

// This method try to retrieve static strides from MemRef, and allow dynamic
// strides if corresponding dims == `1` and they are batch/leading dims. Would
// place `INT_MAX` in corresponding stride position.
static FailureOr<SmallVector<int64_t>>
getCompensatedStrides(ArrayRef<int64_t> shape, Value val, int64_t batchDim,
int64_t leadingDim) {
auto strides = utils::getStrides(val);
if (failed(strides))
return failure();
for (size_t idx = 0; idx < strides->size(); idx++) {
if ((*strides)[idx] == ShapedType::kDynamic) {
if (idx != (size_t)batchDim || idx != (size_t)leadingDim)
return failure();
// We can ignore the stride if dim == 1 (no need to step)
if (shape[idx] != 1)
return failure();
(*strides)[idx] = LONG_MAX;
}
}
return strides;
}

static FailureOr<BrgemmInfo> inferBrgemmInfo(microkernel::BrgemmOp brgemmOp) {
Value operandA = brgemmOp.getOperandA();
Value operandB = brgemmOp.getOperandB();
Expand Down Expand Up @@ -82,66 +104,57 @@ static FailureOr<BrgemmInfo> inferBrgemmInfo(microkernel::BrgemmOp brgemmOp) {
return {batchDimSize, leadingDimSize, minorDimSize};
};

auto checkAndGetLdStride = [&](int64_t leadingDim,
Value operand) -> FailureOr<int64_t> {
auto checkAndGetStride =
[&](int64_t batchDim, int64_t leadingDim,
Value operand) -> FailureOr<std::pair<int64_t, int64_t>> {
auto operandShape = checkTypeAndGetShape(operand);
if (failed(operandShape))
return failure();
auto stridesOnOperand = utils::getStaticStrides(operand);
auto stridesOnOperand =
getCompensatedStrides(*operandShape, operand, batchDim, leadingDim);
if (failed(stridesOnOperand))
return failure();
auto leadingDimStride = (*stridesOnOperand)[leadingDim];
if (operandShape->size() == 4)
// Input B VNNI format exists, special treatment to align with non-VNNI
// format
return leadingDimStride / (*operandShape)[3];
return leadingDimStride;
};

auto checkAndGetBatchStride = [&](int64_t batchDim,
Value operand) -> FailureOr<int64_t> {
auto stridesOnOperand = utils::getStaticStrides(operand);
if (failed(stridesOnOperand))
return failure();
return (*stridesOnOperand)[batchDim];
return std::pair<int64_t, int64_t>{(*stridesOnOperand)[batchDim],
leadingDimStride / (*operandShape)[3]};
return std::pair<int64_t, int64_t>{(*stridesOnOperand)[batchDim],
leadingDimStride};
};

// A(m, k)
auto batchDimA = brgemmOp.getBatchDimA();
auto leadingDimA = brgemmOp.getLeadingDimA();
auto [batchA, M, KA] = checkAndGetDimSize(batchDimA, leadingDimA, operandA);
auto lda = checkAndGetLdStride(leadingDimA, operandA);
if (failed(batchA) || failed(M) || failed(KA) || failed(lda))
auto strideA = checkAndGetStride(batchDimA, leadingDimA, operandA);
if (failed(batchA) || failed(M) || failed(KA) || failed(strideA))
return failure();
LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmInfo] M, K, Lda for A: " << *M << ", "
<< *KA << ", " << *lda << "\n");
<< *KA << ", " << strideA->first << ", "
<< strideA->second << "\n");

// B(k, n)
auto batchDimB = brgemmOp.getBatchDimB();
auto leadingDimB = brgemmOp.getLeadingDimB();
auto [batchB, KB, N] = checkAndGetDimSize(batchDimB, leadingDimB, operandB);
auto ldb = checkAndGetLdStride(leadingDimB, operandB);
if (failed(batchB) || failed(KB) || failed(N) || failed(ldb))
auto strideB = checkAndGetStride(batchDimB, leadingDimB, operandB);
if (failed(batchB) || failed(KB) || failed(N) || failed(strideB))
return failure();
LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmInfo] K, N, Ldb for B: " << *KB
<< ", " << *N << ", " << *ldb << "\n");
<< ", " << *N << ", " << strideB->first << ", "
<< strideB->second << "\n");
assert(*batchA == *batchB && *KA == *KB &&
"Expecting matching shapes of inputs");

// C(m, n)
auto ldc = checkAndGetLdStride(0, operandC);
if (failed(ldc))
return failure();
LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmInfo] Ld stride on C: " << ldc
<< "\n");

auto strideA = checkAndGetBatchStride(brgemmOp.getBatchDimA(), operandA);
if (failed(strideA))
return failure();

auto strideB = checkAndGetBatchStride(brgemmOp.getBatchDimB(), operandB);
if (failed(strideB))
// Put irrelevant value in parameter `batchDim` for C as we don't need it
auto strideC = checkAndGetStride(0, 0, operandC);
if (failed(strideC))
return failure();
LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmInfo] Ld stride on C: "
<< strideC->second << "\n");

bool isInit = false;
auto flags = brgemmOp.getFlagsAttr();
Expand All @@ -157,19 +170,21 @@ static FailureOr<BrgemmInfo> inferBrgemmInfo(microkernel::BrgemmOp brgemmOp) {

LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmInfo] final BrgemmInfo: m(" << *M
<< "), n(" << *N << "), k(" << *KB << "), batch("
<< *batchA << "), lda(" << *lda << "), ldb(" << *ldb
<< "), ldc(" << *ldc << "), strideA(" << *strideA
<< "), strideB(" << *strideB << ")\n");
<< *batchA << "), lda(" << strideA->second
<< "), ldb(" << strideB->second << "), ldc("
<< strideC->second << "), batchStrideA("
<< strideA->first << "), batchStrideB("
<< strideB->first << ")\n");
BrgemmInfo info{*M,
*N,
*KA,
*batchA,
0 /* addrLen useless under stride mode */,
*lda,
*ldb,
*ldc,
*strideA,
*strideB,
strideA->second,
strideB->second,
strideC->second,
strideA->first,
strideB->first,
isInit,
BrgemmInfo::STRIDE_MODE};
return info;
Expand Down
17 changes: 11 additions & 6 deletions lib/gc/Transforms/Utils/ValueUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,21 +98,26 @@ static bool isZeroOp(Operation *defOp) {
.Default([&](Operation *op) { return false; });
}

FailureOr<SmallVector<int64_t>> getStaticStrides(Value value) {
FailureOr<SmallVector<int64_t>> getStrides(Value value) {
auto valueType = value.getType();
if (!isa<MemRefType>(valueType))
return failure();
auto memrefType = cast<MemRefType>(valueType);
SmallVector<int64_t> strides;
int64_t offset;
if (failed(getStridesAndOffset(memrefType, strides, offset))) {
if (failed(getStridesAndOffset(memrefType, strides, offset)))
return failure();
}
if (llvm::any_of(strides, [](int64_t stride) {
return strides;
}

FailureOr<SmallVector<int64_t>> getStaticStrides(Value value) {
auto strides = getStrides(value);
if (failed(strides))
return failure();
if (llvm::any_of(*strides, [](int64_t stride) {
return stride == ShapedType::kDynamic;
})) {
}))
return failure();
}
return strides;
}

Expand Down
68 changes: 68 additions & 0 deletions test/mlir/test/gc/Dialect/Microkernel/expand-microkernel.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,71 @@ func.func @transpose_expand_microkernel_init_vnni() {
// CHECK-NEXT: microkernel.brgemm.epilogue(%[[DIS]]) : (i64) -> ()

// -----

#map = affine_map<(d0) -> (-d0 + 344, 11)>
#map1 = affine_map<(d0)[s0] -> (-d0 + s0, 8)>
#map2 = affine_map<()[s0, s1, s2] -> (s0 + s1 + s2)>
#map3 = affine_map<(d0, d1) -> (d0, d1)>
module {
func.func @expand_microkernel_with_dynamic(%arg0: memref<1x128x1x32xbf16>, %arg1: memref<344x128x16x32x2xbf16>, %arg2: memref<1x344x1x32xbf16>) attributes {llvm.emit_c_interface} {
%c1 = arith.constant 1 : index
%c64 = arith.constant 64 : index
%c128 = arith.constant 128 : index
%c8 = arith.constant 8 : index
%c0 = arith.constant 0 : index
scf.forall (%arg3) = (0) to (344) step (11) {
%0 = affine.min #map(%arg3)
%subview = memref.subview %arg2[0, %arg3, 0, 0] [1, %0, 1, 32] [1, 1, 1, 1] : memref<1x344x1x32xbf16> to memref<1x?x1x32xbf16, strided<[11008, 32, 32, 1], offset: ?>>
scf.for %arg4 = %c0 to %0 step %c8 {
%1 = affine.min #map1(%arg4)[%0]
%subview_0 = memref.subview %subview[0, %arg4, 0, 0] [1, %1, 1, 32] [1, 1, 1, 1] : memref<1x?x1x32xbf16, strided<[11008, 32, 32, 1], offset: ?>> to memref<1x?x1x32xbf16, strided<[11008, 32, 32, 1], offset: ?>>
%alloc = memref.alloc(%1) {alignment = 64 : i64} : memref<1x?x1x32xf32>
scf.for %arg5 = %c0 to %c128 step %c64 {
%subview_1 = memref.subview %alloc[0, 0, 0, 0] [1, %1, 1, 32] [1, 1, 1, 1] : memref<1x?x1x32xf32> to memref<1x?x1x32xf32, strided<[?, 32, 32, 1]>>
%subview_2 = memref.subview %arg0[0, %arg5, 0, 0] [1, 64, 1, 32] [1, 1, 1, 1] : memref<1x128x1x32xbf16> to memref<64x1x32xbf16, strided<[32, 32, 1], offset: ?>>
%2 = arith.cmpi eq, %arg5, %c0 : index
%3 = arith.addi %arg5, %c64 : index
%4 = arith.cmpi sge, %3, %c128 : index
scf.for %arg6 = %c0 to %1 step %c1 {
%5 = affine.apply #map2()[%arg3, %arg6, %arg4]
%subview_3 = memref.subview %arg1[%5, %arg5, 0, 0, 0] [1, 64, 16, 32, 2] [1, 1, 1, 1, 1] : memref<344x128x16x32x2xbf16> to memref<64x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>
%subview_4 = memref.subview %subview_1[0, %arg6, 0, 0] [1, 1, 1, 32] [1, 1, 1, 1] : memref<1x?x1x32xf32, strided<[?, 32, 32, 1]>> to memref<1x32xf32, strided<[?, 1], offset: ?>>
%subview_5 = memref.subview %subview_0[0, %arg6, 0, 0] [1, 1, 1, 32] [1, 1, 1, 1] : memref<1x?x1x32xbf16, strided<[11008, 32, 32, 1], offset: ?>> to memref<1x32xbf16, strided<[11008, 1], offset: ?>>
scf.if %2 {
microkernel.brgemm ins(%subview_2, %subview_3 : memref<64x1x32xbf16, strided<[32, 32, 1], offset: ?>>, memref<64x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>) outs(%subview_4 : memref<1x32xf32, strided<[?, 1], offset: ?>>) batch_dims(0, 0) leading_dims(1, 1) flags(beta_0)
} else {
microkernel.brgemm ins(%subview_2, %subview_3 : memref<64x1x32xbf16, strided<[32, 32, 1], offset: ?>>, memref<64x16x32x2xbf16, strided<[1024, 64, 2, 1], offset: ?>>) outs(%subview_4 : memref<1x32xf32, strided<[?, 1], offset: ?>>) batch_dims(0, 0) leading_dims(1, 1) flags()
}
scf.if %4 {
linalg.generic {indexing_maps = [#map3, #map3], iterator_types = ["parallel", "parallel"]} ins(%subview_4 : memref<1x32xf32, strided<[?, 1], offset: ?>>) outs(%subview_5 : memref<1x32xbf16, strided<[11008, 1], offset: ?>>) {
^bb0(%in: f32, %out: bf16):
%6 = arith.truncf %in : f32 to bf16
linalg.yield %6 : bf16
}
}
}
}
memref.dealloc %alloc : memref<1x?x1x32xf32>
}
}
return
}
}

// CHECK-LABEL: expand_microkernel_with_dynamic
// CHECK: scf.forall (%[[ARG:.+]]) = (0) to (344) step (11)
// CHECK: scf.for %[[ARG2:.+]] = %[[CST0:.+]] to %[[AFF:.+]] step %[[CST8:.+]]
// CHECK: scf.for %[[ARG3:.+]] = %[[CST0]] to %[[CST128:.+]] step %[[CST64:.+]]
// CHECK: scf.for %[[ARG4:.+]] = %[[CST0]] to %[[AFF1:.+]] step %[[CST1:.+]]
// CHECK: scf.if
// CHECK: %[[DIS:.+]] = microkernel.brgemm.dispatch [1, 32, 32, 32, 32, 9223372036854775807, 32, 1024] flags(beta_0, stride) data_type(bf16, bf16)
// CHECK-NEXT: microkernel.brgemm.prologue(%[[DIS]]) : (i64) -> ()
// CHECK-NEXT: microkernel.brgemm.execute(%[[DIS]]
// CHECK-NEXT: microkernel.brgemm.epilogue(%[[DIS]]) : (i64) -> ()
// CHECK: else
// CHECK: %[[DIS2:.+]] = microkernel.brgemm.dispatch [1, 32, 32, 32, 32, 9223372036854775807, 32, 1024] flags(stride) data_type(bf16, bf16)
// CHECK-NEXT: microkernel.brgemm.prologue(%[[DIS2]]) : (i64) -> ()
// CHECK-NEXT: microkernel.brgemm.execute(%[[DIS]]
// CHECK-NEXT: microkernel.brgemm.epilogue(%[[DIS2]]) : (i64) -> ()

// -----
Loading