Skip to content

Commit 9ed1e58

Browse files
committed
[mlir][shape] Start a pass that lowers shape constraints.
This pass converts shape.cstr_* ops to eager (side-effecting) error-handling code. After that conversion is done, the witnesses are trivially satisfied and are replaced with `shape.const_witness true`. Differential Revision: https://reviews.llvm.org/D87941
1 parent c4bacc3 commit 9ed1e58

File tree

5 files changed

+209
-0
lines changed

5 files changed

+209
-0
lines changed

mlir/include/mlir/Conversion/Passes.td

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,21 @@ def ConvertShapeToStandard : Pass<"convert-shape-to-std", "ModuleOp"> {
242242
let dependentDialects = ["StandardOpsDialect", "scf::SCFDialect"];
243243
}
244244

245+
def ConvertShapeConstraints: Pass<"convert-shape-constraints", "FuncOp"> {
246+
let summary = "Convert shape constraint operations to the standard dialect";
247+
let description = [{
248+
This pass eliminates shape constraints from the program, converting them to
249+
eager (side-effecting) error handling code.
250+
251+
This pass is separate from the regular convert-shape-to-standard, despite
252+
converting between the same dialects, because converting shape constraints
253+
can happen at a different part of the program than general shape
254+
computation lowering.
255+
}];
256+
let constructor = "mlir::createConvertShapeConstraintsPass()";
257+
let dependentDialects = ["StandardOpsDialect", "scf::SCFDialect"];
258+
}
259+
245260
//===----------------------------------------------------------------------===//
246261
// SPIRVToLLVM
247262
//===----------------------------------------------------------------------===//

