Skip to content

Commit cbd8cdf

Browse files
author
Gabor Horvath
committed
[cxx-interop] Add rules to recognize escapability of aggregates
For now, this logic is used for importing fewer unannotated types as unsafe. In the future, this logic will be used by escapability inference for other (non-aggregate) types.
1 parent 4052991 commit cbd8cdf

File tree

5 files changed

+135
-9
lines changed

5 files changed

+135
-9
lines changed

include/swift/ClangImporter/ClangImporterRequests.h

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,13 @@
1616
#ifndef SWIFT_CLANG_IMPORTER_REQUESTS_H
1717
#define SWIFT_CLANG_IMPORTER_REQUESTS_H
1818

19-
#include "swift/AST/SimpleRequest.h"
2019
#include "swift/AST/ASTTypeIDs.h"
2120
#include "swift/AST/EvaluatorDependencies.h"
22-
#include "swift/AST/FileUnit.h"
2321
#include "swift/AST/Identifier.h"
2422
#include "swift/AST/NameLookup.h"
23+
#include "swift/AST/SimpleRequest.h"
2524
#include "swift/Basic/Statistic.h"
25+
#include "clang/AST/Type.h"
2626
#include "llvm/ADT/Hashing.h"
2727
#include "llvm/ADT/TinyPtrVector.h"
2828

@@ -500,6 +500,45 @@ class CustomRefCountingOperation
500500
CustomRefCountingOperationDescriptor desc) const;
501501
};
502502

503+
enum class CxxEscapability { Escapable, NonEscapable, Unknown };
504+
505+
struct EscapabilityLookupDescriptor final {
506+
const clang::Type *type;
507+
508+
friend llvm::hash_code hash_value(const EscapabilityLookupDescriptor &desc) {
509+
return llvm::hash_combine(desc.type);
510+
}
511+
512+
friend bool operator==(const EscapabilityLookupDescriptor &lhs,
513+
const EscapabilityLookupDescriptor &rhs) {
514+
return lhs.type == rhs.type;
515+
}
516+
517+
friend bool operator!=(const EscapabilityLookupDescriptor &lhs,
518+
const EscapabilityLookupDescriptor &rhs) {
519+
return !(lhs == rhs);
520+
}
521+
};
522+
523+
class ClangTypeEscapability
524+
: public SimpleRequest<ClangTypeEscapability,
525+
CxxEscapability(EscapabilityLookupDescriptor),
526+
RequestFlags::Cached> {
527+
public:
528+
using SimpleRequest::SimpleRequest;
529+
530+
bool isCached() const { return true; }
531+
532+
private:
533+
friend SimpleRequest;
534+
535+
CxxEscapability evaluate(Evaluator &evaluator,
536+
EscapabilityLookupDescriptor desc) const;
537+
};
538+
539+
void simple_display(llvm::raw_ostream &out, EscapabilityLookupDescriptor desc);
540+
SourceLoc extractNearestSourceLoc(EscapabilityLookupDescriptor desc);
541+
503542
#define SWIFT_TYPEID_ZONE ClangImporter
504543
#define SWIFT_TYPEID_HEADER "swift/ClangImporter/ClangImporterTypeIDZone.def"
505544
#include "swift/Basic/DefineTypeIDZone.h"

