Skip to content

Commit 686d1b4

Browse files
authored
Merge pull request #74783 from susmonteiro/susmonteiro/cxx-span-from-ubpointer
[cxx-interop] Implements constructor for std::span from UnsafeBufferPointer
2 parents 12484ff + e86099c commit 686d1b4

File tree

9 files changed

+354
-5
lines changed

9 files changed

+354
-5
lines changed

include/swift/AST/KnownProtocols.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ PROTOCOL(CxxRandomAccessCollection)
135135
PROTOCOL(CxxSequence)
136136
PROTOCOL(CxxUniqueSet)
137137
PROTOCOL(CxxVector)
138+
PROTOCOL(CxxSpan)
138139
PROTOCOL(UnsafeCxxInputIterator)
139140
PROTOCOL(UnsafeCxxMutableInputIterator)
140141
PROTOCOL(UnsafeCxxRandomAccessIterator)

lib/AST/ASTContext.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1341,6 +1341,7 @@ ProtocolDecl *ASTContext::getProtocol(KnownProtocolKind kind) const {
13411341
case KnownProtocolKind::CxxSequence:
13421342
case KnownProtocolKind::CxxUniqueSet:
13431343
case KnownProtocolKind::CxxVector:
1344+
case KnownProtocolKind::CxxSpan:
13441345
case KnownProtocolKind::UnsafeCxxInputIterator:
13451346
case KnownProtocolKind::UnsafeCxxMutableInputIterator:
13461347
case KnownProtocolKind::UnsafeCxxRandomAccessIterator:

lib/ClangImporter/ClangDerivedConformances.cpp

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,22 @@ static bool isStdDecl(const clang::CXXRecordDecl *clangDecl,
122122
return llvm::is_contained(names, name);
123123
}
124124

125+
static clang::TypeDecl *
126+
lookupNestedClangTypeDecl(const clang::CXXRecordDecl *clangDecl,
127+
StringRef name) {
128+
clang::IdentifierInfo *nestedDeclName =
129+
&clangDecl->getASTContext().Idents.get(name);
130+
auto nestedDecls = clangDecl->lookup(nestedDeclName);
131+
// If this is a templated typedef, Clang might have instantiated several
132+
// equivalent typedef decls. If they aren't equivalent, Clang has already
133+
// complained about this. Let's assume that they are equivalent. (see
134+
// filterNonConflictingPreviousTypedefDecls in clang/Sema/SemaDecl.cpp)
135+
if (nestedDecls.empty())
136+
return nullptr;
137+
auto nestedDecl = nestedDecls.front();
138+
return dyn_cast_or_null<clang::TypeDecl>(nestedDecl);
139+
}
140+
125141
static clang::TypeDecl *
126142
getIteratorCategoryDecl(const clang::CXXRecordDecl *clangDecl) {
127143
clang::IdentifierInfo *iteratorCategoryDeclName =
@@ -1128,4 +1144,94 @@ void swift::conformToCxxFunctionIfNeeded(
11281144
decl->addMember(importedConstructor);
11291145

11301146
// TODO: actually conform to some form of CxxFunction protocol
1147+
1148+
}
1149+
1150+
void swift::conformToCxxSpanIfNeeded(ClangImporter::Implementation &impl,
1151+
NominalTypeDecl *decl,
1152+
const clang::CXXRecordDecl *clangDecl) {
1153+
PrettyStackTraceDecl trace("conforming to CxxSpan", decl);
1154+
1155+
assert(decl);
1156+
assert(clangDecl);
1157+
ASTContext &ctx = decl->getASTContext();
1158+
clang::ASTContext &clangCtx = impl.getClangASTContext();
1159+
clang::Sema &clangSema = impl.getClangSema();
1160+
1161+
// Only auto-conform types from the C++ standard library. Custom user types
1162+
// might have a similar interface but different semantics.
1163+
if (!isStdDecl(clangDecl, {"span"}))
1164+
return;
1165+
1166+
auto elementType = lookupDirectSingleWithoutExtensions<TypeAliasDecl>(
1167+
decl, ctx.getIdentifier("element_type"));
1168+
auto sizeType = lookupDirectSingleWithoutExtensions<TypeAliasDecl>(
1169+
decl, ctx.getIdentifier("size_type"));
1170+
1171+
if (!elementType || !sizeType)
1172+
return;
1173+
1174+
auto constPointerTypeDecl =
1175+
lookupNestedClangTypeDecl(clangDecl, "const_pointer");
1176+
auto countTypeDecl = lookupNestedClangTypeDecl(clangDecl, "size_type");
1177+
1178+
if (!constPointerTypeDecl || !countTypeDecl)
1179+
return;
1180+
1181+
// create fake variable for constPointer (constructor arg 1)
1182+
auto constPointerType = clangCtx.getTypeDeclType(constPointerTypeDecl);
1183+
auto fakeConstPointerVarDecl = clang::VarDecl::Create(
1184+
clangCtx, /*DC*/ clangCtx.getTranslationUnitDecl(),
1185+
clang::SourceLocation(), clang::SourceLocation(), /*Id*/ nullptr,
1186+
constPointerType, clangCtx.getTrivialTypeSourceInfo(constPointerType),
1187+
clang::StorageClass::SC_None);
1188+
1189+
auto fakeConstPointer = new (clangCtx) clang::DeclRefExpr(
1190+
clangCtx, fakeConstPointerVarDecl, false, constPointerType,
1191+
clang::ExprValueKind::VK_LValue, clang::SourceLocation());
1192+
1193+
// create fake variable for count (constructor arg 2)
1194+
auto countType = clangCtx.getTypeDeclType(countTypeDecl);
1195+
auto fakeCountVarDecl = clang::VarDecl::Create(
1196+
clangCtx, /*DC*/ clangCtx.getTranslationUnitDecl(),
1197+
clang::SourceLocation(), clang::SourceLocation(), /*Id*/ nullptr,
1198+
countType, clangCtx.getTrivialTypeSourceInfo(countType),
1199+
clang::StorageClass::SC_None);
1200+
1201+
auto fakeCount = new (clangCtx) clang::DeclRefExpr(
1202+
clangCtx, fakeCountVarDecl, false, countType,
1203+
clang::ExprValueKind::VK_LValue, clang::SourceLocation());
1204+
1205+
// Use clangSema.BuildCxxTypeConstructExpr to create a CXXTypeConstructExpr,
1206+
// passing constPointer and count
1207+
SmallVector<clang::Expr *, 2> constructExprArgs = {fakeConstPointer,
1208+
fakeCount};
1209+
1210+
auto clangDeclTyInfo = clangCtx.getTrivialTypeSourceInfo(
1211+
clang::QualType(clangDecl->getTypeForDecl(), 0));
1212+
1213+
// Instantiate the templated constructor that would accept this fake variable.
1214+
auto constructExprResult = clangSema.BuildCXXTypeConstructExpr(
1215+
clangDeclTyInfo, clangDecl->getLocation(), constructExprArgs,
1216+
clangDecl->getLocation(), /*ListInitialization*/ false);
1217+
if (!constructExprResult.isUsable())
1218+
return;
1219+
1220+
auto constructExpr =
1221+
dyn_cast_or_null<clang::CXXConstructExpr>(constructExprResult.get());
1222+
if (!constructExpr)
1223+
return;
1224+
1225+
auto constructorDecl = constructExpr->getConstructor();
1226+
auto importedConstructor =
1227+
impl.importDecl(constructorDecl, impl.CurrentVersion);
1228+
if (!importedConstructor)
1229+
return;
1230+
decl->addMember(importedConstructor);
1231+
1232+
impl.addSynthesizedTypealias(decl, ctx.Id_Element,
1233+
elementType->getUnderlyingType());
1234+
impl.addSynthesizedTypealias(decl, ctx.getIdentifier("Size"),
1235+
sizeType->getUnderlyingType());
1236+
impl.addSynthesizedProtocolAttrs(decl, {KnownProtocolKind::CxxSpan});
11311237
}

lib/ClangImporter/ClangDerivedConformances.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,12 @@ void conformToCxxVectorIfNeeded(ClangImporter::Implementation &impl,
7676
void conformToCxxFunctionIfNeeded(ClangImporter::Implementation &impl,
7777
NominalTypeDecl *decl,
7878
const clang::CXXRecordDecl *clangDecl);
79+
80+
/// If the decl is an instantiation of C++ `std::span`, synthesize a
81+
/// conformance to CxxSpan, which is defined in the Cxx module.
82+
void conformToCxxSpanIfNeeded(ClangImporter::Implementation &impl,
83+
NominalTypeDecl *decl,
84+
const clang::CXXRecordDecl *clangDecl);
7985

8086
} // namespace swift
8187

lib/ClangImporter/ImportDecl.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2826,6 +2826,7 @@ namespace {
28262826
conformToCxxOptionalIfNeeded(Impl, nominalDecl, decl);
28272827
conformToCxxVectorIfNeeded(Impl, nominalDecl, decl);
28282828
conformToCxxFunctionIfNeeded(Impl, nominalDecl, decl);
2829+
conformToCxxSpanIfNeeded(Impl, nominalDecl, decl);
28292830
}
28302831
}
28312832

lib/IRGen/GenMeta.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6836,6 +6836,7 @@ SpecialProtocol irgen::getSpecialProtocolID(ProtocolDecl *P) {
68366836
case KnownProtocolKind::CxxSequence:
68376837
case KnownProtocolKind::CxxUniqueSet:
68386838
case KnownProtocolKind::CxxVector:
6839+
case KnownProtocolKind::CxxSpan:
68396840
case KnownProtocolKind::UnsafeCxxInputIterator:
68406841
case KnownProtocolKind::UnsafeCxxMutableInputIterator:
68416842
case KnownProtocolKind::UnsafeCxxRandomAccessIterator:

stdlib/public/Cxx/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ add_swift_target_library(swiftCxx STATIC NO_LINK_NAME IS_STDLIB IS_SWIFT_ONLY IS
1616
CxxRandomAccessCollection.swift
1717
CxxSequence.swift
1818
CxxVector.swift
19+
CxxSpan.swift
1920
UnsafeCxxIterators.swift
2021

2122
SWIFT_COMPILE_FLAGS ${SWIFT_RUNTIME_SWIFT_COMPILE_FLAGS} ${SWIFT_STANDARD_LIBRARY_SWIFT_FLAGS}

stdlib/public/Cxx/CxxSpan.swift

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2014 - 2023 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See https://swift.org/LICENSE.txt for license information
9+
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
/// A C++ type that is an object that can refer to a contiguous sequence of objects.
14+
///
15+
/// C++ standard library type `std::span` conforms to this protocol.
16+
public protocol CxxSpan<Element> {
17+
associatedtype Element
18+
associatedtype Size: BinaryInteger
19+
20+
init()
21+
init(_ unsafePointer : UnsafePointer<Element>, _ count: Size)
22+
}
23+
24+
extension CxxSpan {
25+
/// Creates a C++ span from a Swift UnsafeBufferPointer
26+
@inlinable
27+
public init(_ unsafeBufferPointer: UnsafeBufferPointer<Element>) {
28+
precondition(unsafeBufferPointer.baseAddress != nil,
29+
"UnsafeBufferPointer should not point to nil")
30+
self.init(unsafeBufferPointer.baseAddress!, Size(unsafeBufferPointer.count))
31+
}
32+
}

0 commit comments

Comments
 (0)