24
24
#include " mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
25
25
#include " mlir/Target/LLVMIR/Export.h"
26
26
#include " mlir/Transforms/DialectConversion.h"
27
- #include " clang/CIR/Dialect/IR/CIRAttrVisitor.h"
28
27
#include " clang/CIR/Dialect/IR/CIRDialect.h"
29
28
#include " clang/CIR/MissingFeatures.h"
30
29
#include " clang/CIR/Passes.h"
30
+ #include " llvm/ADT/TypeSwitch.h"
31
31
#include " llvm/IR/Module.h"
32
32
#include " llvm/Support/TimeProfiler.h"
33
33
@@ -37,63 +37,78 @@ using namespace llvm;
37
37
namespace cir {
38
38
namespace direct {
39
39
40
- class CIRAttrToValue : public CirAttrVisitor <CIRAttrToValue, mlir::Value> {
40
+ class CIRAttrToValue {
41
41
public:
42
42
CIRAttrToValue (mlir::Operation *parentOp,
43
43
mlir::ConversionPatternRewriter &rewriter,
44
44
const mlir::TypeConverter *converter)
45
45
: parentOp(parentOp), rewriter(rewriter), converter(converter) {}
46
46
47
- mlir::Value lowerCirAttrAsValue (mlir::Attribute attr) { return visit (attr); }
48
-
49
- mlir::Value visitCirIntAttr (cir::IntAttr intAttr) {
50
- mlir::Location loc = parentOp->getLoc ();
51
- return rewriter.create <mlir::LLVM::ConstantOp>(
52
- loc, converter->convertType (intAttr.getType ()), intAttr.getValue ());
53
- }
54
-
55
- mlir::Value visitCirFPAttr (cir::FPAttr fltAttr) {
56
- mlir::Location loc = parentOp->getLoc ();
57
- return rewriter.create <mlir::LLVM::ConstantOp>(
58
- loc, converter->convertType (fltAttr.getType ()), fltAttr.getValue ());
47
+ mlir::Value visit (mlir::Attribute attr) {
48
+ return llvm::TypeSwitch<mlir::Attribute, mlir::Value>(attr)
49
+ .Case <cir::IntAttr, cir::FPAttr, cir::ConstPtrAttr>(
50
+ [&](auto attrT) { return visitCirAttr (attrT); })
51
+ .Default ([&](auto attrT) { return mlir::Value (); });
59
52
}
60
53
61
- mlir::Value visitCirConstPtrAttr (cir::ConstPtrAttr ptrAttr) {
62
- mlir::Location loc = parentOp->getLoc ();
63
- if (ptrAttr.isNullValue ()) {
64
- return rewriter.create <mlir::LLVM::ZeroOp>(
65
- loc, converter->convertType (ptrAttr.getType ()));
66
- }
67
- mlir::DataLayout layout (parentOp->getParentOfType <mlir::ModuleOp>());
68
- mlir::Value ptrVal = rewriter.create <mlir::LLVM::ConstantOp>(
69
- loc,
70
- rewriter.getIntegerType (layout.getTypeSizeInBits (ptrAttr.getType ())),
71
- ptrAttr.getValue ().getInt ());
72
- return rewriter.create <mlir::LLVM::IntToPtrOp>(
73
- loc, converter->convertType (ptrAttr.getType ()), ptrVal);
74
- }
54
+ mlir::Value visitCirAttr (cir::IntAttr intAttr);
55
+ mlir::Value visitCirAttr (cir::FPAttr fltAttr);
56
+ mlir::Value visitCirAttr (cir::ConstPtrAttr ptrAttr);
75
57
76
58
private:
77
59
mlir::Operation *parentOp;
78
60
mlir::ConversionPatternRewriter &rewriter;
79
61
const mlir::TypeConverter *converter;
80
62
};
81
63
64
+ // / IntAttr visitor.
65
+ mlir::Value CIRAttrToValue::visitCirAttr (cir::IntAttr intAttr) {
66
+ mlir::Location loc = parentOp->getLoc ();
67
+ return rewriter.create <mlir::LLVM::ConstantOp>(
68
+ loc, converter->convertType (intAttr.getType ()), intAttr.getValue ());
69
+ }
70
+
71
+ // / ConstPtrAttr visitor.
72
+ mlir::Value CIRAttrToValue::visitCirAttr (cir::ConstPtrAttr ptrAttr) {
73
+ mlir::Location loc = parentOp->getLoc ();
74
+ if (ptrAttr.isNullValue ()) {
75
+ return rewriter.create <mlir::LLVM::ZeroOp>(
76
+ loc, converter->convertType (ptrAttr.getType ()));
77
+ }
78
+ mlir::DataLayout layout (parentOp->getParentOfType <mlir::ModuleOp>());
79
+ mlir::Value ptrVal = rewriter.create <mlir::LLVM::ConstantOp>(
80
+ loc, rewriter.getIntegerType (layout.getTypeSizeInBits (ptrAttr.getType ())),
81
+ ptrAttr.getValue ().getInt ());
82
+ return rewriter.create <mlir::LLVM::IntToPtrOp>(
83
+ loc, converter->convertType (ptrAttr.getType ()), ptrVal);
84
+ }
85
+
86
+ // / FPAttr visitor.
87
+ mlir::Value CIRAttrToValue::visitCirAttr (cir::FPAttr fltAttr) {
88
+ mlir::Location loc = parentOp->getLoc ();
89
+ return rewriter.create <mlir::LLVM::ConstantOp>(
90
+ loc, converter->convertType (fltAttr.getType ()), fltAttr.getValue ());
91
+ }
92
+
82
93
// This class handles rewriting initializer attributes for types that do not
83
94
// require region initialization.
84
- class GlobalInitAttrRewriter
85
- : public CirAttrVisitor<GlobalInitAttrRewriter, mlir::Attribute> {
95
+ class GlobalInitAttrRewriter {
86
96
public:
87
97
GlobalInitAttrRewriter (mlir::Type type,
88
98
mlir::ConversionPatternRewriter &rewriter)
89
99
: llvmType(type), rewriter(rewriter) {}
90
100
91
- mlir::Attribute rewriteInitAttr (mlir::Attribute attr) { return visit (attr); }
101
+ mlir::Attribute visit (mlir::Attribute attr) {
102
+ return llvm::TypeSwitch<mlir::Attribute, mlir::Attribute>(attr)
103
+ .Case <cir::IntAttr, cir::FPAttr>(
104
+ [&](auto attrT) { return visitCirAttr (attrT); })
105
+ .Default ([&](auto attrT) { return mlir::Attribute (); });
106
+ }
92
107
93
- mlir::Attribute visitCirIntAttr (cir::IntAttr attr) {
108
+ mlir::Attribute visitCirAttr (cir::IntAttr attr) {
94
109
return rewriter.getIntegerAttr (llvmType, attr.getValue ());
95
110
}
96
- mlir::Attribute visitCirFPAttr (cir::FPAttr attr) {
111
+ mlir::Attribute visitCirAttr (cir::FPAttr attr) {
97
112
return rewriter.getFloatAttr (llvmType, attr.getValue ());
98
113
}
99
114
@@ -124,12 +139,6 @@ struct ConvertCIRToLLVMPass
124
139
StringRef getArgument () const override { return " cir-flat-to-llvm" ; }
125
140
};
126
141
127
- bool CIRToLLVMGlobalOpLowering::attrRequiresRegionInitialization (
128
- mlir::Attribute attr) const {
129
- // There will be more cases added later.
130
- return isa<cir::ConstPtrAttr>(attr);
131
- }
132
-
133
142
// / Replace CIR global with a region initialized LLVM global and update
134
143
// / insertion point to the end of the initializer block.
135
144
void CIRToLLVMGlobalOpLowering::setupRegionInitializedLLVMGlobalOp (
@@ -176,8 +185,8 @@ CIRToLLVMGlobalOpLowering::matchAndRewriteRegionInitializedGlobal(
176
185
// to the appropriate value.
177
186
const mlir::Location loc = op.getLoc ();
178
187
setupRegionInitializedLLVMGlobalOp (op, rewriter);
179
- CIRAttrToValue attrVisitor (op, rewriter, typeConverter);
180
- mlir::Value value = attrVisitor. lowerCirAttrAsValue (init);
188
+ CIRAttrToValue valueConverter (op, rewriter, typeConverter);
189
+ mlir::Value value = valueConverter. visit (init);
181
190
rewriter.create <mlir::LLVM::ReturnOp>(loc, value);
182
191
return mlir::success ();
183
192
}
@@ -188,12 +197,6 @@ mlir::LogicalResult CIRToLLVMGlobalOpLowering::matchAndRewrite(
188
197
189
198
std::optional<mlir::Attribute> init = op.getInitialValue ();
190
199
191
- // If we have an initializer and it requires region initialization, handle
192
- // that separately
193
- if (init.has_value () && attrRequiresRegionInitialization (init.value ())) {
194
- return matchAndRewriteRegionInitializedGlobal (op, init.value (), rewriter);
195
- }
196
-
197
200
// Fetch required values to create LLVM op.
198
201
const mlir::Type cirSymType = op.getSymType ();
199
202
@@ -218,12 +221,25 @@ mlir::LogicalResult CIRToLLVMGlobalOpLowering::matchAndRewrite(
218
221
SmallVector<mlir::NamedAttribute> attributes;
219
222
220
223
if (init.has_value ()) {
221
- GlobalInitAttrRewriter initRewriter (llvmType, rewriter);
222
- init = initRewriter.rewriteInitAttr (init.value ());
223
- // If initRewriter returned a null attribute, init will have a value but
224
- // the value will be null. If that happens, initRewriter didn't handle the
225
- // attribute type. It probably needs to be added to GlobalInitAttrRewriter.
226
- if (!init.value ()) {
224
+ if (mlir::isa<cir::FPAttr, cir::IntAttr>(init.value ())) {
225
+ GlobalInitAttrRewriter initRewriter (llvmType, rewriter);
226
+ init = initRewriter.visit (init.value ());
227
+ // If initRewriter returned a null attribute, init will have a value but
228
+ // the value will be null. If that happens, initRewriter didn't handle the
229
+ // attribute type. It probably needs to be added to
230
+ // GlobalInitAttrRewriter.
231
+ if (!init.value ()) {
232
+ op.emitError () << " unsupported initializer '" << init.value () << " '" ;
233
+ return mlir::failure ();
234
+ }
235
+ } else if (mlir::isa<cir::ConstPtrAttr>(init.value ())) {
236
+ // TODO(cir): once LLVM's dialect has proper equivalent attributes this
237
+ // should be updated. For now, we use a custom op to initialize globals
238
+ // to the appropriate value.
239
+ return matchAndRewriteRegionInitializedGlobal (op, init.value (), rewriter);
240
+ } else {
241
+ // We will only get here if new initializer types are added and this
242
+ // code is not updated to handle them.
227
243
op.emitError () << " unsupported initializer '" << init.value () << " '" ;
228
244
return mlir::failure ();
229
245
}
0 commit comments