Skip to content

[CIR] Add cir-simplify pass #138317

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion clang/include/clang/CIR/CIRToCIRPasses.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ namespace cir {
mlir::LogicalResult runCIRToCIRPasses(mlir::ModuleOp theModule,
mlir::MLIRContext &mlirCtx,
clang::ASTContext &astCtx,
bool enableVerifier);
bool enableVerifier,
bool enableCIRSimplify);

} // namespace cir

Expand Down
7 changes: 7 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIRDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,13 @@ def CIR_Dialect : Dialect {
let useDefaultAttributePrinterParser = 0;
let useDefaultTypePrinterParser = 0;

// Enable constant materialization for the CIR dialect. This generates a
// declaration for the cir::CIRDialect::materializeConstant function. This
// hook is necessary for canonicalization to properly handle attributes
// returned by fold methods, allowing them to be materialized as constant
// operations in the IR.
let hasConstantMaterializer = 1;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh boy would i love a comment explaining this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's some documentation here: https://mlir.llvm.org/docs/Canonicalization/#generating-constants-from-attributes

But yes I can add a small comment explaining that we need this for canonicalization.


let extraClassDeclaration = [{
static llvm::StringRef getTripleAttrName() { return "cir.triple"; }

Expand Down
2 changes: 2 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1464,6 +1464,8 @@ def SelectOp : CIR_Op<"select", [Pure,
qualified(type($false_value))
`)` `->` qualified(type($result)) attr-dict
}];

let hasFolder = 1;
}

//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions clang/include/clang/CIR/Dialect/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ namespace mlir {

std::unique_ptr<Pass> createCIRCanonicalizePass();
std::unique_ptr<Pass> createCIRFlattenCFGPass();
std::unique_ptr<Pass> createCIRSimplifyPass();
std::unique_ptr<Pass> createHoistAllocasPass();

void populateCIRPreLoweringPasses(mlir::OpPassManager &pm);
Expand Down
19 changes: 19 additions & 0 deletions clang/include/clang/CIR/Dialect/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,25 @@ def CIRCanonicalize : Pass<"cir-canonicalize"> {
let dependentDialects = ["cir::CIRDialect"];
}

def CIRSimplify : Pass<"cir-simplify"> {
let summary = "Performs CIR simplification and code optimization";
let description = [{
The pass performs semantics-preserving code simplifications and optimizations
on CIR while maintaining strict program correctness.

Unlike the `cir-canonicalize` pass, these transformations may reduce the IR's
structural similarity to the original source code as a trade-off for improved
code quality. This can affect debugging fidelity by altering intermediate
representations of folded expressions, hoisted operations, and other
optimized constructs.

Example transformations include ternary expression folding and code hoisting
while preserving program semantics.
}];
let constructor = "mlir::createCIRSimplifyPass()";
let dependentDialects = ["cir::CIRDialect"];
}

def HoistAllocas : Pass<"cir-hoist-allocas"> {
let summary = "Hoist allocas to the entry of the function";
let description = [{
Expand Down
1 change: 0 additions & 1 deletion clang/include/clang/CIR/MissingFeatures.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,6 @@ struct MissingFeatures {
static bool labelOp() { return false; }
static bool ptrDiffOp() { return false; }
static bool ptrStrideOp() { return false; }
static bool selectOp() { return false; }
static bool switchOp() { return false; }
static bool ternaryOp() { return false; }
static bool tryOp() { return false; }
Expand Down
30 changes: 30 additions & 0 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,14 @@ void cir::CIRDialect::initialize() {
addInterfaces<CIROpAsmDialectInterface>();
}

Operation *cir::CIRDialect::materializeConstant(mlir::OpBuilder &builder,
mlir::Attribute value,
mlir::Type type,
mlir::Location loc) {
return builder.create<cir::ConstantOp>(loc, type,
mlir::cast<mlir::TypedAttr>(value));
}

//===----------------------------------------------------------------------===//
// Helpers
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1261,6 +1269,28 @@ void cir::TernaryOp::build(
result.addTypes(TypeRange{yield.getOperandTypes().front()});
}

//===----------------------------------------------------------------------===//
// SelectOp
//===----------------------------------------------------------------------===//

OpFoldResult cir::SelectOp::fold(FoldAdaptor adaptor) {
mlir::Attribute condition = adaptor.getCondition();
if (condition) {
bool conditionValue = mlir::cast<cir::BoolAttr>(condition).getValue();
return conditionValue ? getTrueValue() : getFalseValue();
}

// cir.select if %0 then x else x -> x
mlir::Attribute trueValue = adaptor.getTrueValue();
mlir::Attribute falseValue = adaptor.getFalseValue();
if (trueValue == falseValue)
return trueValue;
if (getTrueValue() == getFalseValue())
return getTrueValue();

return {};
}

//===----------------------------------------------------------------------===//
// ShiftOp
//===----------------------------------------------------------------------===//
Expand Down
3 changes: 1 addition & 2 deletions clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,14 +121,13 @@ void CIRCanonicalizePass::runOnOperation() {
getOperation()->walk([&](Operation *op) {
assert(!cir::MissingFeatures::switchOp());
assert(!cir::MissingFeatures::tryOp());
assert(!cir::MissingFeatures::selectOp());
assert(!cir::MissingFeatures::complexCreateOp());
assert(!cir::MissingFeatures::complexRealOp());
assert(!cir::MissingFeatures::complexImagOp());
assert(!cir::MissingFeatures::callOp());
// CastOp and UnaryOp are here to perform a manual `fold` in
// applyOpPatternsGreedily.
if (isa<BrOp, BrCondOp, ScopeOp, CastOp, UnaryOp>(op))
if (isa<BrOp, BrCondOp, CastOp, ScopeOp, SelectOp, UnaryOp>(op))
ops.push_back(op);
});

Expand Down
202 changes: 202 additions & 0 deletions clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "PassDetail.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Region.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "clang/CIR/Dialect/IR/CIRDialect.h"
#include "clang/CIR/Dialect/Passes.h"
#include "llvm/ADT/SmallVector.h"

using namespace mlir;
using namespace cir;

//===----------------------------------------------------------------------===//
// Rewrite patterns
//===----------------------------------------------------------------------===//

namespace {

/// Simplify suitable ternary operations into select operations.
///
/// For now we only simplify those ternary operations whose true and false
/// branches directly yield a value or a constant. That is, both of the true and
/// the false branch must either contain a cir.yield operation as the only
/// operation in the branch, or contain a cir.const operation followed by a
/// cir.yield operation that yields the constant value.
///
/// For example, we will simplify the following ternary operation:
///
/// %0 = ...
/// %1 = cir.ternary (%condition, true {
/// %2 = cir.const ...
/// cir.yield %2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This example is confusing. The identifier %2 gives the impression that it was defined sometime after %0 but I don't think that's the intention. This will only happen if the false case returns a value that exists prior to the ternary, right?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would probably look a bit better:

///   %0 = ...
///   %1 = cir.ternary (%condition, true {
///     %2 = cir.const ...
///     cir.yield %2
///   } false {
///     cir.yield %0

/// } false {
/// cir.yield %0
///
/// into the following sequence of operations:
///
/// %1 = cir.const ...
/// %0 = cir.select if %condition then %1 else %2
struct SimplifyTernary final : public OpRewritePattern<TernaryOp> {
using OpRewritePattern<TernaryOp>::OpRewritePattern;

LogicalResult matchAndRewrite(TernaryOp op,
PatternRewriter &rewriter) const override {
if (op->getNumResults() != 1)
return mlir::failure();

if (!isSimpleTernaryBranch(op.getTrueRegion()) ||
!isSimpleTernaryBranch(op.getFalseRegion()))
return mlir::failure();

cir::YieldOp trueBranchYieldOp =
mlir::cast<cir::YieldOp>(op.getTrueRegion().front().getTerminator());
cir::YieldOp falseBranchYieldOp =
mlir::cast<cir::YieldOp>(op.getFalseRegion().front().getTerminator());
mlir::Value trueValue = trueBranchYieldOp.getArgs()[0];
mlir::Value falseValue = falseBranchYieldOp.getArgs()[0];

rewriter.inlineBlockBefore(&op.getTrueRegion().front(), op);
rewriter.inlineBlockBefore(&op.getFalseRegion().front(), op);
rewriter.eraseOp(trueBranchYieldOp);
rewriter.eraseOp(falseBranchYieldOp);
rewriter.replaceOpWithNewOp<cir::SelectOp>(op, op.getCond(), trueValue,
falseValue);

return mlir::success();
}

private:
bool isSimpleTernaryBranch(mlir::Region &region) const {
if (!region.hasOneBlock())
return false;

mlir::Block &onlyBlock = region.front();
mlir::Block::OpListType &ops = onlyBlock.getOperations();

// The region/block could only contain at most 2 operations.
if (ops.size() > 2)
return false;

if (ops.size() == 1) {
// The region/block only contain a cir.yield operation.
return true;
}

// Check whether the region/block contains a cir.const followed by a
// cir.yield that yields the value.
auto yieldOp = mlir::cast<cir::YieldOp>(onlyBlock.getTerminator());
auto yieldValueDefOp = mlir::dyn_cast_if_present<cir::ConstantOp>(
yieldOp.getArgs()[0].getDefiningOp());
return yieldValueDefOp && yieldValueDefOp->getBlock() == &onlyBlock;
}
};

/// Simplify select operations with boolean constants into simpler forms.
///
/// This pattern simplifies select operations where both true and false values
/// are boolean constants. Two specific cases are handled:
///
/// 1. When selecting between true and false based on a condition,
/// the operation simplifies to just the condition itself:
///
/// %0 = cir.select if %condition then true else false
/// ->
/// (replaced with %condition directly)
///
/// 2. When selecting between false and true based on a condition,
/// the operation simplifies to the logical negation of the condition:
///
/// %0 = cir.select if %condition then false else true
/// ->
/// %0 = cir.unary not %condition
struct SimplifySelect : public OpRewritePattern<SelectOp> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be great if this had an explanatory comment like SimplifyTernary does. If there are going to be more cases added later, a general comment is fine. Otherwise, the two cases handled here can be explained pretty easily here.

using OpRewritePattern<SelectOp>::OpRewritePattern;

LogicalResult matchAndRewrite(SelectOp op,
PatternRewriter &rewriter) const final {
mlir::Operation *trueValueOp = op.getTrueValue().getDefiningOp();
mlir::Operation *falseValueOp = op.getFalseValue().getDefiningOp();
auto trueValueConstOp =
mlir::dyn_cast_if_present<cir::ConstantOp>(trueValueOp);
auto falseValueConstOp =
mlir::dyn_cast_if_present<cir::ConstantOp>(falseValueOp);
if (!trueValueConstOp || !falseValueConstOp)
return mlir::failure();

auto trueValue = mlir::dyn_cast<cir::BoolAttr>(trueValueConstOp.getValue());
auto falseValue =
mlir::dyn_cast<cir::BoolAttr>(falseValueConstOp.getValue());
if (!trueValue || !falseValue)
return mlir::failure();

// cir.select if %0 then #true else #false -> %0
if (trueValue.getValue() && !falseValue.getValue()) {
rewriter.replaceAllUsesWith(op, op.getCondition());
rewriter.eraseOp(op);
return mlir::success();
}

// cir.select if %0 then #false else #true -> cir.unary not %0
if (!trueValue.getValue() && falseValue.getValue()) {
rewriter.replaceOpWithNewOp<cir::UnaryOp>(op, cir::UnaryOpKind::Not,
op.getCondition());
return mlir::success();
}

return mlir::failure();
}
};

//===----------------------------------------------------------------------===//
// CIRSimplifyPass
//===----------------------------------------------------------------------===//

struct CIRSimplifyPass : public CIRSimplifyBase<CIRSimplifyPass> {
using CIRSimplifyBase::CIRSimplifyBase;

void runOnOperation() override;
};

void populateMergeCleanupPatterns(RewritePatternSet &patterns) {
// clang-format off
patterns.add<
SimplifyTernary,
SimplifySelect
>(patterns.getContext());
// clang-format on
}

void CIRSimplifyPass::runOnOperation() {
// Collect rewrite patterns.
RewritePatternSet patterns(&getContext());
populateMergeCleanupPatterns(patterns);

// Collect operations to apply patterns.
llvm::SmallVector<Operation *, 16> ops;
getOperation()->walk([&](Operation *op) {
if (isa<TernaryOp, SelectOp>(op))
ops.push_back(op);
});

// Apply patterns.
if (applyOpPatternsGreedily(ops, std::move(patterns)).failed())
signalPassFailure();
}

} // namespace

std::unique_ptr<Pass> mlir::createCIRSimplifyPass() {
return std::make_unique<CIRSimplifyPass>();
}
1 change: 1 addition & 0 deletions clang/lib/CIR/Dialect/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
add_clang_library(MLIRCIRTransforms
CIRCanonicalize.cpp
CIRSimplify.cpp
FlattenCFG.cpp
HoistAllocas.cpp

Expand Down
12 changes: 7 additions & 5 deletions clang/lib/CIR/FrontendAction/CIRGenAction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,16 @@ class CIRGenConsumer : public clang::ASTConsumer {
IntrusiveRefCntPtr<llvm::vfs::FileSystem> FS;
std::unique_ptr<CIRGenerator> Gen;
const FrontendOptions &FEOptions;
CodeGenOptions &CGO;

public:
CIRGenConsumer(CIRGenAction::OutputType Action, CompilerInstance &CI,
std::unique_ptr<raw_pwrite_stream> OS)
CodeGenOptions &CGO, std::unique_ptr<raw_pwrite_stream> OS)
: Action(Action), CI(CI), OutputStream(std::move(OS)),
FS(&CI.getVirtualFileSystem()),
Gen(std::make_unique<CIRGenerator>(CI.getDiagnostics(), std::move(FS),
CI.getCodeGenOpts())),
FEOptions(CI.getFrontendOpts()) {}
FEOptions(CI.getFrontendOpts()), CGO(CGO) {}

void Initialize(ASTContext &Ctx) override {
assert(!Context && "initialized multiple times");
Expand Down Expand Up @@ -102,7 +103,8 @@ class CIRGenConsumer : public clang::ASTConsumer {
if (!FEOptions.ClangIRDisablePasses) {
// Setup and run CIR pipeline.
if (runCIRToCIRPasses(MlirModule, MlirCtx, C,
!FEOptions.ClangIRDisableCIRVerifier)
!FEOptions.ClangIRDisableCIRVerifier,
CGO.OptimizationLevel > 0)
.failed()) {
CI.getDiagnostics().Report(diag::err_cir_to_cir_transform_failed);
return;
Expand Down Expand Up @@ -168,8 +170,8 @@ CIRGenAction::CreateASTConsumer(CompilerInstance &CI, StringRef InFile) {
if (!Out)
Out = getOutputStream(CI, InFile, Action);

auto Result =
std::make_unique<cir::CIRGenConsumer>(Action, CI, std::move(Out));
auto Result = std::make_unique<cir::CIRGenConsumer>(
Action, CI, CI.getCodeGenOpts(), std::move(Out));

return Result;
}
Expand Down
Loading