@@ -4846,6 +4846,11 @@ static clang::CXXMethodDecl *synthesizeCxxBaseMethod(
4846
4846
newMethod->setImplicit ();
4847
4847
newMethod->setImplicitlyInline ();
4848
4848
newMethod->setAccess (clang::AccessSpecifier::AS_public);
4849
+ if (method->hasAttr <clang::CFReturnsRetainedAttr>()) {
4850
+ // Return an FRT field at +1 if the base method also follows this
4851
+ // convention.
4852
+ newMethod->addAttr (clang::CFReturnsRetainedAttr::CreateImplicit (clangCtx));
4853
+ }
4849
4854
4850
4855
llvm::SmallVector<clang::ParmVarDecl *, 4 > params;
4851
4856
for (size_t i = 0 ; i < method->getNumParams (); ++i) {
@@ -5047,7 +5052,8 @@ synthesizeBaseClassMethodBody(AbstractFunctionDecl *afd, void *context) {
5047
5052
// to the base class while the field is accessed.
5048
5053
static clang::CXXMethodDecl *synthesizeCxxBaseGetterAccessorMethod (
5049
5054
ClangImporter &impl, const clang::CXXRecordDecl *derivedClass,
5050
- const clang::CXXRecordDecl *baseClass, const clang::FieldDecl *field) {
5055
+ const clang::CXXRecordDecl *baseClass, const clang::FieldDecl *field,
5056
+ ValueDecl *retainOperationFn) {
5051
5057
auto &clangCtx = impl.getClangASTContext ();
5052
5058
auto &clangSema = impl.getClangSema ();
5053
5059
@@ -5078,51 +5084,95 @@ static clang::CXXMethodDecl *synthesizeCxxBaseGetterAccessorMethod(
5078
5084
newMethod->setImplicit ();
5079
5085
newMethod->setImplicitlyInline ();
5080
5086
newMethod->setAccess (clang::AccessSpecifier::AS_public);
5087
+ if (retainOperationFn) {
5088
+ // Return an FRT field at +1.
5089
+ newMethod->addAttr (clang::CFReturnsRetainedAttr::CreateImplicit (clangCtx));
5090
+ }
5081
5091
5082
5092
// Create a new Clang diagnostic pool to capture any diagnostics
5083
5093
// emitted during the construction of the method.
5084
5094
clang::sema::DelayedDiagnosticPool diagPool{
5085
5095
clangSema.DelayedDiagnostics .getCurrentPool ()};
5086
5096
auto diagState = clangSema.DelayedDiagnostics .push (diagPool);
5087
5097
5098
+ // Returns the expression that accesses the base field from derived type.
5099
+ auto createFieldAccess = [&]() -> clang::Expr * {
5100
+ auto *thisExpr = new (clangCtx)
5101
+ clang::CXXThisExpr (clang::SourceLocation (), newMethod->getThisType (),
5102
+ /* IsImplicit=*/ false );
5103
+ clang::QualType baseClassPtr = clangCtx.getRecordType (baseClass);
5104
+ baseClassPtr.addConst ();
5105
+ baseClassPtr = clangCtx.getPointerType (baseClassPtr);
5106
+
5107
+ clang::CastKind Kind;
5108
+ clang::CXXCastPath Path;
5109
+ clangSema.CheckPointerConversion (thisExpr, baseClassPtr, Kind, Path,
5110
+ /* IgnoreBaseAccess=*/ false ,
5111
+ /* Diagnose=*/ true );
5112
+ auto conv = clangSema.ImpCastExprToType (thisExpr, baseClassPtr, Kind,
5113
+ clang::VK_PRValue, &Path);
5114
+ if (!conv.isUsable ())
5115
+ return nullptr ;
5116
+ auto memberExpr = clangSema.BuildMemberExpr (
5117
+ conv.get (), /* isArrow=*/ true , clang::SourceLocation (),
5118
+ clang::NestedNameSpecifierLoc (), clang::SourceLocation (),
5119
+ const_cast <clang::FieldDecl *>(field),
5120
+ clang::DeclAccessPair::make (const_cast <clang::FieldDecl *>(field),
5121
+ clang::AS_public),
5122
+ /* HadMultipleCandidates=*/ false ,
5123
+ clang::DeclarationNameInfo (field->getDeclName (),
5124
+ clang::SourceLocation ()),
5125
+ returnType, clang::VK_LValue, clang::OK_Ordinary);
5126
+ auto returnCast = clangSema.ImpCastExprToType (
5127
+ memberExpr, returnType, clang::CK_LValueToRValue, clang::VK_PRValue);
5128
+ if (!returnCast.isUsable ())
5129
+ return nullptr ;
5130
+ return returnCast.get ();
5131
+ };
5132
+
5133
+ llvm::SmallVector<clang::Stmt *, 2 > body;
5134
+ if (retainOperationFn) {
5135
+ // Check if the returned value needs to be retained. This might occur if the
5136
+ // field getter is returning a shared reference type using, as it needs to
5137
+ // perform the retain to match the expected @owned convention.
5138
+ auto *retainClangFn =
5139
+ dyn_cast<clang::FunctionDecl>(retainOperationFn->getClangDecl ());
5140
+ if (!retainClangFn) {
5141
+ return nullptr ;
5142
+ }
5143
+ auto *fnRef = new (clangCtx) clang::DeclRefExpr (
5144
+ clangCtx, const_cast <clang::FunctionDecl *>(retainClangFn), false ,
5145
+ retainClangFn->getType (), clang::ExprValueKind::VK_LValue,
5146
+ clang::SourceLocation ());
5147
+ auto fieldExpr = createFieldAccess ();
5148
+ if (!fieldExpr)
5149
+ return nullptr ;
5150
+ auto retainCall = clangSema.BuildResolvedCallExpr (
5151
+ fnRef, const_cast <clang::FunctionDecl *>(retainClangFn),
5152
+ clang::SourceLocation (), {fieldExpr}, clang::SourceLocation ());
5153
+ if (!retainCall.isUsable ())
5154
+ return nullptr ;
5155
+ body.push_back (retainCall.get ());
5156
+ }
5157
+
5088
5158
// Construct the method's body.
5089
- auto *thisExpr = new (clangCtx) clang::CXXThisExpr (
5090
- clang::SourceLocation (), newMethod->getThisType (), /* IsImplicit=*/ false );
5091
- clang::QualType baseClassPtr = clangCtx.getRecordType (baseClass);
5092
- baseClassPtr.addConst ();
5093
- baseClassPtr = clangCtx.getPointerType (baseClassPtr);
5094
-
5095
- clang::CastKind Kind;
5096
- clang::CXXCastPath Path;
5097
- clangSema.CheckPointerConversion (thisExpr, baseClassPtr, Kind, Path,
5098
- /* IgnoreBaseAccess=*/ false ,
5099
- /* Diagnose=*/ true );
5100
- auto conv = clangSema.ImpCastExprToType (thisExpr, baseClassPtr, Kind,
5101
- clang::VK_PRValue, &Path);
5102
- if (!conv.isUsable ())
5103
- return nullptr ;
5104
- auto memberExpr = clangSema.BuildMemberExpr (
5105
- conv.get (), /* isArrow=*/ true , clang::SourceLocation (),
5106
- clang::NestedNameSpecifierLoc (), clang::SourceLocation (),
5107
- const_cast <clang::FieldDecl *>(field),
5108
- clang::DeclAccessPair::make (const_cast <clang::FieldDecl *>(field),
5109
- clang::AS_public),
5110
- /* HadMultipleCandidates=*/ false ,
5111
- clang::DeclarationNameInfo (field->getDeclName (), clang::SourceLocation ()),
5112
- returnType, clang::VK_LValue, clang::OK_Ordinary);
5113
- auto returnCast = clangSema.ImpCastExprToType (
5114
- memberExpr, returnType, clang::CK_LValueToRValue, clang::VK_PRValue);
5115
- if (!returnCast.isUsable ())
5159
+ auto fieldExpr = createFieldAccess ();
5160
+ if (!fieldExpr)
5116
5161
return nullptr ;
5117
5162
auto returnStmt = clang::ReturnStmt::Create (clangCtx, clang::SourceLocation (),
5118
- returnCast.get (), nullptr );
5163
+ fieldExpr, nullptr );
5164
+ body.push_back (returnStmt);
5119
5165
5120
5166
// Check if there were any Clang errors during the construction
5121
5167
// of the method body.
5122
5168
clangSema.DelayedDiagnostics .popWithoutEmitting (diagState);
5123
5169
if (!diagPool.empty ())
5124
5170
return nullptr ;
5125
- newMethod->setBody (returnStmt);
5171
+ newMethod->setBody (body.size () > 1
5172
+ ? clang::CompoundStmt::Create (
5173
+ clangCtx, body, clang::FPOptionsOverride (),
5174
+ clang::SourceLocation (), clang::SourceLocation ())
5175
+ : body[0 ]);
5126
5176
return newMethod;
5127
5177
}
5128
5178
@@ -5163,12 +5213,28 @@ synthesizeBaseClassFieldGetterBody(AbstractFunctionDecl *afd, void *context) {
5163
5213
RemoveReference,
5164
5214
/* forceConstQualifier=*/ true );
5165
5215
} else if (auto *fd = dyn_cast_or_null<clang::FieldDecl>(baseClangDecl)) {
5216
+ ValueDecl *retainOperationFn = nullptr ;
5217
+ // Check if this field getter is returning a retainable FRT.
5218
+ if (getterDecl->getResultInterfaceType ()->isForeignReferenceType ()) {
5219
+ auto retainOperation = evaluateOrDefault (
5220
+ ctx.evaluator ,
5221
+ CustomRefCountingOperation ({getterDecl->getResultInterfaceType ()
5222
+ ->lookThroughAllOptionalTypes ()
5223
+ ->getClassOrBoundGenericClass (),
5224
+ CustomRefCountingOperationKind::retain}),
5225
+ {});
5226
+ if (retainOperation.kind ==
5227
+ CustomRefCountingOperationResult::foundOperation) {
5228
+ retainOperationFn = retainOperation.operation ;
5229
+ }
5230
+ }
5166
5231
// Field getter is represented through a generated
5167
5232
// C++ method call that returns the value of the base field.
5168
5233
baseGetterCxxMethod = synthesizeCxxBaseGetterAccessorMethod (
5169
5234
*static_cast <ClangImporter *>(ctx.getClangModuleLoader ()),
5170
5235
cast<clang::CXXRecordDecl>(derivedStruct->getClangDecl ()),
5171
- cast<clang::CXXRecordDecl>(baseStruct->getClangDecl ()), fd);
5236
+ cast<clang::CXXRecordDecl>(baseStruct->getClangDecl ()), fd,
5237
+ retainOperationFn);
5172
5238
}
5173
5239
5174
5240
if (!baseGetterCxxMethod) {
0 commit comments