Skip to content

Commit a19e685

Browse files
authored
[CIR] Realign CIR-to-LLVM IR lowering code with incubator (#129293)
The previously upstreamed lowering from ClangIR to LLVM IR diverged from the incubator implementation, but when the incubator was updated to incorporate these changes some issues arose which require the upstream implementation to be modified to re-align with the incubator. First, in the earlier upstream implementation a CIRAttrVisitor class was introduced with the intention that an mlir-tblgen based extension would be created to automatically add all CIR attributes to the visitor. When I proposed this in mlir-tblgen a reviewer suggested that what I wanted could be better accomplished with TypeSwitch. See #126332 This was done in the incubator, and here I am bringing that implementation upstream. The other issue was that the global op initialization in the incubator had more cases than I had accounted for in my previous upstream refactoring. I did still refactor the incubator code, but not in quite the same way as the upstream code. This change re-aligns the two.
1 parent 6ff0f69 commit a19e685

File tree

3 files changed

+69
-107
lines changed

3 files changed

+69
-107
lines changed

clang/include/clang/CIR/Dialect/IR/CIRAttrVisitor.h

Lines changed: 0 additions & 52 deletions
This file was deleted.

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 69 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@
2424
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
2525
#include "mlir/Target/LLVMIR/Export.h"
2626
#include "mlir/Transforms/DialectConversion.h"
27-
#include "clang/CIR/Dialect/IR/CIRAttrVisitor.h"
2827
#include "clang/CIR/Dialect/IR/CIRDialect.h"
2928
#include "clang/CIR/MissingFeatures.h"
3029
#include "clang/CIR/Passes.h"
30+
#include "llvm/ADT/TypeSwitch.h"
3131
#include "llvm/IR/Module.h"
3232
#include "llvm/Support/TimeProfiler.h"
3333

@@ -37,63 +37,78 @@ using namespace llvm;
3737
namespace cir {
3838
namespace direct {
3939

40-
class CIRAttrToValue : public CirAttrVisitor<CIRAttrToValue, mlir::Value> {
40+
class CIRAttrToValue {
4141
public:
4242
CIRAttrToValue(mlir::Operation *parentOp,
4343
mlir::ConversionPatternRewriter &rewriter,
4444
const mlir::TypeConverter *converter)
4545
: parentOp(parentOp), rewriter(rewriter), converter(converter) {}
4646

47-
mlir::Value lowerCirAttrAsValue(mlir::Attribute attr) { return visit(attr); }
48-
49-
mlir::Value visitCirIntAttr(cir::IntAttr intAttr) {
50-
mlir::Location loc = parentOp->getLoc();
51-
return rewriter.create<mlir::LLVM::ConstantOp>(
52-
loc, converter->convertType(intAttr.getType()), intAttr.getValue());
53-
}
54-
55-
mlir::Value visitCirFPAttr(cir::FPAttr fltAttr) {
56-
mlir::Location loc = parentOp->getLoc();
57-
return rewriter.create<mlir::LLVM::ConstantOp>(
58-
loc, converter->convertType(fltAttr.getType()), fltAttr.getValue());
47+
mlir::Value visit(mlir::Attribute attr) {
48+
return llvm::TypeSwitch<mlir::Attribute, mlir::Value>(attr)
49+
.Case<cir::IntAttr, cir::FPAttr, cir::ConstPtrAttr>(
50+
[&](auto attrT) { return visitCirAttr(attrT); })
51+
.Default([&](auto attrT) { return mlir::Value(); });
5952
}
6053

61-
mlir::Value visitCirConstPtrAttr(cir::ConstPtrAttr ptrAttr) {
62-
mlir::Location loc = parentOp->getLoc();
63-
if (ptrAttr.isNullValue()) {
64-
return rewriter.create<mlir::LLVM::ZeroOp>(
65-
loc, converter->convertType(ptrAttr.getType()));
66-
}
67-
mlir::DataLayout layout(parentOp->getParentOfType<mlir::ModuleOp>());
68-
mlir::Value ptrVal = rewriter.create<mlir::LLVM::ConstantOp>(
69-
loc,
70-
rewriter.getIntegerType(layout.getTypeSizeInBits(ptrAttr.getType())),
71-
ptrAttr.getValue().getInt());
72-
return rewriter.create<mlir::LLVM::IntToPtrOp>(
73-
loc, converter->convertType(ptrAttr.getType()), ptrVal);
74-
}
54+
mlir::Value visitCirAttr(cir::IntAttr intAttr);
55+
mlir::Value visitCirAttr(cir::FPAttr fltAttr);
56+
mlir::Value visitCirAttr(cir::ConstPtrAttr ptrAttr);
7557

7658
private:
7759
mlir::Operation *parentOp;
7860
mlir::ConversionPatternRewriter &rewriter;
7961
const mlir::TypeConverter *converter;
8062
};
8163

64+
/// IntAttr visitor.
65+
mlir::Value CIRAttrToValue::visitCirAttr(cir::IntAttr intAttr) {
66+
mlir::Location loc = parentOp->getLoc();
67+
return rewriter.create<mlir::LLVM::ConstantOp>(
68+
loc, converter->convertType(intAttr.getType()), intAttr.getValue());
69+
}
70+
71+
/// ConstPtrAttr visitor.
72+
mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstPtrAttr ptrAttr) {
73+
mlir::Location loc = parentOp->getLoc();
74+
if (ptrAttr.isNullValue()) {
75+
return rewriter.create<mlir::LLVM::ZeroOp>(
76+
loc, converter->convertType(ptrAttr.getType()));
77+
}
78+
mlir::DataLayout layout(parentOp->getParentOfType<mlir::ModuleOp>());
79+
mlir::Value ptrVal = rewriter.create<mlir::LLVM::ConstantOp>(
80+
loc, rewriter.getIntegerType(layout.getTypeSizeInBits(ptrAttr.getType())),
81+
ptrAttr.getValue().getInt());
82+
return rewriter.create<mlir::LLVM::IntToPtrOp>(
83+
loc, converter->convertType(ptrAttr.getType()), ptrVal);
84+
}
85+
86+
/// FPAttr visitor.
87+
mlir::Value CIRAttrToValue::visitCirAttr(cir::FPAttr fltAttr) {
88+
mlir::Location loc = parentOp->getLoc();
89+
return rewriter.create<mlir::LLVM::ConstantOp>(
90+
loc, converter->convertType(fltAttr.getType()), fltAttr.getValue());
91+
}
92+
8293
// This class handles rewriting initializer attributes for types that do not
8394
// require region initialization.
84-
class GlobalInitAttrRewriter
85-
: public CirAttrVisitor<GlobalInitAttrRewriter, mlir::Attribute> {
95+
class GlobalInitAttrRewriter {
8696
public:
8797
GlobalInitAttrRewriter(mlir::Type type,
8898
mlir::ConversionPatternRewriter &rewriter)
8999
: llvmType(type), rewriter(rewriter) {}
90100

91-
mlir::Attribute rewriteInitAttr(mlir::Attribute attr) { return visit(attr); }
101+
mlir::Attribute visit(mlir::Attribute attr) {
102+
return llvm::TypeSwitch<mlir::Attribute, mlir::Attribute>(attr)
103+
.Case<cir::IntAttr, cir::FPAttr>(
104+
[&](auto attrT) { return visitCirAttr(attrT); })
105+
.Default([&](auto attrT) { return mlir::Attribute(); });
106+
}
92107

93-
mlir::Attribute visitCirIntAttr(cir::IntAttr attr) {
108+
mlir::Attribute visitCirAttr(cir::IntAttr attr) {
94109
return rewriter.getIntegerAttr(llvmType, attr.getValue());
95110
}
96-
mlir::Attribute visitCirFPAttr(cir::FPAttr attr) {
111+
mlir::Attribute visitCirAttr(cir::FPAttr attr) {
97112
return rewriter.getFloatAttr(llvmType, attr.getValue());
98113
}
99114

@@ -124,12 +139,6 @@ struct ConvertCIRToLLVMPass
124139
StringRef getArgument() const override { return "cir-flat-to-llvm"; }
125140
};
126141

127-
bool CIRToLLVMGlobalOpLowering::attrRequiresRegionInitialization(
128-
mlir::Attribute attr) const {
129-
// There will be more cases added later.
130-
return isa<cir::ConstPtrAttr>(attr);
131-
}
132-
133142
/// Replace CIR global with a region initialized LLVM global and update
134143
/// insertion point to the end of the initializer block.
135144
void CIRToLLVMGlobalOpLowering::setupRegionInitializedLLVMGlobalOp(
@@ -176,8 +185,8 @@ CIRToLLVMGlobalOpLowering::matchAndRewriteRegionInitializedGlobal(
176185
// to the appropriate value.
177186
const mlir::Location loc = op.getLoc();
178187
setupRegionInitializedLLVMGlobalOp(op, rewriter);
179-
CIRAttrToValue attrVisitor(op, rewriter, typeConverter);
180-
mlir::Value value = attrVisitor.lowerCirAttrAsValue(init);
188+
CIRAttrToValue valueConverter(op, rewriter, typeConverter);
189+
mlir::Value value = valueConverter.visit(init);
181190
rewriter.create<mlir::LLVM::ReturnOp>(loc, value);
182191
return mlir::success();
183192
}
@@ -188,12 +197,6 @@ mlir::LogicalResult CIRToLLVMGlobalOpLowering::matchAndRewrite(
188197

189198
std::optional<mlir::Attribute> init = op.getInitialValue();
190199

191-
// If we have an initializer and it requires region initialization, handle
192-
// that separately
193-
if (init.has_value() && attrRequiresRegionInitialization(init.value())) {
194-
return matchAndRewriteRegionInitializedGlobal(op, init.value(), rewriter);
195-
}
196-
197200
// Fetch required values to create LLVM op.
198201
const mlir::Type cirSymType = op.getSymType();
199202

@@ -218,12 +221,25 @@ mlir::LogicalResult CIRToLLVMGlobalOpLowering::matchAndRewrite(
218221
SmallVector<mlir::NamedAttribute> attributes;
219222

220223
if (init.has_value()) {
221-
GlobalInitAttrRewriter initRewriter(llvmType, rewriter);
222-
init = initRewriter.rewriteInitAttr(init.value());
223-
// If initRewriter returned a null attribute, init will have a value but
224-
// the value will be null. If that happens, initRewriter didn't handle the
225-
// attribute type. It probably needs to be added to GlobalInitAttrRewriter.
226-
if (!init.value()) {
224+
if (mlir::isa<cir::FPAttr, cir::IntAttr>(init.value())) {
225+
GlobalInitAttrRewriter initRewriter(llvmType, rewriter);
226+
init = initRewriter.visit(init.value());
227+
// If initRewriter returned a null attribute, init will have a value but
228+
// the value will be null. If that happens, initRewriter didn't handle the
229+
// attribute type. It probably needs to be added to
230+
// GlobalInitAttrRewriter.
231+
if (!init.value()) {
232+
op.emitError() << "unsupported initializer '" << init.value() << "'";
233+
return mlir::failure();
234+
}
235+
} else if (mlir::isa<cir::ConstPtrAttr>(init.value())) {
236+
// TODO(cir): once LLVM's dialect has proper equivalent attributes this
237+
// should be updated. For now, we use a custom op to initialize globals
238+
// to the appropriate value.
239+
return matchAndRewriteRegionInitializedGlobal(op, init.value(), rewriter);
240+
} else {
241+
// We will only get here if new initializer types are added and this
242+
// code is not updated to handle them.
227243
op.emitError() << "unsupported initializer '" << init.value() << "'";
228244
return mlir::failure();
229245
}

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,6 @@ class CIRToLLVMGlobalOpLowering
3636
mlir::ConversionPatternRewriter &rewriter) const override;
3737

3838
private:
39-
bool attrRequiresRegionInitialization(mlir::Attribute attr) const;
40-
4139
mlir::LogicalResult matchAndRewriteRegionInitializedGlobal(
4240
cir::GlobalOp op, mlir::Attribute init,
4341
mlir::ConversionPatternRewriter &rewriter) const;

0 commit comments

Comments
 (0)