include/swift/ClangImporter/ClangImporterTypeIDZone.def

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,6 @@ SWIFT_REQUEST(ClangImporter, IsSafeUseOfCxxDecl,
4242
SWIFT_REQUEST(ClangImporter, CustomRefCountingOperation,
4343
CustomRefCountingOperationResult(CustomRefCountingOperationDescriptor), Cached,
4444
NoLocationInfo)
45+
SWIFT_REQUEST(ClangImporter, ClangTypeEscapability,
46+
CxxEscapability(EscapabilityLookupDescriptor), Cached,
47+
NoLocationInfo)

lib/ClangImporter/ClangImporter.cpp

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,9 @@
5252
#include "swift/Strings.h"
5353
#include "swift/Subsystems.h"
5454
#include "clang/AST/ASTContext.h"
55+
#include "clang/AST/DeclCXX.h"
5556
#include "clang/AST/Mangle.h"
57+
#include "clang/AST/Type.h"
5658
#include "clang/Basic/DiagnosticOptions.h"
5759
#include "clang/Basic/FileEntry.h"
5860
#include "clang/Basic/IdentifierTable.h"
@@ -5018,6 +5020,74 @@ TinyPtrVector<ValueDecl *> CXXNamespaceMemberLookup::evaluate(
50185020
return result;
50195021
}
50205022

5023+
CxxEscapability
5024+
ClangTypeEscapability::evaluate(Evaluator &evaluator,
5025+
EscapabilityLookupDescriptor desc) const {
5026+
auto desugared = desc.type->getUnqualifiedDesugaredType();
5027+
if (const auto *recordType = desugared->getAs<clang::RecordType>()) {
5028+
if (importer::hasNonEscapableAttr(recordType->getDecl()))
5029+
return CxxEscapability::NonEscapable;
5030+
if (importer::hasEscapableAttr(recordType->getDecl()))
5031+
return CxxEscapability::Escapable;
5032+
auto recordDecl = recordType->getDecl();
5033+
auto cxxRecordDecl = dyn_cast<clang::CXXRecordDecl>(recordDecl);
5034+
if (!cxxRecordDecl || cxxRecordDecl->isAggregate()) {
5035+
bool hadUnknown = false;
5036+
auto evaluateEscapability = [&](const clang::Type *type) {
5037+
auto escapability = evaluateOrDefault(
5038+
evaluator, ClangTypeEscapability({type}), CxxEscapability::Unknown);
5039+
if (escapability == CxxEscapability::Unknown)
5040+
hadUnknown = true;
5041+
return escapability;
5042+
};
5043+
5044+
if (cxxRecordDecl) {
5045+
for (auto base : cxxRecordDecl->bases()) {
5046+
auto baseEscapability = evaluateEscapability(
5047+
base.getType()->getUnqualifiedDesugaredType());
5048+
if (baseEscapability == CxxEscapability::NonEscapable)
5049+
return CxxEscapability::NonEscapable;
5050+
}
5051+
}
5052+
5053+
for (auto field : recordDecl->fields()) {
5054+
auto fieldEscapability = evaluateEscapability(
5055+
field->getType()->getUnqualifiedDesugaredType());
5056+
if (fieldEscapability == CxxEscapability::NonEscapable)
5057+
return CxxEscapability::NonEscapable;
5058+
}
5059+
5060+
return hadUnknown ? CxxEscapability::Unknown : CxxEscapability::Escapable;
5061+
}
5062+
}
5063+
if (desugared->isArrayType()) {
5064+
auto elemTy = cast<clang::ArrayType>(desugared)
5065+
->getElementType()
5066+
->getUnqualifiedDesugaredType();
5067+
return evaluateOrDefault(evaluator, ClangTypeEscapability({elemTy}),
5068+
CxxEscapability::Unknown);
5069+
}
5070+
5071+
// Base cases
5072+
if (desugared->isAnyPointerType() || desugared->isBlockPointerType() ||
5073+
desugared->isMemberPointerType() || desugared->isReferenceType())
5074+
return CxxEscapability::NonEscapable;
5075+
if (desugared->isScalarType())
5076+
return CxxEscapability::Escapable;
5077+
return CxxEscapability::Unknown;
5078+
}
5079+
5080+
void swift::simple_display(llvm::raw_ostream &out,
5081+
EscapabilityLookupDescriptor desc) {
5082+
out << "Computing escapability for type '";
5083+
out << clang::QualType(desc.type, 0).getAsString();
5084+
out << "'";
5085+
}
5086+
5087+
SourceLoc swift::extractNearestSourceLoc(EscapabilityLookupDescriptor) {
5088+
return SourceLoc();
5089+
}
5090+
50215091
// Just create a specialized function decl for "__swift_interopStaticCast"
50225092
// using the types base and derived.
50235093
static

lib/ClangImporter/ImportDecl.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8155,8 +8155,7 @@ bool swift::importer::isMutabilityAttr(const clang::SwiftAttrAttr *swiftAttr) {
81558155
swiftAttr->getAttribute() == "nonmutating";
81568156
}
81578157

8158-
static bool importAsUnsafe(const ASTContext &context,
8159-
const clang::RecordDecl *decl,
8158+
static bool importAsUnsafe(ASTContext &context, const clang::RecordDecl *decl,
81608159
const Decl *MappedDecl) {
81618160
if (!context.LangOpts.hasFeature(Feature::SafeInterop) ||
81628161
!context.LangOpts.hasFeature(Feature::AllowUnsafeAttribute) || !decl)
@@ -8165,9 +8164,9 @@ static bool importAsUnsafe(const ASTContext &context,
81658164
if (isa<ClassDecl>(MappedDecl))
81668165
return false;
81678166

8168-
// TODO: Add logic to cover structural rules.
8169-
return !importer::hasNonEscapableAttr(decl) &&
8170-
!importer::hasEscapableAttr(decl);
8167+
return evaluateOrDefault(
8168+
context.evaluator, ClangTypeEscapability({decl->getTypeForDecl()}),
8169+
CxxEscapability::Unknown) == CxxEscapability::Unknown;
81718170
}
81728171

81738172
void

test/Interop/Cxx/class/safe-interop-mode.swift

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,22 @@ private:
2626

2727
struct SWIFT_ESCAPABLE Owner {};
2828

29-
struct Unannotated {};
29+
struct Unannotated {
30+
Unannotated();
31+
};
3032

3133
struct SWIFT_UNSAFE_REFERENCE UnsafeReference {};
3234

35+
struct SafeEscapableAggregate {
36+
int a;
37+
float b[5];
38+
};
39+
40+
struct UnknownEscapabilityAggregate {
41+
SafeEscapableAggregate agg;
42+
Unannotated unann;
43+
};
44+
3345
//--- test.swift
3446

3547
import Test
@@ -42,7 +54,10 @@ func useUnsafeParam(x: Unannotated) { // expected-warning{{reference to unsafe s
4254
func useUnsafeParam2(x: UnsafeReference) { // expected-warning{{reference to unsafe class 'UnsafeReference'}}
4355
}
4456

45-
func useSafeParams(x: Owner, y: View) {
57+
func useUnsafeParam3(x: UnknownEscapabilityAggregate) { // expected-warning{{reference to unsafe struct 'UnknownEscapabilityAggregate'}}
58+
}
59+
60+
func useSafeParams(x: Owner, y: View, z: SafeEscapableAggregate) {
4661
}
4762

4863
func useCfType(x: CFArray) {

0 commit comments

Comments
 (0)