Skip to content

Commit 050ca5e

Browse files
authored
[CIR][NFC] Simplify BoolAttr builders (#136366)
This mirrors incubator changes from llvm/clangir#1572
1 parent 8435de0 commit 050ca5e

File tree

6 files changed

+25
-20
lines changed

6 files changed

+25
-20
lines changed

clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
5757
public:
5858
CIRBaseBuilderTy(mlir::MLIRContext &mlirContext)
5959
: mlir::OpBuilder(&mlirContext) {}
60+
CIRBaseBuilderTy(mlir::OpBuilder &builder) : mlir::OpBuilder(builder) {}
6061

6162
mlir::Value getConstAPInt(mlir::Location loc, mlir::Type typ,
6263
const llvm::APInt &val) {
@@ -98,13 +99,13 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
9899
if (auto recordTy = mlir::dyn_cast<cir::RecordType>(ty))
99100
return getZeroAttr(recordTy);
100101
if (mlir::isa<cir::BoolType>(ty)) {
101-
return getCIRBoolAttr(false);
102+
return getFalseAttr();
102103
}
103104
llvm_unreachable("Zero initializer for given type is NYI");
104105
}
105106

106107
cir::ConstantOp getBool(bool state, mlir::Location loc) {
107-
return create<cir::ConstantOp>(loc, getBoolTy(), getCIRBoolAttr(state));
108+
return create<cir::ConstantOp>(loc, getCIRBoolAttr(state));
108109
}
109110
cir::ConstantOp getFalse(mlir::Location loc) { return getBool(false, loc); }
110111
cir::ConstantOp getTrue(mlir::Location loc) { return getBool(true, loc); }
@@ -120,9 +121,12 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
120121
}
121122

122123
cir::BoolAttr getCIRBoolAttr(bool state) {
123-
return cir::BoolAttr::get(getContext(), getBoolTy(), state);
124+
return cir::BoolAttr::get(getContext(), state);
124125
}
125126

127+
cir::BoolAttr getTrueAttr() { return getCIRBoolAttr(true); }
128+
cir::BoolAttr getFalseAttr() { return getCIRBoolAttr(false); }
129+
126130
mlir::Value createNot(mlir::Value value) {
127131
return create<cir::UnaryOp>(value.getLoc(), value.getType(),
128132
cir::UnaryOpKind::Not, value);

clang/include/clang/CIR/Dialect/IR/CIRAttrs.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,12 @@ def CIR_BoolAttr : CIR_Attr<"Bool", "bool", [TypedAttrInterface]> {
4949
"", "cir::BoolType">:$type,
5050
"bool":$value);
5151

52+
let builders = [
53+
AttrBuilder<(ins "bool":$value), [{
54+
return $_get($_ctxt, cir::BoolType::get($_ctxt), value);
55+
}]>,
56+
];
57+
5258
let assemblyFormat = [{
5359
`<` $value `>`
5460
}];

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,12 @@ def ConstantOp : CIR_Op<"const",
294294
// The constant operation returns a single value of CIR_AnyType.
295295
let results = (outs CIR_AnyType:$res);
296296

297+
let builders = [
298+
OpBuilder<(ins "cir::BoolAttr":$value), [{
299+
build($_builder, $_state, value.getType(), value);
300+
}]>
301+
];
302+
297303
let assemblyFormat = "attr-dict $value";
298304

299305
let hasVerifier = 1;
@@ -844,7 +850,7 @@ def UnaryOp : CIR_Op<"unary", [Pure, SameOperandsAndResultType]> {
844850
let assemblyFormat = [{
845851
`(` $kind `,` $input `)`
846852
(`nsw` $no_signed_wrap^)?
847-
`:` type($input) `,` type($result) attr-dict
853+
`:` type($input) `,` type($result) attr-dict
848854
}];
849855

850856
let hasVerifier = 1;

clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -152,10 +152,7 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
152152
}
153153

154154
mlir::Value VisitCXXBoolLiteralExpr(const CXXBoolLiteralExpr *e) {
155-
mlir::Type type = cgf.convertType(e->getType());
156-
return builder.create<cir::ConstantOp>(
157-
cgf.getLoc(e->getExprLoc()), type,
158-
builder.getCIRBoolAttr(e->getValue()));
155+
return builder.getBool(e->getValue(), cgf.getLoc(e->getExprLoc()));
159156
}
160157

161158
mlir::Value VisitCastExpr(CastExpr *e);
@@ -215,9 +212,7 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
215212

216213
if (llvm::isa<MemberPointerType>(srcType)) {
217214
cgf.getCIRGenModule().errorNYI(loc, "member pointer to bool conversion");
218-
mlir::Type boolType = builder.getBoolTy();
219-
return builder.create<cir::ConstantOp>(loc, boolType,
220-
builder.getCIRBoolAttr(false));
215+
return builder.getFalse(loc);
221216
}
222217

223218
if (srcType->isIntegerType())
@@ -354,9 +349,7 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
354349
// An interesting aspect of this is that increment is always true.
355350
// Decrement does not have this property.
356351
if (isInc && type->isBooleanType()) {
357-
value = builder.create<cir::ConstantOp>(cgf.getLoc(e->getExprLoc()),
358-
cgf.convertType(type),
359-
builder.getCIRBoolAttr(true));
352+
value = builder.getTrue(cgf.getLoc(e->getExprLoc()));
360353
} else if (type->isIntegerType()) {
361354
QualType promotedType;
362355
bool canPerformLossyDemotionCheck = false;

clang/lib/CIR/CodeGen/CIRGenStmt.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -456,9 +456,7 @@ mlir::LogicalResult CIRGenFunction::emitForStmt(const ForStmt &s) {
456456
// scalar type.
457457
condVal = evaluateExprAsBool(s.getCond());
458458
} else {
459-
cir::BoolType boolTy = cir::BoolType::get(b.getContext());
460-
condVal = b.create<cir::ConstantOp>(
461-
loc, boolTy, cir::BoolAttr::get(b.getContext(), boolTy, true));
459+
condVal = b.create<cir::ConstantOp>(loc, builder.getTrueAttr());
462460
}
463461
builder.createCondition(condVal);
464462
},

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -692,9 +692,7 @@ mlir::LogicalResult CIRToLLVMConstantOpLowering::matchAndRewrite(
692692
// during a pass as long as they don't live past the end of the pass.
693693
attr = op.getValue();
694694
} else if (mlir::isa<cir::BoolType>(op.getType())) {
695-
int value = (op.getValue() ==
696-
cir::BoolAttr::get(getContext(),
697-
cir::BoolType::get(getContext()), true));
695+
int value = mlir::cast<cir::BoolAttr>(op.getValue()).getValue();
698696
attr = rewriter.getIntegerAttr(typeConverter->convertType(op.getType()),
699697
value);
700698
} else if (mlir::isa<cir::IntType>(op.getType())) {

0 commit comments

Comments
 (0)