mlir/include/mlir/Conversion/ShapeToStandard/ShapeToStandard.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
namespace mlir {
1515

16+
class FuncOp;
1617
class MLIRContext;
1718
class ModuleOp;
1819
template <typename T>
@@ -24,6 +25,11 @@ void populateShapeToStandardConversionPatterns(
2425

2526
std::unique_ptr<OperationPass<ModuleOp>> createConvertShapeToStandardPass();
2627

28+
void populateConvertShapeConstraintsConversionPatterns(
29+
OwningRewritePatternList &patterns, MLIRContext *ctx);
30+
31+
std::unique_ptr<OperationPass<FuncOp>> createConvertShapeConstraintsPass();
32+
2733
} // namespace mlir
2834

2935
#endif // MLIR_CONVERSION_SHAPETOSTANDARD_SHAPETOSTANDARD_H_

mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
add_mlir_conversion_library(MLIRShapeToStandard
2+
ConvertShapeConstraints.cpp
23
ShapeToStandard.cpp
34

45
ADDITIONAL_HEADER_DIRS
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
//===- ConvertShapeConstraints.cpp - Conversion of shape constraints ------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
10+
11+
#include "../PassDetail.h"
12+
#include "mlir/Dialect/SCF/SCF.h"
13+
#include "mlir/Dialect/Shape/IR/Shape.h"
14+
#include "mlir/Dialect/StandardOps/IR/Ops.h"
15+
#include "mlir/IR/PatternMatch.h"
16+
#include "mlir/Pass/Pass.h"
17+
#include "mlir/Pass/PassRegistry.h"
18+
19+
using namespace mlir;
20+
21+
namespace {
22+
class ConvertCstrBroadcastableOp
23+
: public OpRewritePattern<shape::CstrBroadcastableOp> {
24+
public:
25+
using OpRewritePattern::OpRewritePattern;
26+
LogicalResult matchAndRewrite(shape::CstrBroadcastableOp op,
27+
PatternRewriter &rewriter) const override {
28+
if (op.getType().isa<shape::ShapeType>() ||
29+
op.lhs().getType().isa<shape::ShapeType>() ||
30+
op.rhs().getType().isa<shape::ShapeType>()) {
31+
return rewriter.notifyMatchFailure(
32+
op, "cannot convert error-propagating shapes");
33+
}
34+
35+
auto loc = op.getLoc();
36+
Value zero = rewriter.create<ConstantIndexOp>(loc, 0);
37+
Value one = rewriter.create<ConstantIndexOp>(loc, 1);
38+
39+
// Find smaller and greater rank and extent tensor.
40+
Value lhsRank = rewriter.create<DimOp>(loc, op.lhs(), zero);
41+
Value rhsRank = rewriter.create<DimOp>(loc, op.rhs(), zero);
42+
Value lhsSmaller =
43+
rewriter.create<CmpIOp>(loc, CmpIPredicate::ule, lhsRank, rhsRank);
44+
Type indexTy = rewriter.getIndexType();
45+
Type extentTensorTy = op.lhs().getType();
46+
auto ifOp = rewriter.create<scf::IfOp>(
47+
loc, TypeRange{indexTy, extentTensorTy, indexTy, extentTensorTy},
48+
lhsSmaller,
49+
[&](OpBuilder &b, Location loc) {
50+
b.create<scf::YieldOp>(
51+
loc, ValueRange{lhsRank, op.lhs(), rhsRank, op.rhs()});
52+
},
53+
[&](OpBuilder &b, Location loc) {
54+
b.create<scf::YieldOp>(
55+
loc, ValueRange{rhsRank, op.rhs(), lhsRank, op.lhs()});
56+
});
57+
Value lesserRank = ifOp.getResult(0);
58+
Value lesserRankOperand = ifOp.getResult(1);
59+
Value greaterRank = ifOp.getResult(2);
60+
Value greaterRankOperand = ifOp.getResult(3);
61+
62+
Value rankDiff =
63+
rewriter.create<SubIOp>(loc, indexTy, greaterRank, lesserRank);
64+
65+
// Generate code to compare the shapes extent by extent, and emit errors for
66+
// non-broadcast-compatible shapes.
67+
// Two extents are broadcast-compatible if
68+
// 1. they are both equal, or
69+
// 2. at least one of them is 1.
70+
71+
rewriter.create<scf::ForOp>(
72+
loc, rankDiff, greaterRank, one, llvm::None,
73+
[&](OpBuilder &b, Location loc, Value iv, ValueRange) {
74+
Value greaterRankOperandExtent = b.create<ExtractElementOp>(
75+
loc, greaterRankOperand, ValueRange{iv});
76+
Value ivShifted = b.create<SubIOp>(loc, indexTy, iv, rankDiff);
77+
Value lesserRankOperandExtent = b.create<ExtractElementOp>(
78+
loc, lesserRankOperand, ValueRange{ivShifted});
79+
80+
Value greaterRankOperandExtentIsOne = b.create<CmpIOp>(
81+
loc, CmpIPredicate::eq, greaterRankOperandExtent, one);
82+
Value lesserRankOperandExtentIsOne = b.create<CmpIOp>(
83+
loc, CmpIPredicate::eq, lesserRankOperandExtent, one);
84+
Value extentsAgree =
85+
b.create<CmpIOp>(loc, CmpIPredicate::eq, greaterRankOperandExtent,
86+
lesserRankOperandExtent);
87+
auto broadcastIsValid =
88+
b.create<OrOp>(loc, b.getI1Type(), extentsAgree,
89+
b.create<OrOp>(loc, greaterRankOperandExtentIsOne,
90+
lesserRankOperandExtentIsOne));
91+
b.create<AssertOp>(loc, broadcastIsValid, "invalid broadcast");
92+
b.create<scf::YieldOp>(loc);
93+
});
94+
95+
rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, true);
96+
return success();
97+
}
98+
};
99+
} // namespace
100+
101+
namespace {
102+
class ConvertCstrRequireOp : public OpRewritePattern<shape::CstrRequireOp> {
103+
public:
104+
using OpRewritePattern::OpRewritePattern;
105+
LogicalResult matchAndRewrite(shape::CstrRequireOp op,
106+
PatternRewriter &rewriter) const override {
107+
rewriter.create<AssertOp>(op.getLoc(), op.pred(), op.msgAttr());
108+
rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, true);
109+
return success();
110+
}
111+
};
112+
} // namespace
113+
114+
void mlir::populateConvertShapeConstraintsConversionPatterns(
115+
OwningRewritePatternList &patterns, MLIRContext *ctx) {
116+
patterns.insert<ConvertCstrBroadcastableOp>(ctx);
117+
patterns.insert<ConvertCstrRequireOp>(ctx);
118+
}
119+
120+
namespace {
121+
// This pass eliminates shape constraints from the program, converting them to
122+
// eager (side-effecting) error handling code. After eager error handling code
123+
// is emitted, witnesses are satisfied, so they are replace with
124+
// `shape.const_witness true`.
125+
class ConvertShapeConstraints
126+
: public ConvertShapeConstraintsBase<ConvertShapeConstraints> {
127+
void runOnOperation() {
128+
auto func = getOperation();
129+
auto *context = &getContext();
130+
131+
OwningRewritePatternList patterns;
132+
populateConvertShapeConstraintsConversionPatterns(patterns, context);
133+
134+
if (failed(applyPatternsAndFoldGreedily(func, patterns)))
135+
return signalPassFailure();
136+
}
137+
};
138+
} // namespace
139+
140+
std::unique_ptr<OperationPass<FuncOp>>
141+
mlir::createConvertShapeConstraintsPass() {
142+
return std::make_unique<ConvertShapeConstraints>();
143+
}
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
// RUN: mlir-opt -convert-shape-constraints <%s | FileCheck %s
2+
3+
// There's not very much useful to check here other than pasting the output.
4+
// CHECK-LABEL: func @cstr_broadcastable(
5+
// CHECK-SAME: %[[LHS:.*]]: tensor<?xindex>,
6+
// CHECK-SAME: %[[RHS:.*]]: tensor<?xindex>) -> !shape.witness {
7+
// CHECK: %[[C0:.*]] = constant 0 : index
8+
// CHECK: %[[C1:.*]] = constant 1 : index
9+
// CHECK: %[[RET:.*]] = shape.const_witness true
10+
// CHECK: %[[LHSRANK:.*]] = dim %[[LHS]], %[[C0]] : tensor<?xindex>
11+
// CHECK: %[[RHSRANK:.*]] = dim %[[RHS]], %[[C0]] : tensor<?xindex>
12+
// CHECK: %[[LESSEQUAL:.*]] = cmpi "ule", %[[LHSRANK]], %[[RHSRANK]] : index
13+
// CHECK: %[[IFRESULTS:.*]]:4 = scf.if %[[LESSEQUAL]] -> (index, tensor<?xindex>, index, tensor<?xindex>) {
14+
// CHECK: scf.yield %[[LHSRANK]], %[[LHS]], %[[RHSRANK]], %[[RHS]] : index, tensor<?xindex>, index, tensor<?xindex>
15+
// CHECK: } else {
16+
// CHECK: scf.yield %[[RHSRANK]], %[[RHS]], %[[LHSRANK]], %[[LHS]] : index, tensor<?xindex>, index, tensor<?xindex>
17+
// CHECK: }
18+
// CHECK: %[[RANKDIFF:.*]] = subi %[[IFRESULTS:.*]]#2, %[[IFRESULTS]]#0 : index
19+
// CHECK: scf.for %[[IV:.*]] = %[[RANKDIFF]] to %[[IFRESULTS]]#2 step %[[C1]] {
20+
// CHECK: %[[GREATERRANKOPERANDEXTENT:.*]] = extract_element %[[IFRESULTS]]#3{{\[}}%[[IV]]] : tensor<?xindex>
21+
// CHECK: %[[IVSHIFTED:.*]] = subi %[[IV]], %[[RANKDIFF]] : index
22+
// CHECK: %[[LESSERRANKOPERANDEXTENT:.*]] = extract_element %[[IFRESULTS]]#1{{\[}}%[[IVSHIFTED]]] : tensor<?xindex>
23+
// CHECK: %[[GREATERRANKOPERANDEXTENTISONE:.*]] = cmpi "eq", %[[GREATERRANKOPERANDEXTENT]], %[[C1]] : index
24+
// CHECK: %[[LESSERRANKOPERANDEXTENTISONE:.*]] = cmpi "eq", %[[LESSERRANKOPERANDEXTENT]], %[[C1]] : index
25+
// CHECK: %[[EXTENTSAGREE:.*]] = cmpi "eq", %[[GREATERRANKOPERANDEXTENT]], %[[LESSERRANKOPERANDEXTENT]] : index
26+
// CHECK: %[[OR_TMP:.*]] = or %[[GREATERRANKOPERANDEXTENTISONE]], %[[LESSERRANKOPERANDEXTENTISONE]] : i1
27+
// CHECK: %[[BROADCASTISVALID:.*]] = or %[[EXTENTSAGREE]], %[[OR_TMP]] : i1
28+
// CHECK: assert %[[BROADCASTISVALID]], "invalid broadcast"
29+
// CHECK: }
30+
// CHECK: return %[[RET]] : !shape.witness
31+
// CHECK: }
32+
func @cstr_broadcastable(%arg0: tensor<?xindex>, %arg1: tensor<?xindex>) -> !shape.witness {
33+
%witness = shape.cstr_broadcastable %arg0, %arg1 : tensor<?xindex>, tensor<?xindex>
34+
return %witness : !shape.witness
35+
}
36+
37+
// CHECK-LABEL: func @cstr_require
38+
func @cstr_require(%arg0: i1) -> !shape.witness {
39+
// CHECK: %[[RET:.*]] = shape.const_witness true
40+
// CHECK: assert %arg0, "msg"
41+
// CHECK: return %[[RET]]
42+
%witness = shape.cstr_require %arg0, "msg"
43+
return %witness : !shape.witness
44+
}

0 commit comments

Comments
 (0)