Skip to content

Commit bc0014b

Browse files
author
Haixin Huang
authored
[Transform] Introduce microkernel dialect optimization passes (#296)
1 parent 23269d7 commit bc0014b

16 files changed

+2279
-62
lines changed

include/gc/Transforms/Microkernel/MicrokernelPasses.td

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,4 +76,35 @@ def ConvertMicrokernelToDnnlFunc: Pass<"convert-microkernel-to-dnnl-func", "::ml
7676
"microkernel::MicrokernelDialect"];
7777
}
7878

79+
def EarlyDispatchMicrokernel: Pass<"early-dispatch-microkernel", "::mlir::ModuleOp"> {
80+
let summary = "Early dispatch microkernel during compile time";
81+
let description = [{
82+
Early dispatch microkernel during compile time.
83+
}];
84+
let dependentDialects = ["func::FuncDialect",
85+
"memref::MemRefDialect",
86+
"LLVM::LLVMDialect",
87+
"microkernel::MicrokernelDialect"];
88+
}
89+
90+
def MergeBranchMicrokernelContext: Pass<"merge-branch-microkernel-context", "::mlir::ModuleOp"> {
91+
let summary = "Find and merge identical microkernel context operations in branches into one";
92+
let description = [{
93+
Find and merge identical microkernel context operations in branches into one.
94+
}];
95+
let dependentDialects = ["func::FuncDialect",
96+
"memref::MemRefDialect"];
97+
}
98+
99+
def MicrokernelInvariantCodeMotion: Pass<"microkernel-invariant-code-motion", "::mlir::ModuleOp"> {
100+
let summary = "Hoist invariant microkernel code to avoid redundant execution";
101+
let description = [{
102+
Hoist invariant microkernel code to avoid redundant execution.
103+
}];
104+
let dependentDialects = ["func::FuncDialect",
105+
"memref::MemRefDialect",
106+
"LLVM::LLVMDialect",
107+
"microkernel::MicrokernelDialect"];
108+
}
109+
79110
#endif // GC_DIALECT_MICROKERNELPASSES

lib/gc/Dialect/Microkernel/MicrokernelOps.cpp

Lines changed: 53 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,42 @@ static LogicalResult verifyBrgemmFlags(ArrayAttr flags, Operation *op,
213213
return success();
214214
}
215215

216+
static bool isTypeSupported(Type outType, Type operandAType,
217+
Type operandBType) {
218+
if (!outType.isF32() && !outType.isSignedInteger(32))
219+
return false;
220+
221+
if (outType.isF32()) {
222+
if (!(operandAType.isF32() && operandBType.isF32()) &&
223+
!(operandAType.isBF16() && operandBType.isBF16()))
224+
return false;
225+
}
226+
if (outType.isSignedInteger(32)) {
227+
if (!(operandAType.isSignedInteger(8) ||
228+
operandAType.isUnsignedInteger(8)) &&
229+
(operandBType.isSignedInteger(8) || operandBType.isUnsignedInteger(8)))
230+
return false;
231+
}
232+
return true;
233+
}
234+
235+
// TODO(haixin): could use compiler-wide VNNI utils?
236+
static bool isInVnniLayout(ShapedType type) {
237+
if (!type.getElementType().isBF16() &&
238+
!type.getElementType().isSignedInteger(8) &&
239+
!type.getElementType().isUnsignedInteger(8))
240+
return false;
241+
242+
auto blockingFactor = 0;
243+
if (type.getElementType().isBF16())
244+
blockingFactor = 2;
245+
else if (type.getElementType().isSignedInteger(8) ||
246+
type.getElementType().isUnsignedInteger(8))
247+
blockingFactor = 4;
248+
249+
return type.getShape().back() == blockingFactor;
250+
}
251+
216252
/////////////////////////////////////////////////////
217253
// Start of BrgemmOp
218254

