Skip to content

Commit 2eb6545

Browse files
authored
[CIR] Add cir-simplify pass (#138317)
This patch adds the cir-simplify pass for SelectOp and TernaryOp. It also adds the SelectOp folder and adds the constant materializer for the CIR dialect.
1 parent 3cb480b commit 2eb6545

File tree

15 files changed

+416
-10
lines changed

15 files changed

+416
-10
lines changed

clang/include/clang/CIR/CIRToCIRPasses.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ namespace cir {
3232
mlir::LogicalResult runCIRToCIRPasses(mlir::ModuleOp theModule,
3333
mlir::MLIRContext &mlirCtx,
3434
clang::ASTContext &astCtx,
35-
bool enableVerifier);
35+
bool enableVerifier,
36+
bool enableCIRSimplify);
3637

3738
} // namespace cir
3839

clang/include/clang/CIR/Dialect/IR/CIRDialect.td

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,13 @@ def CIR_Dialect : Dialect {
2727
let useDefaultAttributePrinterParser = 0;
2828
let useDefaultTypePrinterParser = 0;
2929

30+
// Enable constant materialization for the CIR dialect. This generates a
31+
// declaration for the cir::CIRDialect::materializeConstant function. This
32+
// hook is necessary for canonicalization to properly handle attributes
33+
// returned by fold methods, allowing them to be materialized as constant
34+
// operations in the IR.
35+
let hasConstantMaterializer = 1;
36+
3037
let extraClassDeclaration = [{
3138
static llvm::StringRef getTripleAttrName() { return "cir.triple"; }
3239

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1464,6 +1464,8 @@ def SelectOp : CIR_Op<"select", [Pure,
14641464
qualified(type($false_value))
14651465
`)` `->` qualified(type($result)) attr-dict
14661466
}];
1467+
1468+
let hasFolder = 1;
14671469
}
14681470

14691471
//===----------------------------------------------------------------------===//

clang/include/clang/CIR/Dialect/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ namespace mlir {
2222

2323
std::unique_ptr<Pass> createCIRCanonicalizePass();
2424
std::unique_ptr<Pass> createCIRFlattenCFGPass();
25+
std::unique_ptr<Pass> createCIRSimplifyPass();
2526
std::unique_ptr<Pass> createHoistAllocasPass();
2627

2728
void populateCIRPreLoweringPasses(mlir::OpPassManager &pm);

clang/include/clang/CIR/Dialect/Passes.td

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,25 @@ def CIRCanonicalize : Pass<"cir-canonicalize"> {
2929
let dependentDialects = ["cir::CIRDialect"];
3030
}
3131

32+
def CIRSimplify : Pass<"cir-simplify"> {
33+
let summary = "Performs CIR simplification and code optimization";
34+
let description = [{
35+
The pass performs semantics-preserving code simplifications and optimizations
36+
on CIR while maintaining strict program correctness.
37+
38+
Unlike the `cir-canonicalize` pass, these transformations may reduce the IR's
39+
structural similarity to the original source code as a trade-off for improved
40+
code quality. This can affect debugging fidelity by altering intermediate
41+
representations of folded expressions, hoisted operations, and other
42+
optimized constructs.
43+
44+
Example transformations include ternary expression folding and code hoisting
45+
while preserving program semantics.
46+
}];
47+
let constructor = "mlir::createCIRSimplifyPass()";
48+
let dependentDialects = ["cir::CIRDialect"];
49+
}
50+
3251
def HoistAllocas : Pass<"cir-hoist-allocas"> {
3352
let summary = "Hoist allocas to the entry of the function";
3453
let description = [{

clang/include/clang/CIR/MissingFeatures.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,6 @@ struct MissingFeatures {
206206
static bool labelOp() { return false; }
207207
static bool ptrDiffOp() { return false; }
208208
static bool ptrStrideOp() { return false; }
209-
static bool selectOp() { return false; }
210209
static bool switchOp() { return false; }
211210
static bool ternaryOp() { return false; }
212211
static bool tryOp() { return false; }

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,14 @@ void cir::CIRDialect::initialize() {
7979
addInterfaces<CIROpAsmDialectInterface>();
8080
}
8181

82+
Operation *cir::CIRDialect::materializeConstant(mlir::OpBuilder &builder,
83+
mlir::Attribute value,
84+
mlir::Type type,
85+
mlir::Location loc) {
86+
return builder.create<cir::ConstantOp>(loc, type,
87+
mlir::cast<mlir::TypedAttr>(value));
88+
}
89+
8290
//===----------------------------------------------------------------------===//
8391
// Helpers
8492
//===----------------------------------------------------------------------===//
@@ -1261,6 +1269,28 @@ void cir::TernaryOp::build(
12611269
result.addTypes(TypeRange{yield.getOperandTypes().front()});
12621270
}
12631271

1272+
//===----------------------------------------------------------------------===//
1273+
// SelectOp
1274+
//===----------------------------------------------------------------------===//
1275+
1276+
OpFoldResult cir::SelectOp::fold(FoldAdaptor adaptor) {
1277+
mlir::Attribute condition = adaptor.getCondition();
1278+
if (condition) {
1279+
bool conditionValue = mlir::cast<cir::BoolAttr>(condition).getValue();
1280+
return conditionValue ? getTrueValue() : getFalseValue();
1281+
}
1282+
1283+
// cir.select if %0 then x else x -> x
1284+
mlir::Attribute trueValue = adaptor.getTrueValue();
1285+
mlir::Attribute falseValue = adaptor.getFalseValue();
1286+
if (trueValue == falseValue)
1287+
return trueValue;
1288+
if (getTrueValue() == getFalseValue())
1289+
return getTrueValue();
1290+
1291+
return {};
1292+
}
1293+
12641294
//===----------------------------------------------------------------------===//
12651295
// ShiftOp
12661296
//===----------------------------------------------------------------------===//

clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,14 +121,13 @@ void CIRCanonicalizePass::runOnOperation() {
121121
getOperation()->walk([&](Operation *op) {
122122
assert(!cir::MissingFeatures::switchOp());
123123
assert(!cir::MissingFeatures::tryOp());
124-
assert(!cir::MissingFeatures::selectOp());
125124
assert(!cir::MissingFeatures::complexCreateOp());
126125
assert(!cir::MissingFeatures::complexRealOp());
127126
assert(!cir::MissingFeatures::complexImagOp());
128127
assert(!cir::MissingFeatures::callOp());
129128
// CastOp and UnaryOp are here to perform a manual `fold` in
130129
// applyOpPatternsGreedily.
131-
if (isa<BrOp, BrCondOp, ScopeOp, CastOp, UnaryOp>(op))
130+
if (isa<BrOp, BrCondOp, CastOp, ScopeOp, SelectOp, UnaryOp>(op))
132131
ops.push_back(op);
133132
});
134133

Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "PassDetail.h"
10+
#include "mlir/Dialect/Func/IR/FuncOps.h"
11+
#include "mlir/IR/Block.h"
12+
#include "mlir/IR/Operation.h"
13+
#include "mlir/IR/PatternMatch.h"
14+
#include "mlir/IR/Region.h"
15+
#include "mlir/Support/LogicalResult.h"
16+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
17+
#include "clang/CIR/Dialect/IR/CIRDialect.h"
18+
#include "clang/CIR/Dialect/Passes.h"
19+
#include "llvm/ADT/SmallVector.h"
20+
21+
using namespace mlir;
22+
using namespace cir;
23+
24+
//===----------------------------------------------------------------------===//
25+
// Rewrite patterns
26+
//===----------------------------------------------------------------------===//
27+
28+
namespace {
29+
30+
/// Simplify suitable ternary operations into select operations.
31+
///
32+
/// For now we only simplify those ternary operations whose true and false
33+
/// branches directly yield a value or a constant. That is, both of the true and
34+
/// the false branch must either contain a cir.yield operation as the only
35+
/// operation in the branch, or contain a cir.const operation followed by a
36+
/// cir.yield operation that yields the constant value.
37+
///
38+
/// For example, we will simplify the following ternary operation:
39+
///
40+
/// %0 = ...
41+
/// %1 = cir.ternary (%condition, true {
42+
/// %2 = cir.const ...
43+
/// cir.yield %2
44+
/// } false {
45+
/// cir.yield %0
46+
///
47+
/// into the following sequence of operations:
48+
///
49+
/// %1 = cir.const ...
50+
/// %0 = cir.select if %condition then %1 else %2
51+
struct SimplifyTernary final : public OpRewritePattern<TernaryOp> {
52+
using OpRewritePattern<TernaryOp>::OpRewritePattern;
53+
54+
LogicalResult matchAndRewrite(TernaryOp op,
55+
PatternRewriter &rewriter) const override {
56+
if (op->getNumResults() != 1)
57+
return mlir::failure();
58+
59+
if (!isSimpleTernaryBranch(op.getTrueRegion()) ||
60+
!isSimpleTernaryBranch(op.getFalseRegion()))
61+
return mlir::failure();
62+
63+
cir::YieldOp trueBranchYieldOp =
64+
mlir::cast<cir::YieldOp>(op.getTrueRegion().front().getTerminator());
65+
cir::YieldOp falseBranchYieldOp =
66+
mlir::cast<cir::YieldOp>(op.getFalseRegion().front().getTerminator());
67+
mlir::Value trueValue = trueBranchYieldOp.getArgs()[0];
68+
mlir::Value falseValue = falseBranchYieldOp.getArgs()[0];
69+
70+
rewriter.inlineBlockBefore(&op.getTrueRegion().front(), op);
71+
rewriter.inlineBlockBefore(&op.getFalseRegion().front(), op);
72+
rewriter.eraseOp(trueBranchYieldOp);
73+
rewriter.eraseOp(falseBranchYieldOp);
74+
rewriter.replaceOpWithNewOp<cir::SelectOp>(op, op.getCond(), trueValue,
75+
falseValue);
76+
77+
return mlir::success();
78+
}
79+
80+
private:
81+
bool isSimpleTernaryBranch(mlir::Region &region) const {
82+
if (!region.hasOneBlock())
83+
return false;
84+
85+
mlir::Block &onlyBlock = region.front();
86+
mlir::Block::OpListType &ops = onlyBlock.getOperations();
87+
88+
// The region/block could only contain at most 2 operations.
89+
if (ops.size() > 2)
90+
return false;
91+
92+
if (ops.size() == 1) {
93+
// The region/block only contain a cir.yield operation.
94+
return true;
95+
}
96+
97+
// Check whether the region/block contains a cir.const followed by a
98+
// cir.yield that yields the value.
99+
auto yieldOp = mlir::cast<cir::YieldOp>(onlyBlock.getTerminator());
100+
auto yieldValueDefOp = mlir::dyn_cast_if_present<cir::ConstantOp>(
101+
yieldOp.getArgs()[0].getDefiningOp());
102+
return yieldValueDefOp && yieldValueDefOp->getBlock() == &onlyBlock;
103+
}
104+
};
105+
106+
/// Simplify select operations with boolean constants into simpler forms.
107+
///
108+
/// This pattern simplifies select operations where both true and false values
109+
/// are boolean constants. Two specific cases are handled:
110+
///
111+
/// 1. When selecting between true and false based on a condition,
112+
/// the operation simplifies to just the condition itself:
113+
///
114+
/// %0 = cir.select if %condition then true else false
115+
/// ->
116+
/// (replaced with %condition directly)
117+
///
118+
/// 2. When selecting between false and true based on a condition,
119+
/// the operation simplifies to the logical negation of the condition:
120+
///
121+
/// %0 = cir.select if %condition then false else true
122+
/// ->
123+
/// %0 = cir.unary not %condition
124+
struct SimplifySelect : public OpRewritePattern<SelectOp> {
125+
using OpRewritePattern<SelectOp>::OpRewritePattern;
126+
127+
LogicalResult matchAndRewrite(SelectOp op,
128+
PatternRewriter &rewriter) const final {
129+
mlir::Operation *trueValueOp = op.getTrueValue().getDefiningOp();
130+
mlir::Operation *falseValueOp = op.getFalseValue().getDefiningOp();
131+
auto trueValueConstOp =
132+
mlir::dyn_cast_if_present<cir::ConstantOp>(trueValueOp);
133+
auto falseValueConstOp =
134+
mlir::dyn_cast_if_present<cir::ConstantOp>(falseValueOp);
135+
if (!trueValueConstOp || !falseValueConstOp)
136+
return mlir::failure();
137+
138+
auto trueValue = mlir::dyn_cast<cir::BoolAttr>(trueValueConstOp.getValue());
139+
auto falseValue =
140+
mlir::dyn_cast<cir::BoolAttr>(falseValueConstOp.getValue());
141+
if (!trueValue || !falseValue)
142+
return mlir::failure();
143+
144+
// cir.select if %0 then #true else #false -> %0
145+
if (trueValue.getValue() && !falseValue.getValue()) {
146+
rewriter.replaceAllUsesWith(op, op.getCondition());
147+
rewriter.eraseOp(op);
148+
return mlir::success();
149+
}
150+
151+
// cir.select if %0 then #false else #true -> cir.unary not %0
152+
if (!trueValue.getValue() && falseValue.getValue()) {
153+
rewriter.replaceOpWithNewOp<cir::UnaryOp>(op, cir::UnaryOpKind::Not,
154+
op.getCondition());
155+
return mlir::success();
156+
}
157+
158+
return mlir::failure();
159+
}
160+
};
161+
162+
//===----------------------------------------------------------------------===//
163+
// CIRSimplifyPass
164+
//===----------------------------------------------------------------------===//
165+
166+
struct CIRSimplifyPass : public CIRSimplifyBase<CIRSimplifyPass> {
167+
using CIRSimplifyBase::CIRSimplifyBase;
168+
169+
void runOnOperation() override;
170+
};
171+
172+
void populateMergeCleanupPatterns(RewritePatternSet &patterns) {
173+
// clang-format off
174+
patterns.add<
175+
SimplifyTernary,
176+
SimplifySelect
177+
>(patterns.getContext());
178+
// clang-format on
179+
}
180+
181+
void CIRSimplifyPass::runOnOperation() {
182+
// Collect rewrite patterns.
183+
RewritePatternSet patterns(&getContext());
184+
populateMergeCleanupPatterns(patterns);
185+
186+
// Collect operations to apply patterns.
187+
llvm::SmallVector<Operation *, 16> ops;
188+
getOperation()->walk([&](Operation *op) {
189+
if (isa<TernaryOp, SelectOp>(op))
190+
ops.push_back(op);
191+
});
192+
193+
// Apply patterns.
194+
if (applyOpPatternsGreedily(ops, std::move(patterns)).failed())
195+
signalPassFailure();
196+
}
197+
198+
} // namespace
199+
200+
std::unique_ptr<Pass> mlir::createCIRSimplifyPass() {
201+
return std::make_unique<CIRSimplifyPass>();
202+
}

clang/lib/CIR/Dialect/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
add_clang_library(MLIRCIRTransforms
22
CIRCanonicalize.cpp
3+
CIRSimplify.cpp
34
FlattenCFG.cpp
45
HoistAllocas.cpp
56

clang/lib/CIR/FrontendAction/CIRGenAction.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,15 +62,16 @@ class CIRGenConsumer : public clang::ASTConsumer {
6262
IntrusiveRefCntPtr<llvm::vfs::FileSystem> FS;
6363
std::unique_ptr<CIRGenerator> Gen;
6464
const FrontendOptions &FEOptions;
65+
CodeGenOptions &CGO;
6566

6667
public:
6768
CIRGenConsumer(CIRGenAction::OutputType Action, CompilerInstance &CI,
68-
std::unique_ptr<raw_pwrite_stream> OS)
69+
CodeGenOptions &CGO, std::unique_ptr<raw_pwrite_stream> OS)
6970
: Action(Action), CI(CI), OutputStream(std::move(OS)),
7071
FS(&CI.getVirtualFileSystem()),
7172
Gen(std::make_unique<CIRGenerator>(CI.getDiagnostics(), std::move(FS),
7273
CI.getCodeGenOpts())),
73-
FEOptions(CI.getFrontendOpts()) {}
74+
FEOptions(CI.getFrontendOpts()), CGO(CGO) {}
7475

7576
void Initialize(ASTContext &Ctx) override {
7677
assert(!Context && "initialized multiple times");
@@ -102,7 +103,8 @@ class CIRGenConsumer : public clang::ASTConsumer {
102103
if (!FEOptions.ClangIRDisablePasses) {
103104
// Setup and run CIR pipeline.
104105
if (runCIRToCIRPasses(MlirModule, MlirCtx, C,
105-
!FEOptions.ClangIRDisableCIRVerifier)
106+
!FEOptions.ClangIRDisableCIRVerifier,
107+
CGO.OptimizationLevel > 0)
106108
.failed()) {
107109
CI.getDiagnostics().Report(diag::err_cir_to_cir_transform_failed);
108110
return;
@@ -168,8 +170,8 @@ CIRGenAction::CreateASTConsumer(CompilerInstance &CI, StringRef InFile) {
168170
if (!Out)
169171
Out = getOutputStream(CI, InFile, Action);
170172

171-
auto Result =
172-
std::make_unique<cir::CIRGenConsumer>(Action, CI, std::move(Out));
173+
auto Result = std::make_unique<cir::CIRGenConsumer>(
174+
Action, CI, CI.getCodeGenOpts(), std::move(Out));
173175

174176
return Result;
175177
}

0 commit comments

Comments
 (0)