@@ -40,15 +40,18 @@ struct CIRRecordLowering final {
40
40
// member type that ensures correct rounding.
41
41
struct MemberInfo final {
42
42
CharUnits offset;
43
- enum class InfoKind { Field } kind;
43
+ enum class InfoKind { Field, Base } kind;
44
44
mlir::Type data;
45
45
union {
46
46
const FieldDecl *fieldDecl;
47
- // CXXRecordDecl will be used here when base types are supported.
47
+ const CXXRecordDecl *cxxRecordDecl;
48
48
};
49
49
MemberInfo (CharUnits offset, InfoKind kind, mlir::Type data,
50
50
const FieldDecl *fieldDecl = nullptr )
51
- : offset(offset), kind(kind), data(data), fieldDecl(fieldDecl) {};
51
+ : offset{offset}, kind{kind}, data{data}, fieldDecl{fieldDecl} {}
52
+ MemberInfo (CharUnits offset, InfoKind kind, mlir::Type data,
53
+ const CXXRecordDecl *rd)
54
+ : offset{offset}, kind{kind}, data{data}, cxxRecordDecl{rd} {}
52
55
// MemberInfos are sorted so we define a < operator.
53
56
bool operator <(const MemberInfo &other) const {
54
57
return offset < other.offset ;
@@ -71,6 +74,8 @@ struct CIRRecordLowering final {
71
74
// / Inserts padding everywhere it's needed.
72
75
void insertPadding ();
73
76
77
+ void accumulateBases (const CXXRecordDecl *cxxRecordDecl);
78
+ void accumulateVPtrs ();
74
79
void accumulateFields ();
75
80
76
81
CharUnits bitsToCharUnits (uint64_t bitOffset) {
@@ -89,6 +94,9 @@ struct CIRRecordLowering final {
89
94
bool isZeroInitializable (const FieldDecl *fd) {
90
95
return cirGenTypes.isZeroInitializable (fd->getType ());
91
96
}
97
+ bool isZeroInitializable (const RecordDecl *rd) {
98
+ return cirGenTypes.isZeroInitializable (rd);
99
+ }
92
100
93
101
// / Wraps cir::IntType with some implicit arguments.
94
102
mlir::Type getUIntNType (uint64_t numBits) {
@@ -112,6 +120,11 @@ struct CIRRecordLowering final {
112
120
: cir::ArrayType::get (type, numberOfChars.getQuantity ());
113
121
}
114
122
123
+ // Gets the CIR BaseSubobject type from a CXXRecordDecl.
124
+ mlir::Type getStorageType (const CXXRecordDecl *RD) {
125
+ return cirGenTypes.getCIRGenRecordLayout (RD).getBaseSubobjectCIRType ();
126
+ }
127
+
115
128
mlir::Type getStorageType (const FieldDecl *fieldDecl) {
116
129
mlir::Type type = cirGenTypes.convertTypeForMem (fieldDecl->getType ());
117
130
if (fieldDecl->isBitField ()) {
@@ -145,6 +158,7 @@ struct CIRRecordLowering final {
145
158
// Output fields, consumed by CIRGenTypes::computeRecordLayout
146
159
llvm::SmallVector<mlir::Type, 16 > fieldTypes;
147
160
llvm::DenseMap<const FieldDecl *, unsigned > fieldIdxMap;
161
+ llvm::DenseMap<const CXXRecordDecl *, unsigned > nonVirtualBases;
148
162
cir::CIRDataLayout dataLayout;
149
163
150
164
LLVM_PREFERRED_TYPE (bool )
@@ -179,24 +193,20 @@ void CIRRecordLowering::lower() {
179
193
return ;
180
194
}
181
195
182
- assert (!cir::MissingFeatures::cxxSupport ());
183
-
196
+ assert (!cir::MissingFeatures::recordLayoutVirtualBases ());
184
197
CharUnits size = astRecordLayout.getSize ();
185
198
186
199
accumulateFields ();
187
200
188
201
if (const auto *cxxRecordDecl = dyn_cast<CXXRecordDecl>(recordDecl)) {
189
- if (cxxRecordDecl->getNumBases () > 0 ) {
190
- CIRGenModule &cgm = cirGenTypes.getCGModule ();
191
- cgm.errorNYI (recordDecl->getSourceRange (),
192
- " CIRRecordLowering::lower: derived CXXRecordDecl" );
193
- return ;
194
- }
202
+ accumulateVPtrs ();
203
+ accumulateBases (cxxRecordDecl);
195
204
if (members.empty ()) {
196
205
appendPaddingBytes (size);
197
206
assert (!cir::MissingFeatures::bitfields ());
198
207
return ;
199
208
}
209
+ assert (!cir::MissingFeatures::recordLayoutVirtualBases ());
200
210
}
201
211
202
212
llvm::stable_sort (members);
@@ -223,8 +233,10 @@ void CIRRecordLowering::fillOutputFields() {
223
233
fieldTypes.size () - 1 ;
224
234
// A field without storage must be a bitfield.
225
235
assert (!cir::MissingFeatures::bitfields ());
236
+ } else if (member.kind == MemberInfo::InfoKind::Base) {
237
+ nonVirtualBases[member.cxxRecordDecl ] = fieldTypes.size () - 1 ;
226
238
}
227
- assert (!cir::MissingFeatures::cxxSupport ());
239
+ assert (!cir::MissingFeatures::recordLayoutVirtualBases ());
228
240
}
229
241
}
230
242
@@ -254,9 +266,14 @@ void CIRRecordLowering::calculateZeroInit() {
254
266
continue ;
255
267
zeroInitializable = zeroInitializableAsBase = false ;
256
268
return ;
269
+ } else if (member.kind == MemberInfo::InfoKind::Base) {
270
+ if (isZeroInitializable (member.cxxRecordDecl ))
271
+ continue ;
272
+ zeroInitializable = false ;
273
+ if (member.kind == MemberInfo::InfoKind::Base)
274
+ zeroInitializableAsBase = false ;
257
275
}
258
- // TODO(cir): handle base types
259
- assert (!cir::MissingFeatures::cxxSupport ());
276
+ assert (!cir::MissingFeatures::recordLayoutVirtualBases ());
260
277
}
261
278
}
262
279
@@ -317,6 +334,27 @@ CIRGenTypes::computeRecordLayout(const RecordDecl *rd, cir::RecordType *ty) {
317
334
lowering.lower ();
318
335
319
336
// If we're in C++, compute the base subobject type.
337
+ cir::RecordType baseTy;
338
+ if (llvm::isa<CXXRecordDecl>(rd) && !rd->isUnion () &&
339
+ !rd->hasAttr <FinalAttr>()) {
340
+ baseTy = *ty;
341
+ if (lowering.astRecordLayout .getNonVirtualSize () !=
342
+ lowering.astRecordLayout .getSize ()) {
343
+ CIRRecordLowering baseLowering (*this , rd, /* Packed=*/ lowering.packed );
344
+ baseLowering.lower ();
345
+ std::string baseIdentifier = getRecordTypeName (rd, " .base" );
346
+ baseTy =
347
+ builder.getCompleteRecordTy (baseLowering.fieldTypes , baseIdentifier,
348
+ baseLowering.packed , baseLowering.padded );
349
+ // TODO(cir): add something like addRecordTypeName
350
+
351
+ // BaseTy and Ty must agree on their packedness for getCIRFieldNo to work
352
+ // on both of them with the same index.
353
+ assert (lowering.packed == baseLowering.packed &&
354
+ " Non-virtual and complete types must agree on packedness" );
355
+ }
356
+ }
357
+
320
358
if (llvm::isa<CXXRecordDecl>(rd) && !rd->isUnion () &&
321
359
!rd->hasAttr <FinalAttr>()) {
322
360
if (lowering.astRecordLayout .getNonVirtualSize () !=
@@ -332,10 +370,13 @@ CIRGenTypes::computeRecordLayout(const RecordDecl *rd, cir::RecordType *ty) {
332
370
ty->complete (lowering.fieldTypes , lowering.packed , lowering.padded );
333
371
334
372
auto rl = std::make_unique<CIRGenRecordLayout>(
335
- ty ? *ty : cir::RecordType (), ( bool )lowering. zeroInitializable ,
336
- (bool )lowering.zeroInitializableAsBase );
373
+ ty ? *ty : cir::RecordType{}, baseTy ? baseTy : cir::RecordType{} ,
374
+ (bool )lowering.zeroInitializable , ( bool )lowering. zeroInitializableAsBase );
337
375
338
376
assert (!cir::MissingFeatures::recordZeroInit ());
377
+
378
+ rl->nonVirtualBases .swap (lowering.nonVirtualBases );
379
+
339
380
assert (!cir::MissingFeatures::cxxSupport ());
340
381
assert (!cir::MissingFeatures::bitfields ());
341
382
@@ -415,3 +456,38 @@ void CIRRecordLowering::lowerUnion() {
415
456
if (layoutSize % getAlignment (storageType))
416
457
packed = true ;
417
458
}
459
+
460
+ void CIRRecordLowering::accumulateBases (const CXXRecordDecl *cxxRecordDecl) {
461
+ // If we've got a primary virtual base, we need to add it with the bases.
462
+ if (astRecordLayout.isPrimaryBaseVirtual ()) {
463
+ cirGenTypes.getCGModule ().errorNYI (recordDecl->getSourceRange (),
464
+ " accumulateBases: primary virtual base" );
465
+ }
466
+
467
+ // Accumulate the non-virtual bases.
468
+ for ([[maybe_unused]] const auto &base : cxxRecordDecl->bases ()) {
469
+ if (base.isVirtual ()) {
470
+ cirGenTypes.getCGModule ().errorNYI (recordDecl->getSourceRange (),
471
+ " accumulateBases: virtual base" );
472
+ continue ;
473
+ }
474
+ // Bases can be zero-sized even if not technically empty if they
475
+ // contain only a trailing array member.
476
+ const CXXRecordDecl *baseDecl = base.getType ()->getAsCXXRecordDecl ();
477
+ if (!baseDecl->isEmpty () &&
478
+ !astContext.getASTRecordLayout (baseDecl).getNonVirtualSize ().isZero ()) {
479
+ members.push_back (MemberInfo (astRecordLayout.getBaseClassOffset (baseDecl),
480
+ MemberInfo::InfoKind::Base,
481
+ getStorageType (baseDecl), baseDecl));
482
+ }
483
+ }
484
+ }
485
+
486
+ void CIRRecordLowering::accumulateVPtrs () {
487
+ if (astRecordLayout.hasOwnVFPtr ())
488
+ cirGenTypes.getCGModule ().errorNYI (recordDecl->getSourceRange (),
489
+ " accumulateVPtrs: hasOwnVFPtr" );
490
+ if (astRecordLayout.hasOwnVBPtr ())
491
+ cirGenTypes.getCGModule ().errorNYI (recordDecl->getSourceRange (),
492
+ " accumulateVPtrs: hasOwnVBPtr" );
493
+ }
0 commit comments