Skip to content

Commit f696116

Browse files
committed
*Added logic to ensure the indexing_map attribute can be dropped for collapsed contraction op.
*Refactored some tests and methods for better naming, comments and readability.
1 parent 318eeca commit f696116

File tree

6 files changed

+42
-64
lines changed

6 files changed

+42
-64
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -710,7 +710,10 @@ def LinalgStructuredInterface
710710
>,
711711
InterfaceMethod<
712712
/*desc=*/[{
713-
Return true if the user has supplied an explicit indexing maps for this op.
713+
Returns true if the user has supplied explicit indexing maps that are
714+
different from default indexing maps for this op. Returns `false` otherwise.
715+
Note, if the user define maps that are identical to the default maps,
716+
this method returns `false`.
714717
}],
715718
/*retTy=*/"bool",
716719
/*methodName=*/"hasUserDefinedMaps",

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -674,8 +674,7 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
674674
static unsigned getNumRegionArgs();
675675
std::string getLibraryCallName();
676676
bool hasDynamicIndexingMaps();
677-
/// Check if the op has broadcast and/or transpose semantic. Returns true if the
678-
/// user defined indexing maps are not equal to default map.
677+
/// Returns true if the user defined indexing maps are not equal to default maps.
679678
bool hasUserDefinedMaps();
680679
}];
681680
}
@@ -933,8 +932,7 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz
933932
static unsigned getNumRegionArgs();
934933
bool hasDynamicIndexingMaps() { return true; }
935934
std::string getLibraryCallName();
936-
/// Check if the op has broadcast and/or transpose semantic. Returns true if the
937-
/// user defined indexing maps are not equal to default map.
935+
/// Returns true if the user defined indexing maps are not equal to default maps.
938936
bool hasUserDefinedMaps();
939937
}];
940938
}

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 24 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3426,11 +3426,10 @@ Operation *LinalgDialect::materializeConstant(OpBuilder &builder,
34263426
return arith::ConstantOp::materialize(builder, value, type, loc);
34273427
}
34283428

