Skip to content

[CIR][NFC] Simplify BoolAttr builders #136366

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
public:
CIRBaseBuilderTy(mlir::MLIRContext &mlirContext)
: mlir::OpBuilder(&mlirContext) {}
CIRBaseBuilderTy(mlir::OpBuilder &builder) : mlir::OpBuilder(builder) {}

mlir::Value getConstAPInt(mlir::Location loc, mlir::Type typ,
const llvm::APInt &val) {
Expand Down Expand Up @@ -98,13 +99,13 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
if (auto recordTy = mlir::dyn_cast<cir::RecordType>(ty))
return getZeroAttr(recordTy);
if (mlir::isa<cir::BoolType>(ty)) {
return getCIRBoolAttr(false);
return getFalseAttr();
}
llvm_unreachable("Zero initializer for given type is NYI");
}

cir::ConstantOp getBool(bool state, mlir::Location loc) {
return create<cir::ConstantOp>(loc, getBoolTy(), getCIRBoolAttr(state));
return create<cir::ConstantOp>(loc, getCIRBoolAttr(state));
}
cir::ConstantOp getFalse(mlir::Location loc) { return getBool(false, loc); }
cir::ConstantOp getTrue(mlir::Location loc) { return getBool(true, loc); }
Expand All @@ -120,9 +121,12 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
}

cir::BoolAttr getCIRBoolAttr(bool state) {
return cir::BoolAttr::get(getContext(), getBoolTy(), state);
return cir::BoolAttr::get(getContext(), state);
}

cir::BoolAttr getTrueAttr() { return getCIRBoolAttr(true); }
cir::BoolAttr getFalseAttr() { return getCIRBoolAttr(false); }

mlir::Value createNot(mlir::Value value) {
return create<cir::UnaryOp>(value.getLoc(), value.getType(),
cir::UnaryOpKind::Not, value);
Expand Down
6 changes: 6 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIRAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ def CIR_BoolAttr : CIR_Attr<"Bool", "bool", [TypedAttrInterface]> {
"", "cir::BoolType">:$type,
"bool":$value);

let builders = [
AttrBuilder<(ins "bool":$value), [{
return $_get($_ctxt, cir::BoolType::get($_ctxt), value);
}]>,
];

let assemblyFormat = [{
`<` $value `>`
}];
Expand Down
8 changes: 7 additions & 1 deletion clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,12 @@ def ConstantOp : CIR_Op<"const",
// The constant operation returns a single value of CIR_AnyType.
let results = (outs CIR_AnyType:$res);

let builders = [
OpBuilder<(ins "cir::BoolAttr":$value), [{
build($_builder, $_state, value.getType(), value);
}]>
];

let assemblyFormat = "attr-dict $value";

let hasVerifier = 1;
Expand Down Expand Up @@ -844,7 +850,7 @@ def UnaryOp : CIR_Op<"unary", [Pure, SameOperandsAndResultType]> {
let assemblyFormat = [{
`(` $kind `,` $input `)`
(`nsw` $no_signed_wrap^)?
`:` type($input) `,` type($result) attr-dict
`:` type($input) `,` type($result) attr-dict
}];

let hasVerifier = 1;
Expand Down
13 changes: 3 additions & 10 deletions clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,7 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
}

mlir::Value VisitCXXBoolLiteralExpr(const CXXBoolLiteralExpr *e) {
mlir::Type type = cgf.convertType(e->getType());
return builder.create<cir::ConstantOp>(
cgf.getLoc(e->getExprLoc()), type,
builder.getCIRBoolAttr(e->getValue()));
return builder.getBool(e->getValue(), cgf.getLoc(e->getExprLoc()));
}

mlir::Value VisitCastExpr(CastExpr *e);
Expand Down Expand Up @@ -215,9 +212,7 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {

if (llvm::isa<MemberPointerType>(srcType)) {
cgf.getCIRGenModule().errorNYI(loc, "member pointer to bool conversion");
mlir::Type boolType = builder.getBoolTy();
return builder.create<cir::ConstantOp>(loc, boolType,
builder.getCIRBoolAttr(false));
return builder.getFalse(loc);
}

if (srcType->isIntegerType())
Expand Down Expand Up @@ -354,9 +349,7 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
// An interesting aspect of this is that increment is always true.
// Decrement does not have this property.
if (isInc && type->isBooleanType()) {
value = builder.create<cir::ConstantOp>(cgf.getLoc(e->getExprLoc()),
cgf.convertType(type),
builder.getCIRBoolAttr(true));
value = builder.getTrue(cgf.getLoc(e->getExprLoc()));
} else if (type->isIntegerType()) {
QualType promotedType;
bool canPerformLossyDemotionCheck = false;
Expand Down
4 changes: 1 addition & 3 deletions clang/lib/CIR/CodeGen/CIRGenStmt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -456,9 +456,7 @@ mlir::LogicalResult CIRGenFunction::emitForStmt(const ForStmt &s) {
// scalar type.
condVal = evaluateExprAsBool(s.getCond());
} else {
cir::BoolType boolTy = cir::BoolType::get(b.getContext());
condVal = b.create<cir::ConstantOp>(
loc, boolTy, cir::BoolAttr::get(b.getContext(), boolTy, true));
condVal = b.create<cir::ConstantOp>(loc, builder.getTrueAttr());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
condVal = b.create<cir::ConstantOp>(loc, builder.getTrueAttr());
condVal = b.getTrue();

Does this work?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

b is mlir::OpBuilder here which does not have this method.

It can be done as CIRBaseBuilderTy(b).getTrue().

}
builder.createCondition(condVal);
},
Expand Down
4 changes: 1 addition & 3 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -692,9 +692,7 @@ mlir::LogicalResult CIRToLLVMConstantOpLowering::matchAndRewrite(
// during a pass as long as they don't live past the end of the pass.
attr = op.getValue();
} else if (mlir::isa<cir::BoolType>(op.getType())) {
int value = (op.getValue() ==
cir::BoolAttr::get(getContext(),
cir::BoolType::get(getContext()), true));
int value = mlir::cast<cir::BoolAttr>(op.getValue()).getValue();
attr = rewriter.getIntegerAttr(typeConverter->convertType(op.getType()),
value);
} else if (mlir::isa<cir::IntType>(op.getType())) {
Expand Down
Loading