Skip to content

Commit 0b1cbee

Browse files
committed
Resolve conflict.
1 parent d9eda6b commit 0b1cbee

File tree

2 files changed

+24
-27
lines changed

2 files changed

+24
-27
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -622,15 +622,17 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
622622
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
623623
[{
624624
buildMatmulOp($_builder, $_state, std::nullopt, inputs, outputs,
625-
attributes, MatmulOp::getRegionBuilder());
625+
attributes, MatmulOp::getRegionBuilder(),
626+
MatmulOp::getDefaultIndexingMaps($_builder.getContext()));
626627
}]>,
627628
OpBuilder<
628629
(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
629630
"ValueRange":$outputs,
630631
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
631632
[{
632633
buildMatmulOp($_builder, $_state, resultTensorTypes,
633-
inputs, outputs, attributes, MatmulOp::getRegionBuilder());
634+
inputs, outputs, attributes, MatmulOp::getRegionBuilder(),
635+
MatmulOp::getDefaultIndexingMaps($_builder.getContext()));
634636
}]>,
635637
OpBuilder<
636638
(ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
@@ -648,7 +650,8 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
648650
[{
649651
$_state.addAttribute("cast", cast);
650652
buildMatmulOp($_builder, $_state, resultTensorTypes, inputs, outputs,
651-
attributes, MatmulOp::getRegionBuilder());
653+
attributes, MatmulOp::getRegionBuilder(),
654+
MatmulOp::getDefaultIndexingMaps($_builder.getContext()));
652655
}]>
653656

654657
];
@@ -664,7 +667,15 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
664667
Block &block, ArrayRef<NamedAttribute> attrs);
665668

666669
/// Returns a list of AffineMap with the typical matmul indexing charactristic.
667-
SmallVector<AffineMap> getDefaultIndexingMaps();
670+
static SmallVector<AffineMap> getDefaultIndexingMaps(MLIRContext *context){
671+
AffineExpr d0, d1, d2;
672+
SmallVector<AffineMap, 3> indexingMaps;
673+
bindDims(context, d0, d1, d2);
674+
indexingMaps.push_back(AffineMap::get(3, 0, {d0, d2}, context));
675+
indexingMaps.push_back(AffineMap::get(3, 0, {d2, d1}, context));
676+
indexingMaps.push_back(AffineMap::get(3, 0, {d0, d1}, context));
677+
return indexingMaps;
678+
}
668679

669680
/// Returns true if the given broadcast map \p bcastMap is valid for this op.
670681
bool isValidLhsRhsBroadcastMap(AffineMap bcastMap);

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 9 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -168,14 +168,6 @@ getDefaultIndexingMapsForMatmul(MLIRContext *context) {
168168
return indexingMaps;
169169
}
170170

171-
/// Wrapper to return the typical indexing map array attribute for MatmulOp.
172-
static SmallVector<Attribute>
173-
getDefaultMatmulIndexingMapAttr(MLIRContext *context) {
174-
return llvm::map_to_vector(
175-
getDefaultIndexingMapsForMatmul(context),
176-
[](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
177-
}
178-
179171
/// Creates a structured operation given `inputs`, `outputs`, and `attributes`.
180172
/// The result types are derived automatically if `resultTensorTypes` is none.
181173
/// The body of the operation is filled using `regionBuilder`. All ods-gen
@@ -222,9 +214,6 @@ buildMatmulOp(OpBuilder &b, OperationState &state,
222214
indexingMapsAttrVal.push_back(AffineMapAttr::get(map));
223215
}
224216
state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
225-
} else {
226-
indexingMapsAttrVal = getDefaultMatmulIndexingMapAttr(b.getContext());
227-
state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
228217
}
229218
return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
230219
attributes, regionBuilder);
@@ -3457,7 +3446,7 @@ static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
34573446
unsigned opIndex) {
34583447
SmallVector<AffineMap, 3> opIndexingMaps = matmulOp.getIndexingMapsArray();
34593448
SmallVector<AffineMap, 3> defaultIndexingMaps =
3460-
matmulOp.getDefaultIndexingMaps();
3449+
matmulOp.getDefaultIndexingMaps(matmulOp->getContext());
34613450

34623451
auto opIndexingMap = opIndexingMaps[opIndex];
34633452
auto defaultIndexingMap = defaultIndexingMaps[opIndex];
@@ -3501,7 +3490,8 @@ bool MatmulOp::hasDynamicIndexingMaps() { return true; }
35013490
/// Check if the op has broadcast and/or transpose semantic. Returns true if
35023491
/// the user defined indexing maps are not equal to default map.
35033492
bool MatmulOp::hasUserDefinedMaps() {
3504-
SmallVector<AffineMap, 3> defaultMaps = getDefaultIndexingMaps();
3493+
SmallVector<AffineMap, 3> defaultMaps =
3494+
getDefaultIndexingMaps(this->getContext());
35053495
SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
35063496
return defaultMaps != explicitMaps;
35073497
}
@@ -3535,13 +3525,6 @@ void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
35353525
helper.yieldOutputs(yields);
35363526
}
35373527

3538-
/// Returns a list of AffineMap with the typical matmul indexing
3539-
/// charactristic.
3540-
SmallVector<AffineMap> MatmulOp::getDefaultIndexingMaps() {
3541-
MLIRContext *context = this->getContext();
3542-
return getDefaultIndexingMapsForMatmul(context);
3543-
}
3544-
35453528
/// Returns true if the given broadcast map \p bcastMap is valid for this op.
35463529
bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) {
35473530
assert(bcastMap.getNumResults() == 1 && "Expected single result dim expr.");
@@ -3578,7 +3561,9 @@ ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
35783561
}
35793562
// Initialize indexingMaps, if not supplied explicitly.
35803563
if (indexingMapsAttr.empty()) {
3581-
indexingMapsAttr = getDefaultMatmulIndexingMapAttr(result.getContext());
3564+
indexingMapsAttr = llvm::map_to_vector(
3565+
MatmulOp::getDefaultIndexingMaps(parser.getContext()),
3566+
[](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
35823567
}
35833568
result.addAttribute("indexing_maps",
35843569
parser.getBuilder().getArrayAttr(indexingMapsAttr));
@@ -3592,8 +3577,9 @@ void MatmulOp::print(OpAsmPrinter &p) {
35923577
printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
35933578
elidedAttrs);
35943579

3595-
SmallVector<Attribute, 3> indexingMaps =
3596-
getDefaultMatmulIndexingMapAttr(getContext());
3580+
SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
3581+
MatmulOp::getDefaultIndexingMaps(getContext()),
3582+
[](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
35973583
if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
35983584
p << " indexing_maps = [";
35993585
llvm::interleaveComma(getIndexingMaps(), p,

0 commit comments

Comments
 (0)