Skip to content

Commit 67daddc

Browse files
committed
[ClangImporter] Attach _SwiftifyImportProtocol to imported protocols
with bounds attributes This creates safe overloads for any methods in the protocol annotated with bounds information. rdar://144335990
1 parent 3dbe569 commit 67daddc

File tree

5 files changed

+207
-21
lines changed

5 files changed

+207
-21
lines changed

lib/ClangImporter/ImportDecl.cpp

Lines changed: 92 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5546,6 +5546,8 @@ namespace {
55465546

55475547
result->setMemberLoader(&Impl, 0);
55485548

5549+
Impl.importProtocolBoundsAttributes(result);
5550+
55495551
return result;
55505552
}
55515553

@@ -8802,13 +8804,15 @@ ClangImporter::Implementation::importSwiftAttrAttributes(Decl *MappedDecl) {
88028804

88038805
namespace {
88048806
class SwiftifyInfoPrinter {
8805-
public:
8807+
protected:
88068808
clang::ASTContext &ctx;
88078809
llvm::raw_ostream &out;
88088810
bool firstParam = true;
8811+
8812+
public:
88098813
SwiftifyInfoPrinter(clang::ASTContext &ctx, llvm::raw_ostream &out)
88108814
: ctx(ctx), out(out) {
8811-
out << "@_SwiftifyImport(";
8815+
out << "(";
88128816
}
88138817
~SwiftifyInfoPrinter() { out << ")"; }
88148818

@@ -8865,21 +8869,48 @@ class SwiftifyInfoPrinter {
88658869
}
88668870

88678871
private:
8872+
void printParamOrReturn(ssize_t pointerIndex) {
8873+
if (pointerIndex == -2)
8874+
out << ".self";
8875+
else if (pointerIndex == -1)
8876+
out << ".return";
8877+
else
8878+
out << ".param(" << pointerIndex + 1 << ")";
8879+
}
8880+
8881+
protected:
88688882
void printSeparator() {
88698883
if (!firstParam) {
88708884
out << ", ";
88718885
} else {
88728886
firstParam = false;
88738887
}
88748888
}
8889+
};
88758890

8876-
void printParamOrReturn(ssize_t pointerIndex) {
8877-
if (pointerIndex == -2)
8878-
out << ".self";
8879-
else if (pointerIndex == -1)
8880-
out << ".return";
8881-
else
8882-
out << ".param(" << pointerIndex + 1 << ")";
8891+
class SwiftifyProtocolInfoPrinter : SwiftifyInfoPrinter {
8892+
public:
8893+
SwiftifyProtocolInfoPrinter(clang::ASTContext &ctx, llvm::raw_ostream &out)
8894+
: SwiftifyInfoPrinter(ctx, out) {}
8895+
8896+
bool printMethod(const FuncDecl *Method,
8897+
bool (*importFunctionBoundsAttributes)(
8898+
SwiftifyInfoPrinter &, const clang::ObjCMethodDecl *)) {
8899+
const auto *ClangDecl =
8900+
dyn_cast_or_null<clang::ObjCMethodDecl>(Method->getClangDecl());
8901+
if (!ClangDecl)
8902+
return false;
8903+
8904+
printSeparator();
8905+
out << ".method(name: \"" << Method->getName().getBaseIdentifier()
8906+
<< "\", paramInfo: [";
8907+
// reset firstParam inside paramInfo array. At this point firstParam will
8908+
// always be false, so no need to save the current value.
8909+
firstParam = true;
8910+
bool hadAttributes = importFunctionBoundsAttributes(*this, ClangDecl);
8911+
firstParam = false;
8912+
out << "])";
8913+
return hadAttributes;
88838914
}
88848915
};
88858916
} // namespace
@@ -8896,6 +8927,7 @@ void ClangImporter::Implementation::importSpanAttributes(FuncDecl *MappedDecl) {
88968927
bool attachMacro = false;
88978928
{
88988929
llvm::raw_svector_ostream out(MacroString);
8930+
out << "@_SwiftifyImport";
88998931
llvm::StringMap<std::string> typeMapping;
89008932

89018933
auto registerSwiftifyMacro =
@@ -8951,6 +8983,54 @@ void ClangImporter::Implementation::importSpanAttributes(FuncDecl *MappedDecl) {
89518983
importNontrivialAttribute(MappedDecl, MacroString);
89528984
}
89538985

8986+
template <class FuncType>
8987+
bool importFuncBoundsAttrs(SwiftifyInfoPrinter &printer, FuncType *ClangDecl) {
8988+
bool hadBoundsAttribute = false;
8989+
for (auto [index, param] : llvm::enumerate(ClangDecl->parameters())) {
8990+
if (auto CAT =
8991+
param->getType()->template getAs<clang::CountAttributedType>()) {
8992+
printer.printCountedBy(CAT, index);
8993+
if (param->template hasAttr<clang::NoEscapeAttr>()) {
8994+
printer.printNonEscaping(index);
8995+
}
8996+
hadBoundsAttribute = true;
8997+
}
8998+
}
8999+
if (auto CAT = ClangDecl->getReturnType()
9000+
->template getAs<clang::CountAttributedType>()) {
9001+
printer.printCountedBy(CAT, -1);
9002+
hadBoundsAttribute = true;
9003+
}
9004+
return hadBoundsAttribute;
9005+
}
9006+
9007+
void ClangImporter::Implementation::importProtocolBoundsAttributes(
9008+
ProtocolDecl *MappedDecl) {
9009+
if (!SwiftContext.LangOpts.hasFeature(Feature::SafeInteropWrappers))
9010+
return;
9011+
9012+
llvm::SmallString<128> MacroString;
9013+
{
9014+
llvm::raw_svector_ostream out(MacroString);
9015+
out << "@_SwiftifyImportProtocol";
9016+
9017+
bool hasBoundsAttributes = false;
9018+
SwiftifyProtocolInfoPrinter printer(getClangASTContext(), out);
9019+
for (Decl *SwiftMember : MappedDecl->getAllMembers()) {
9020+
FuncDecl *SwiftDecl = dyn_cast<FuncDecl>(SwiftMember);
9021+
if (!SwiftDecl)
9022+
continue;
9023+
hasBoundsAttributes |=
9024+
printer.printMethod(SwiftDecl, importFuncBoundsAttrs);
9025+
}
9026+
9027+
if (!hasBoundsAttributes)
9028+
return;
9029+
}
9030+
9031+
importNontrivialAttribute(MappedDecl, MacroString);
9032+
}
9033+
89549034
void ClangImporter::Implementation::importBoundsAttributes(
89559035
FuncDecl *MappedDecl) {
89569036
assert(SwiftContext.LangOpts.hasFeature(Feature::SafeInteropWrappers));
@@ -8964,20 +9044,11 @@ void ClangImporter::Implementation::importBoundsAttributes(
89649044
llvm::SmallString<128> MacroString;
89659045
{
89669046
llvm::raw_svector_ostream out(MacroString);
9047+
out << "@_SwiftifyImport";
89679048

89689049
SwiftifyInfoPrinter printer(getClangASTContext(), out);
8969-
for (auto [index, param] : llvm::enumerate(ClangDecl->parameters())) {
8970-
if (auto CAT = param->getType()->getAs<clang::CountAttributedType>()) {
8971-
printer.printCountedBy(CAT, index);
8972-
if (param->hasAttr<clang::NoEscapeAttr>()) {
8973-
printer.printNonEscaping(index);
8974-
}
8975-
}
8976-
}
8977-
if (auto CAT =
8978-
ClangDecl->getReturnType()->getAs<clang::CountAttributedType>()) {
8979-
printer.printCountedBy(CAT, -1);
8980-
}
9050+
if (!importFuncBoundsAttrs(printer, ClangDecl))
9051+
return;
89819052
}
89829053

89839054
importNontrivialAttribute(MappedDecl, MacroString);

lib/ClangImporter/ImporterImpl.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1745,6 +1745,7 @@ class LLVM_LIBRARY_VISIBILITY ClangImporter::Implementation
17451745
}
17461746

17471747
void importSwiftAttrAttributes(Decl *decl);
1748+
void importProtocolBoundsAttributes(ProtocolDecl *MappedDecl);
17481749
void importBoundsAttributes(FuncDecl *MappedDecl);
17491750
void importSpanAttributes(FuncDecl *MappedDecl);
17501751

lib/Macros/Sources/SwiftMacros/SwiftifyImportMacro.swift

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1107,6 +1107,10 @@ func constructOverloadFunction(forDecl funcDecl: FunctionDeclSyntax,
11071107
atSign: .atSignToken(),
11081108
attributeName: IdentifierTypeSyntax(name: "_disfavoredOverload")))
11091109
] : [])
1110+
let hasVisibilityModifier = funcDecl.modifiers.contains { modifier in
1111+
let modName = modifier.name.trimmed.text
1112+
return modName == "public" || modName == "internal" || modName == "open" || modName == "private" || modName == "filePrivate"
1113+
}
11101114
let newFunc =
11111115
funcDecl
11121116
.with(\.signature, newSignature)
@@ -1129,6 +1133,7 @@ func constructOverloadFunction(forDecl funcDecl: FunctionDeclSyntax,
11291133
]
11301134
+ lifetimeAttrs
11311135
+ disfavoredOverload)
1136+
.with(\.modifiers, funcDecl.modifiers + (hasVisibilityModifier ? [] : [DeclModifierSyntax(name: .identifier("public"))]))
11321137
return DeclSyntax(newFunc)
11331138
}
11341139