3429-
/// Returns true if the result AffineExpr of the \p explicitMap is same as \p
3430-
/// defaultMap.
3431-
static bool isValidResultDimExprs(AffineMap explictMap, AffineMap defaultMap) {
3432-
auto explicitRange = explictMap.getResults();
3433-
auto defaultRange = defaultMap.getResults();
3429+
// Returns true if the result expression of `subMap` are a subset of `fullMap`.
3430+
static bool areResultExprsSubsetOf(AffineMap subMap, AffineMap fullMap) {
3431+
auto explicitRange = subMap.getResults();
3432+
auto defaultRange = fullMap.getResults();
34343433
DenseSet<AffineExpr> explicitSet(explicitRange.begin(), explicitRange.end());
34353434
DenseSet<AffineExpr> defaultSet(defaultRange.begin(), defaultRange.end());
34363435
llvm::set_union(explicitSet, defaultSet);
@@ -3455,7 +3454,7 @@ static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
34553454
auto opIndexingMap = opIndexingMaps[opIndex];
34563455
auto defaultIndexingMap = defaultIndexingMaps[opIndex];
34573456
// Check general validity of indexing map results.
3458-
if (!isValidResultDimExprs(opIndexingMap, defaultIndexingMap))
3457+
if (!areResultExprsSubsetOf(opIndexingMap, defaultIndexingMap))
34593458
return matmulOp->emitOpError()
34603459
<< "Unexpected dim expression in map result.";
34613460

@@ -3470,44 +3469,31 @@ static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
34703469
return success();
34713470
}
34723471

3473-
/// Checks if the given AffineMap represents a valid batch dimension.
3474-
/// It checks if the first result dimension is a function of the first
3475-
/// dimension.
3476-
static bool isValidBatchDim(AffineMap bcastMap) {
3477-
AffineExpr exp = bcastMap.getResult(0);
3478-
return exp.isFunctionOfDim(0);
3479-
}
3480-
3481-
/// Checks if the given AffineMap's result dimensions are valid output result
3482-
/// dimensions.
3483-
static bool isValidOutputResultDim(AffineMap outputMap) {
3484-
enum Indices { batchPos, mPos, nPos };
3485-
AffineExpr exp0 = outputMap.getResult(batchPos);
3486-
AffineExpr exp1 = outputMap.getResult(mPos);
3487-
AffineExpr exp2 = outputMap.getResult(nPos);
3488-
return exp0.isFunctionOfDim(batchPos) && exp1.isFunctionOfDim(mPos) &&
3489-
exp2.isFunctionOfDim(nPos);
3490-
}
3491-
34923472
// Check general validity of input indexing map.
34933473
static LogicalResult verifyInputMaps(BatchMatmulOp batchMatmulOp,
34943474
AffineMap opIndexingMap,
34953475
AffineMap defaultIndexingMap, bool isLHS) {
34963476
// Check the result dims are valid.
3497-
if (!isValidResultDimExprs(opIndexingMap, defaultIndexingMap))
3477+
if (!areResultExprsSubsetOf(opIndexingMap, defaultIndexingMap))
34983478
return batchMatmulOp->emitOpError()
3499-
<< "Unexpected dim expression in map result.";
3479+
<< "Unexpected result dim expression (outside the set of default "
3480+
"result dims).";
35003481

35013482
// Check for valid number of result dims of input maps.
35023483
if (opIndexingMap.getNumResults() > 3)
35033484
return batchMatmulOp->emitOpError()
3504-
<< "no. of result dim expression cannot exceed 3.";
3485+
<< "no. of result dim expressions exceeds 3.";
3486+
3487+
auto hasValidBatchDim = [](AffineMap map) {
3488+
AffineExpr batchDim = map.getResult(0);
3489+
return batchDim.isFunctionOfDim(0);
3490+
};
35053491

35063492
// Check if the requested broadcast is valid.
35073493
if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
35083494
if (!batchMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS))
35093495
return batchMatmulOp->emitOpError() << "Invalid broadcast requested.";
3510-
} else if (!isValidBatchDim(opIndexingMap)) {
3496+
} else if (!hasValidBatchDim(opIndexingMap)) {
35113497
return batchMatmulOp->emitOpError()
35123498
<< "Invalid batch dimension expression.";
35133499
}
@@ -3524,7 +3510,13 @@ static LogicalResult verifyOutputMap(BatchMatmulOp batchMatmulOp,
35243510
<< "expects 3 dims, but got (" << opIndexingMap.getNumResults()
35253511
<< ").";
35263512

3527-
if (!isValidOutputResultDim(opIndexingMap))
3513+
auto areValidOutputResultDim = [](AffineMap outputMap) {
3514+
return outputMap.getResult(0).isFunctionOfDim(0) &&
3515+
outputMap.getResult(1).isFunctionOfDim(1) &&
3516+
outputMap.getResult(2).isFunctionOfDim(2);
3517+
};
3518+
3519+
if (!areValidOutputResultDim(opIndexingMap))
35283520
return batchMatmulOp->emitOpError()
35293521
<< "Invalid output map result dimension.";
35303522

@@ -3941,7 +3933,8 @@ bool BatchMatmulOp::hasUserDefinedMaps() {
39413933

39423934
/// Returns true if the given broadcast map \p bcastMap is valid for this op.
39433935
bool BatchMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS) {
3944-
assert(bcastMap.getNumResults() < 3 && "Expected single result dim expr.");
3936+
assert(bcastMap.getNumResults() < 3 &&
3937+
"Expected less than 3 result dim expr.");
39453938
bool isValid = false;
39463939
enum Indices { batchPos, mPos, nPos, kPos };
39473940
if (bcastMap.getNumResults() == 1) {

mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp

Lines changed: 6 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -907,14 +907,9 @@ struct RankReduceContractionOps : OpRewritePattern<FromOpTy> {
907907

908908
LogicalResult matchAndRewrite(FromOpTy contractionOp,
909909
PatternRewriter &rewriter) const override {
910-
// Check to not let go the batch_matmul with extended semantic, through this
911-
// transform.
912-
if (std::is_same<FromOpTy, BatchMatmulOp>::value ||
913-
std::is_same<FromOpTy, MatmulOp>::value) {
914-
if (contractionOp.hasUserDefinedMaps()) {
915-
return rewriter.notifyMatchFailure(
916-
contractionOp, "ops with user-defined maps are not supported");
917-
}
910+
if (contractionOp.hasUserDefinedMaps()) {
911+
return rewriter.notifyMatchFailure(
912+
contractionOp, "ops with user-defined maps are not supported");
918913
}
919914

920915
auto loc = contractionOp.getLoc();
@@ -945,21 +940,10 @@ struct RankReduceContractionOps : OpRewritePattern<FromOpTy> {
945940
loc, collapsedResultTy, ValueRange{collapsedLhs, collapsedRhs},
946941
ValueRange{collapsedInit});
947942
for (auto attr : contractionOp->getAttrs()) {
948-
if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName)
943+
if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName ||
944+
attr.getName() == "indexing_maps")
949945
continue;
950-
951-
// Update the indexing_maps attribute for the collapsed MatmulOp.
952-
if (attr.getName() == "indexing_maps" &&
953-
std::is_same<FromOpTy, BatchMatmulOp>::value &&
954-
std::is_same<ToOpTy, MatmulOp>::value) {
955-
SmallVector<Attribute, 3> indexingMapsAttr = llvm::map_to_vector(
956-
MatmulOp::getDefaultIndexingMaps(rewriter.getContext()),
957-
[](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
958-
collapsedOp->setAttr(attr.getName(),
959-
rewriter.getArrayAttr(indexingMapsAttr));
960-
} else {
961-
collapsedOp->setAttr(attr.getName(), attr.getValue());
962-
}
946+
collapsedOp->setAttr(attr.getName(), attr.getValue());
963947
}
964948

965949
auto results = contractionOp.getResults();

mlir/test/Dialect/Linalg/invalid.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1303,7 +1303,7 @@ func.func @missing_indexing_map_batch_matmul(%arg0: memref<?x?x?xf32>, %arg1: me
13031303
// -----
13041304

13051305
func.func @invalid_dim_expr_batch_matmul_a(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
1306-
// expected-error @+1 {{Unexpected dim expression in map result}}
1306+
// expected-error @+1 {{Unexpected result dim expression (outside the set of default result dims)}}
13071307
linalg.batch_matmul indexing_maps = [
13081308
affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>,
13091309
affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
@@ -1316,7 +1316,7 @@ func.func @invalid_dim_expr_batch_matmul_a(%arg0: memref<?x?x?xf32>, %arg1: memr
13161316
// -----
13171317

13181318
func.func @invalid_dim_expr_batch_matmul_b(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
1319-
// expected-error @+1 {{Unexpected dim expression in map result}}
1319+
// expected-error @+1 {{Unexpected result dim expression (outside the set of default result dims)}}
13201320
linalg.batch_matmul indexing_maps = [
13211321
affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
13221322
affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>,
@@ -1407,7 +1407,7 @@ func.func @invalid_batch_dim_batch_matmul_b(%arg0: memref<?x?x?xf32>, %arg1: mem
14071407
// -----
14081408

14091409
func.func @invalid_A_map_result_num_batch_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
1410-
// expected-error @+1 {{'linalg.batch_matmul' op no. of result dim expression cannot exceed 3.}}
1410+
// expected-error @+1 {{'linalg.batch_matmul' op no. of result dim expressions exceeds 3.}}
14111411
linalg.batch_matmul indexing_maps = [
14121412
affine_map<(d0, d1, d2, d3) -> (d0, d1, d3, d3)>,
14131413
affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,
@@ -1421,7 +1421,7 @@ func.func @invalid_A_map_result_num_batch_matmul(%arg0: memref<?x?x?xf32>, %arg1
14211421
// -----
14221422

14231423
func.func @invalid_B_map_result_num_batch_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?xf32>) {
1424-
// expected-error @+1 {{'linalg.batch_matmul' op no. of result dim expression cannot exceed 3.}}
1424+
// expected-error @+1 {{'linalg.batch_matmul' op no. of result dim expressions exceeds 3.}}
14251425
linalg.batch_matmul indexing_maps = [
14261426
affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>,
14271427
affine_map<(d0, d1, d2, d3) -> (d0, d3, d2, d3)>,

mlir/test/Dialect/Linalg/named-ops.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1489,14 +1489,14 @@ func.func @matmul_transpose_b(%arg0: memref<3x5xf32>, %arg1: memref<7x5xf32>, %a
14891489
// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
14901490
// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
14911491

1492-
// CHECK-LABEL: func.func @batch_matmul_bcast_batch_and_m_dim_A(
1492+
// CHECK-LABEL: func.func @batch_matmul_bcast_k_to_fill_missing_dims_A(
14931493
// CHECK-SAME: %[[VAL_0:.*]]: memref<5xf32>,
14941494
// CHECK-SAME: %[[VAL_1:.*]]: memref<2x5x7xf32>,
14951495
// CHECK-SAME: %[[VAL_2:.*]]: memref<2x3x7xf32>) {
14961496
// CHECK: linalg.batch_matmul ins(%[[VAL_0]], %[[VAL_1]] : memref<5xf32>, memref<2x5x7xf32>) outs(%[[VAL_2]] : memref<2x3x7xf32>) indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]]
14971497
// CHECK: return
14981498
// CHECK: }
1499-
func.func @batch_matmul_bcast_batch_and_m_dim_A(%arg0: memref<5xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<2x3x7xf32>) {
1499+
func.func @batch_matmul_bcast_k_to_fill_missing_dims_A(%arg0: memref<5xf32>, %arg1: memref<2x5x7xf32>, %arg2: memref<2x3x7xf32>) {
15001500
linalg.batch_matmul indexing_maps = [
15011501
affine_map<(d0, d1, d2, d3) -> (d3)>,
15021502
affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>,

0 commit comments

Comments
 (0)