@@ -308,9 +344,8 @@ static inline ArrayRef<int64_t> getShapedValueShape(Value val) {
308344
assert((llvm::isa<TensorType>(val.getType()) ||
309345
llvm::isa<MemRefType>(val.getType())) &&
310346
"Expecting shaped value");
311-
if (auto tensorTy = dyn_cast_or_null<TensorType>(val.getType())) {
347+
if (auto tensorTy = dyn_cast_or_null<TensorType>(val.getType()))
312348
return tensorTy.getShape();
313-
}
314349
auto memrefTy = dyn_cast_or_null<MemRefType>(val.getType());
315350
return memrefTy.getShape();
316351
}
@@ -331,15 +366,27 @@ LogicalResult BrgemmOp::verify() {
331366
return op.emitOpError()
332367
<< "expect inputs and its related info to be size 2\n";
333368

369+
auto elemTypeA = getElementTypeOrSelf(ins[0]);
370+
auto elemTypeB = getElementTypeOrSelf(ins[1]);
371+
auto elemTypeC = getElementTypeOrSelf(out);
372+
if (!isTypeSupported(elemTypeC, elemTypeA, elemTypeB))
373+
return op.emitOpError() << "unsupported input matrix types\n";
374+
334375
ArrayRef<int64_t> dimA = getShapedValueShape(ins[0]);
335376
ArrayRef<int64_t> dimB = getShapedValueShape(ins[1]);
336377
ArrayRef<int64_t> dimC = getShapedValueShape(out);
337378
if (dimA.size() != 3)
338379
return op.emitOpError() << "expect input A to be 3D\n";
339-
if (dimB.size() != 3 && dimB.size() != 4)
340-
return op.emitOpError() << "expect input B to be 3D or 4D\n";
341-
if (dimB.size() == 4 && (dimB[3] != 2 && dimB[3] != 4))
342-
return op.emitOpError() << "expect input B vnni step to be 2 or 4\n";
380+
if (!elemTypeB.isF32()) {
381+
if (dimB.size() != 4 ||
382+
!isInVnniLayout(dyn_cast<ShapedType>(ins[1].getType())))
383+
return op.emitOpError()
384+
<< "expect a 4d VNNI input B for non-F32 operand: " << ins[1];
385+
} else {
386+
if (dimB.size() != 3)
387+
return op.emitOpError()
388+
<< "expect a 3d input B for F32 operand: " << ins[1];
389+
}
343390
if (dimC.size() != 2)
344391
return op.emitOpError() << "expect input C to be 2D\n";
345392
for (auto dim : batchDims)
@@ -558,42 +605,6 @@ LogicalResult BrgemmDispatchOp::verify() {
558605
/////////////////////////////////////////////////////
559606
// Start of BrgemmExecuteOp
560607

561-
// TODO(haixin): could use compiler-wide VNNI utils?
562-
static bool isInVnniLayout(MemRefType memref) {
563-
if (!memref.getElementType().isBF16() &&
564-
!memref.getElementType().isSignedInteger(8) &&
565-
!memref.getElementType().isUnsignedInteger(8))
566-
return false;
567-
568-
auto blockingFactor = 0;
569-
if (memref.getElementType().isBF16())
570-
blockingFactor = 2;
571-
else if (memref.getElementType().isSignedInteger(8) ||
572-
memref.getElementType().isUnsignedInteger(8))
573-
blockingFactor = 4;
574-
575-
return memref.getShape().back() == blockingFactor;
576-
}
577-
578-
static bool isTypeSupported(Type outType, Type operandAType,
579-
Type operandBType) {
580-
if (!outType.isF32() && !outType.isSignedInteger(32))
581-
return false;
582-
583-
if (outType.isF32()) {
584-
if (!(operandAType.isF32() && operandBType.isF32()) &&
585-
!(operandAType.isBF16() && operandBType.isBF16()))
586-
return false;
587-
}
588-
if (outType.isSignedInteger(32)) {
589-
if (!(operandAType.isSignedInteger(8) ||
590-
operandAType.isUnsignedInteger(8)) &&
591-
(operandBType.isSignedInteger(8) || operandBType.isUnsignedInteger(8)))
592-
return false;
593-
}
594-
return true;
595-
}
596-
597608
LogicalResult BrgemmExecuteOp::verify() {
598609
BrgemmExecuteOp &brgemmOp = *this;
599610

lib/gc/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ add_subdirectory(Utils)
33
gc_set_mlir_link_components(MLIR_LINK_COMPONENTS
44
MLIRIR
55
MLIRSupport
6+
MLIRMicrokernelTransforms
67
MLIRBufferizationToMemRef
78
MLIRBufferizationPipelines)
89

lib/gc/Transforms/Microkernel/CMakeLists.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
1-
gc_set_mlir_link_components(MLIR_LINK_COMPONENTS MLIRIR)
1+
gc_set_mlir_link_components(MLIR_LINK_COMPONENTS MLIRIR MLIRMicrokernel)
22

33
include(onednn)
44

55
gc_add_mlir_dialect_library(MLIRMicrokernelTransforms
66
ConvertLinalgToMicrokernel.cpp
77
ExpandMicrokernel.cpp
88
ConvertMicrokernelToDnnlFunc.cpp
9+
EarlyDispatchMicrokernel.cpp
10+
MicrokernelInvariantCodeMotion.cpp
11+
MergeBranchMicrokernelContext.cpp
912

1013
ADDITIONAL_HEADER_DIRS
1114
${PROJECT_SOURCE_DIR}/include/

lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,23 @@ static FailureOr<BrgemmDims> inferBrgemmDims(linalg::LinalgOp linalgOp) {
157157
else
158158
return failure();
159159

160+
OpOperand *operandA = linalgOp.getDpsInputOperands()[0];
161+
OpOperand *operandB = linalgOp.getDpsInputOperands()[1];
162+
Type operandBElemType = getElementTypeOrSelf(operandB->get());
163+
if (operandBElemType.isF32()) {
164+
if (kAffinePos.size() == 2) {
165+
LLVM_DEBUG(llvm::dbgs() << "[checkStructure] Wrong dimensions for input "
166+
"B, should be non-VNNI\n");
167+
return failure();
168+
}
169+
} else {
170+
if (kAffinePos.size() == 1) {
171+
LLVM_DEBUG(llvm::dbgs() << "[checkStructure] Wrong dimensions for input "
172+
"B, should be VNNI\n");
173+
return failure();
174+
}
175+
}
176+
160177
LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmDims] Candidate dims: "
161178
<< "\n");
162179
LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmDims] m pos in affine: " << mAffinePos
@@ -169,9 +186,6 @@ static FailureOr<BrgemmDims> inferBrgemmDims(linalg::LinalgOp linalgOp) {
169186
LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmDims] batch pos in affine: "
170187
<< batchAffinePos << "\n");
171188

172-
OpOperand *operandA = linalgOp.getDpsInputOperands()[0];
173-
OpOperand *operandB = linalgOp.getDpsInputOperands()[1];
174-
175189
BrgemmDims brgemmDims;
176190

177191
#define CHECK_GET_POS_IN_DOMAIN(dim, dimPos, operand) \

0 commit comments

Comments
 (0)