Skip to content

Commit 55b2bf4

Browse files
committed
[NFC][MLIR][Linalg] Refactor linalg.matmul tablegen ODS and related c++ code.
This commit refactors part of the code in preparation for the migration of other *matmul* variant from OpDSL to ODS. Moves getDefaultIndexingmaps() helper into the MatmulOp class.
1 parent 10b048c commit 55b2bf4

File tree

2 files changed

+19
-29
lines changed

2 files changed

+19
-29
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -622,15 +622,17 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
622622
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
623623
[{
624624
buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
625-
attributes, MatmulOp::getRegionBuilder());
625+
attributes, MatmulOp::getRegionBuilder(),
626+
MatmulOp::getDefaultIndexingMaps($_builder.getContext()));
626627
}]>,
627628
OpBuilder<
628629
(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
629630
"ValueRange":$outputs,
630631
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
631632
[{
632633
buildStructuredOp($_builder, $_state, resultTensorTypes,
633-
inputs, outputs, attributes, MatmulOp::getRegionBuilder());
634+
inputs, outputs, attributes, MatmulOp::getRegionBuilder(),
635+
MatmulOp::getDefaultIndexingMaps($_builder.getContext()));
634636
}]>,
635637
OpBuilder<
636638
(ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
@@ -648,7 +650,8 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
648650
[{
649651
$_state.addAttribute("cast", cast);
650652
buildStructuredOp($_builder, $_state, resultTensorTypes, inputs, outputs,
651-
attributes, MatmulOp::getRegionBuilder());
653+
attributes, MatmulOp::getRegionBuilder(),
654+
MatmulOp::getDefaultIndexingMaps($_builder.getContext()));
652655
}]>
653656

654657
];
@@ -664,7 +667,15 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
664667
Block &block, ArrayRef<NamedAttribute> attrs);
665668

666669
/// Returns a list of AffineMap with the typical matmul indexing charactristic.
667-
SmallVector<AffineMap> getDefaultIndexingMaps();
670+
static SmallVector<AffineMap> getDefaultIndexingMaps(MLIRContext *context){
671+
AffineExpr d0, d1, d2;
672+
SmallVector<AffineMap, 3> indexingMaps;
673+
bindDims(context, d0, d1, d2);
674+
indexingMaps.push_back(AffineMap::get(3, 0, {d0, d2}, context));
675+
indexingMaps.push_back(AffineMap::get(3, 0, {d2, d1}, context));
676+
indexingMaps.push_back(AffineMap::get(3, 0, {d0, d1}, context));
677+
return indexingMaps;
678+
}
668679

669680
/// Returns true if the given broadcast map \p bcastMap is valid for this op.
670681
bool isValidLhsRhsBroadcastMap(AffineMap bcastMap);

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 4 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -155,23 +155,10 @@ static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
155155
// iterator_types is an auto-generated method.
156156
}
157157

158-
/// Helper to create a typical indexing map for MatmulOp. Returns a list of
159-
/// AffineMap.
160-
static SmallVector<AffineMap, 3>
161-
getDefaultIndexingMapsForMatmul(MLIRContext *context) {
162-
AffineExpr d0, d1, d2;
163-
SmallVector<AffineMap, 3> indexingMaps;
164-
bindDims(context, d0, d1, d2);
165-
indexingMaps.push_back(AffineMap::get(3, 0, {d0, d2}, context));
166-
indexingMaps.push_back(AffineMap::get(3, 0, {d2, d1}, context));
167-
indexingMaps.push_back(AffineMap::get(3, 0, {d0, d1}, context));
168-
return indexingMaps;
169-
}
170-
171158
/// Wrapper to return the typical indexing map array attribute for MatmulOp.
172159
static SmallVector<Attribute> getDefaultIndexingMapAttr(MLIRContext *context) {
173160
return llvm::map_to_vector(
174-
getDefaultIndexingMapsForMatmul(context),
161+
MatmulOp::getDefaultIndexingMaps(context),
175162
[](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
176163
}
177164

@@ -204,9 +191,6 @@ static void buildStructuredOp(
204191
indexingMapsAttrVal.push_back(AffineMapAttr::get(map));
205192
}
206193
state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
207-
} else {
208-
indexingMapsAttrVal = getDefaultIndexingMapAttr(b.getContext());
209-
state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
210194
}
211195

212196
state.addAttributes(attributes);
@@ -3481,7 +3465,7 @@ static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
34813465
unsigned opIndex) {
34823466
SmallVector<AffineMap, 3> opIndexingMaps = matmulOp.getIndexingMapsArray();
34833467
SmallVector<AffineMap, 3> defaultIndexingMaps =
3484-
matmulOp.getDefaultIndexingMaps();
3468+
matmulOp.getDefaultIndexingMaps(matmulOp->getContext());
34853469

34863470
auto opIndexingMap = opIndexingMaps[opIndex];
34873471
auto defaultIndexingMap = defaultIndexingMaps[opIndex];
@@ -3523,7 +3507,8 @@ bool MatmulOp::hasDynamicIndexingMaps() { return true; }
35233507
/// Check if the op has broadcast and/or transpose semantic. Returns true if the
35243508
/// user defined indexing maps are not equal to default map.
35253509
bool MatmulOp::hasUserDefinedMaps() {
3526-
SmallVector<AffineMap, 3> defaultMaps = getDefaultIndexingMaps();
3510+
SmallVector<AffineMap, 3> defaultMaps =
3511+
MatmulOp::getDefaultIndexingMaps(this->getContext());
35273512
SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
35283513
return defaultMaps != explicitMaps;
35293514
}
@@ -3557,12 +3542,6 @@ void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
35573542
helper.yieldOutputs(yields);
35583543
}
35593544

3560-
/// Returns a list of AffineMap with the typical matmul indexing charactristic.
3561-
SmallVector<AffineMap> MatmulOp::getDefaultIndexingMaps() {
3562-
MLIRContext *context = this->getContext();
3563-
return getDefaultIndexingMapsForMatmul(context);
3564-
}
3565-
35663545
/// Returns true if the given broadcast map \p bcastMap is valid for this op.
35673546
bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) {
35683547
assert(bcastMap.getNumResults() == 1 && "Expected single result dim expr.");

0 commit comments

Comments
 (0)