lib/Sema/TypeCheckMacros.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1337,6 +1337,13 @@ static SourceFile *evaluateAttachedMacro(MacroDecl *macro, Decl *attachedTo,
13371337
} else if (role == MacroRole::Conformance || role == MacroRole::Extension) {
13381338
// Conformance macros always expand to extensions at file-scope.
13391339
dc = attachedTo->getDeclContext()->getParentSourceFile();
1340+
if (!dc) {
1341+
assert(isa<ClangModuleUnit>(
1342+
attachedTo->getDeclContext()->getModuleScopeContext()));
1343+
dc = attachedTo->getDeclContext();
1344+
// decls imported from clang do not have a SourceFile
1345+
assert(isa<FileUnit>(dc));
1346+
}
13401347
} else {
13411348
dc = attachedTo->getInnermostDeclContext();
13421349
}
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
// REQUIRES: swift_feature_SafeInteropWrappers
2+
3+
// swift-ide-test doesn't currently trigger extension macro expansion, nor does it typecheck macro expansions, so dump macros with swift-frontend
4+
5+
// RUN: %empty-directory(%t)
6+
// RUN: split-file %s %t
7+
8+
// RUN: %target-swift-frontend -emit-module -plugin-path %swift-plugin-dir -o %t/CountedByProtocol.swiftmodule -I %t/Inputs -enable-experimental-feature SafeInteropWrappers %t/counted-by-protocol.swift -dump-macro-expansions 2>&1 | %FileCheck %s
9+
10+
//--- Inputs/module.modulemap
11+
module CountedByProtocolClang {
12+
header "counted-by-protocol.h"
13+
export *
14+
}
15+
16+
//--- Inputs/counted-by-protocol.h
17+
#pragma once
18+
19+
#define __counted_by(x) __attribute__((__counted_by__(x)))
20+
21+
@protocol CountedByProtocol
22+
- (void) simple:(int)len :(int * __counted_by(len))p;
23+
- (void) shared:(int)len :(int * __counted_by(len))p1 :(int * __counted_by(len))p2;
24+
- (void) complexExpr:(int)len :(int) offset :(int * __counted_by(len - offset))p;
25+
- (void) nullUnspecified:(int)len :(int * __counted_by(len) _Null_unspecified)p;
26+
- (void) nonnull:(int)len :(int * __counted_by(len) _Nonnull)p;
27+
- (void) nullable:(int)len :(int * __counted_by(len) _Nullable)p;
28+
- (int * __counted_by(len)) returnPointer:(int)len;
29+
@end
30+
31+
// CHECK-LABEL: extension CountedByProtocol {
32+
// CHECK-NEXT: @_alwaysEmitIntoClient public
33+
// CHECK-NEXT: func simple(_ p: UnsafeMutableBufferPointer<Int{{.*}}>) {
34+
// CHECK-NEXT: return simple(Int{{.*}}(exactly: p.count)!, p.baseAddress!)
35+
// CHECK-NEXT: }
36+
// CHECK-NEXT: @_alwaysEmitIntoClient public
37+
// CHECK-NEXT: func shared(_ len: Int{{.*}}, _ p1: UnsafeMutableBufferPointer<Int{{.*}}>, _ p2: UnsafeMutableBufferPointer<Int{{.*}}>) {
38+
// CHECK-NEXT: let _p1Count: some BinaryInteger = len
39+
// CHECK-NEXT: if p1.count < _p1Count || _p1Count < 0 {
40+
// CHECK-NEXT: fatalError("bounds check failure when calling unsafe function")
41+
// CHECK-NEXT: }
42+
// CHECK-NEXT: let _p2Count: some BinaryInteger = len
43+
// CHECK-NEXT: if p2.count < _p2Count || _p2Count < 0 {
44+
// CHECK-NEXT: fatalError("bounds check failure when calling unsafe function")
45+
// CHECK-NEXT: }
46+
// CHECK-NEXT: return shared(len, p1.baseAddress!, p2.baseAddress!)
47+
// CHECK-NEXT: }
48+
// CHECK-NEXT: @_alwaysEmitIntoClient public
49+
// CHECK-NEXT: func complexExpr(_ len: Int{{.*}}, _ offset: Int{{.*}}, _ p: UnsafeMutableBufferPointer<Int{{.*}}>) {
50+
// CHECK-NEXT: let _pCount: some BinaryInteger = len - offset
51+
// CHECK-NEXT: if p.count < _pCount || _pCount < 0 {
52+
// CHECK-NEXT: fatalError("bounds check failure when calling unsafe function")
53+
// CHECK-NEXT: }
54+
// CHECK-NEXT: return complexExpr(len, offset, p.baseAddress!)
55+
// CHECK-NEXT: }
56+
// CHECK-NEXT: @_alwaysEmitIntoClient public
57+
// CHECK-NEXT: func nullUnspecified(_ p: UnsafeMutableBufferPointer<Int{{.*}}>) {
58+
// CHECK-NEXT: return nullUnspecified(Int{{.*}}(exactly: p.count)!, p.baseAddress!)
59+
// CHECK-NEXT: }
60+
// CHECK-NEXT: @_alwaysEmitIntoClient public
61+
// CHECK-NEXT: func nonnull(_ p: UnsafeMutableBufferPointer<Int{{.*}}>) {
62+
// CHECK-NEXT: return nonnull(Int{{.*}}(exactly: p.count)!, p.baseAddress!)
63+
// CHECK-NEXT: }
64+
// CHECK-NEXT: @_alwaysEmitIntoClient public
65+
// CHECK-NEXT: func nullable(_ p: UnsafeMutableBufferPointer<Int{{.*}}>?) {
66+
// CHECK-NEXT: return nullable(Int{{.*}}(exactly: p?.count ?? 0)!, p?.baseAddress)
67+
// CHECK-NEXT: }
68+
// CHECK-NEXT: @_alwaysEmitIntoClient @_disfavoredOverload public
69+
// CHECK-NEXT: func returnPointer(_ len: Int{{.*}}) -> UnsafeMutableBufferPointer<Int{{.*}}> {
70+
// CHECK-NEXT: return UnsafeMutableBufferPointer<Int{{.*}}>(start: returnPointer(len), count: Int(len))
71+
// CHECK-NEXT: }
72+
// CHECK-NEXT: }
73+
74+
__attribute__((swift_attr("@_SwiftifyImportProtocol(.method(name: \"swiftAttr\", paramInfo: [.countedBy(pointer: .param(2), count: \"len\")]))")))
75+
@protocol SwiftAttrProtocol
76+
- (void)swiftAttr:(int)len :(int *)p;
77+
@end
78+
79+
// CHECK-LABEL: extension SwiftAttrProtocol {
80+
// CHECK-NEXT: @_alwaysEmitIntoClient public
81+
// CHECK-NEXT: func swiftAttr(_ p: UnsafeMutableBufferPointer<Int32>) {
82+
// CHECK-NEXT: return swiftAttr(Int32(exactly: p.count)!, p.baseAddress!)
83+
// CHECK-NEXT: }
84+
// CHECK-NEXT: }
85+
86+
//--- counted-by-protocol.swift
87+
import CountedByProtocolClang
88+
89+
@inlinable
90+
public func call(p: UnsafeMutableBufferPointer<CInt>, x: CInt, y: CInt, a: CountedByProtocol, b: SwiftAttrProtocol) {
91+
a.simple(p)
92+
a.shared(x, p, p)
93+
a.complexExpr(x, y, p)
94+
a.nullUnspecified(p)
95+
a.nonnull(p)
96+
a.nullable(p)
97+
let r1: UnsafeMutableBufferPointer<CInt> = a.returnPointer(x)
98+
let r2 = a.returnPointer(x)
99+
let r3: UnsafeMutablePointer<CInt>? = r2 // make sure the original is the favored overload
100+
b.swiftAttr(p)
101+
}
102+

0 commit comments

Comments
 (0)