Skip to content

Commit 42b3db7

Browse files
committed
[cxx-interop] Conform to UnsafeCxxContiguousIterator based on iterator_concept nested type
3a200de has a logic bug where we tried to conform C++ iterator types to `UnsafeCxxContiguousIterator` protocol based on their nested type called `iterator_category`. The C++20 standard says we should rely on `iterator_concept` instead. https://en.cppreference.com/w/cpp/iterator/iterator_tags#Iterator_concept Despite what the name suggests, we are not actually using C++ concepts in this change. rdar://137877849
1 parent 844d103 commit 42b3db7

File tree

3 files changed

+119
-36
lines changed

3 files changed

+119
-36
lines changed

lib/ClangImporter/ClangDerivedConformances.cpp

Lines changed: 49 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -125,18 +125,12 @@ lookupNestedClangTypeDecl(const clang::CXXRecordDecl *clangDecl,
125125

126126
static clang::TypeDecl *
127127
getIteratorCategoryDecl(const clang::CXXRecordDecl *clangDecl) {
128-
clang::IdentifierInfo *iteratorCategoryDeclName =
129-
&clangDecl->getASTContext().Idents.get("iterator_category");
130-
auto iteratorCategories = clangDecl->lookup(iteratorCategoryDeclName);
131-
// If this is a templated typedef, Clang might have instantiated several
132-
// equivalent typedef decls. If they aren't equivalent, Clang has already
133-
// complained about this. Let's assume that they are equivalent. (see
134-
// filterNonConflictingPreviousTypedefDecls in clang/Sema/SemaDecl.cpp)
135-
if (iteratorCategories.empty())
136-
return nullptr;
137-
auto iteratorCategory = iteratorCategories.front();
128+
return lookupNestedClangTypeDecl(clangDecl, "iterator_category");
129+
}
138130

139-
return dyn_cast_or_null<clang::TypeDecl>(iteratorCategory);
131+
static clang::TypeDecl *
132+
getIteratorConceptDecl(const clang::CXXRecordDecl *clangDecl) {
133+
return lookupNestedClangTypeDecl(clangDecl, "iterator_concept");
140134
}
141135

142136
static ValueDecl *lookupOperator(NominalTypeDecl *decl, Identifier id,
@@ -435,55 +429,64 @@ void swift::conformToCxxIteratorIfNeeded(
435429
if (!iteratorCategory)
436430
return;
437431

432+
// In C++20, `std::contiguous_iterator_tag` is specified as a type called
433+
// `iterator_concept`. It is not possible to detect a contiguous iterator
434+
// based on its `iterator_category`. The type might not have an
435+
// `iterator_concept` defined.
436+
auto iteratorConcept = getIteratorConceptDecl(clangDecl);
437+
438+
auto unwrapUnderlyingTypeDecl =
439+
[](clang::TypeDecl *typeDecl) -> clang::CXXRecordDecl * {
440+
clang::CXXRecordDecl *underlyingDecl = nullptr;
441+
if (auto typedefDecl = dyn_cast<clang::TypedefNameDecl>(typeDecl)) {
442+
auto type = typedefDecl->getUnderlyingType();
443+
underlyingDecl = type->getAsCXXRecordDecl();
444+
} else {
445+
underlyingDecl = dyn_cast<clang::CXXRecordDecl>(typeDecl);
446+
}
447+
if (underlyingDecl) {
448+
underlyingDecl = underlyingDecl->getDefinition();
449+
}
450+
return underlyingDecl;
451+
};
452+
438453
// If `iterator_category` is a typedef or a using-decl, retrieve the
439454
// underlying struct decl.
440-
clang::CXXRecordDecl *underlyingCategoryDecl = nullptr;
441-
if (auto typedefDecl = dyn_cast<clang::TypedefNameDecl>(iteratorCategory)) {
442-
auto type = typedefDecl->getUnderlyingType();
443-
underlyingCategoryDecl = type->getAsCXXRecordDecl();
444-
} else {
445-
underlyingCategoryDecl = dyn_cast<clang::CXXRecordDecl>(iteratorCategory);
446-
}
447-
if (underlyingCategoryDecl) {
448-
underlyingCategoryDecl = underlyingCategoryDecl->getDefinition();
449-
}
450-
455+
auto underlyingCategoryDecl = unwrapUnderlyingTypeDecl(iteratorCategory);
451456
if (!underlyingCategoryDecl)
452457
return;
453458

454-
auto isIteratorCategoryDecl = [&](const clang::CXXRecordDecl *base,
455-
StringRef tag) {
459+
// Same for `iterator_concept`.
460+
auto underlyingConceptDecl =
461+
iteratorConcept ? unwrapUnderlyingTypeDecl(iteratorConcept) : nullptr;
462+
463+
auto isIteratorTagDecl = [&](const clang::CXXRecordDecl *base,
464+
StringRef tag) {
456465
return base->isInStdNamespace() && base->getIdentifier() &&
457466
base->getName() == tag;
458467
};
459468
auto isInputIteratorDecl = [&](const clang::CXXRecordDecl *base) {
460-
return isIteratorCategoryDecl(base, "input_iterator_tag");
469+
return isIteratorTagDecl(base, "input_iterator_tag");
461470
};
462471
auto isRandomAccessIteratorDecl = [&](const clang::CXXRecordDecl *base) {
463-
return isIteratorCategoryDecl(base, "random_access_iterator_tag");
472+
return isIteratorTagDecl(base, "random_access_iterator_tag");
464473
};
465474
auto isContiguousIteratorDecl = [&](const clang::CXXRecordDecl *base) {
466-
return isIteratorCategoryDecl(base, "contiguous_iterator_tag"); // C++20
475+
return isIteratorTagDecl(base, "contiguous_iterator_tag"); // C++20
467476
};
468477

469478
// Traverse all transitive bases of `underlyingDecl` to check if
470479
// it inherits from `std::input_iterator_tag`.
471480
bool isInputIterator = isInputIteratorDecl(underlyingCategoryDecl);
472481
bool isRandomAccessIterator =
473482
isRandomAccessIteratorDecl(underlyingCategoryDecl);
474-
bool isContiguousIterator = isContiguousIteratorDecl(underlyingCategoryDecl);
475483
underlyingCategoryDecl->forallBases([&](const clang::CXXRecordDecl *base) {
476484
if (isInputIteratorDecl(base)) {
477485
isInputIterator = true;
478486
}
479487
if (isRandomAccessIteratorDecl(base)) {
480488
isRandomAccessIterator = true;
481489
isInputIterator = true;
482-
}
483-
if (isContiguousIteratorDecl(base)) {
484-
isContiguousIterator = true;
485-
isRandomAccessIterator = true;
486-
isInputIterator = true;
487490
return false;
488491
}
489492
return true;
@@ -492,6 +495,19 @@ void swift::conformToCxxIteratorIfNeeded(
492495
if (!isInputIterator)
493496
return;
494497

498+
bool isContiguousIterator = false;
499+
if (underlyingConceptDecl) {
500+
isContiguousIterator = isContiguousIteratorDecl(underlyingConceptDecl);
501+
if (!isContiguousIterator)
502+
underlyingConceptDecl->forallBases([&](const clang::CXXRecordDecl *base) {
503+
if (isContiguousIteratorDecl(base)) {
504+
isContiguousIterator = true;
505+
return false;
506+
}
507+
return true;
508+
});
509+
}
510+
495511
// Check if present: `var pointee: Pointee { get }`
496512
auto pointeeId = ctx.getIdentifier("pointee");
497513
auto pointee = lookupDirectSingleWithoutExtensions<VarDecl>(decl, pointeeId);

test/Interop/Cxx/stdlib/overlay/Inputs/custom-iterator.h

Lines changed: 63 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,8 @@ struct ConstContiguousIterator {
348348
const int *value;
349349

350350
public:
351-
using iterator_category = std::contiguous_iterator_tag;
351+
using iterator_category = std::random_access_iterator_tag;
352+
using iterator_concept = std::contiguous_iterator_tag;
352353
using value_type = int;
353354
using pointer = int *;
354355
using reference = const int &;
@@ -403,7 +404,8 @@ struct HasCustomContiguousIteratorTag {
403404

404405
public:
405406
struct CustomTag : std::contiguous_iterator_tag {};
406-
using iterator_category = CustomTag;
407+
using iterator_category = std::random_access_iterator_tag;
408+
using iterator_concept = CustomTag;
407409
using value_type = int;
408410
using pointer = int *;
409411
using reference = const int &;
@@ -458,7 +460,8 @@ struct MutableContiguousIterator {
458460
int *value;
459461

460462
public:
461-
using iterator_category = std::contiguous_iterator_tag;
463+
using iterator_category = std::random_access_iterator_tag;
464+
using iterator_concept = std::contiguous_iterator_tag;
462465
using value_type = int;
463466
using pointer = int *;
464467
using reference = const int &;
@@ -507,6 +510,63 @@ struct MutableContiguousIterator {
507510
return value != other.value;
508511
}
509512
};
513+
514+
/// This is actually just a random access iterator
515+
struct HasNoContiguousIteratorConcept {
516+
private:
517+
const int *value;
518+
519+
public:
520+
using iterator_category = std::contiguous_iterator_tag;
521+
// no iterator_concept
522+
using value_type = int;
523+
using pointer = int *;
524+
using reference = const int &;
525+
using difference_type = int;
526+
527+
HasNoContiguousIteratorConcept(const int *value) : value(value) {}
528+
HasNoContiguousIteratorConcept(const HasNoContiguousIteratorConcept &other) =
529+
default;
530+
531+
const int &operator*() const { return *value; }
532+
533+
HasNoContiguousIteratorConcept &operator++() {
534+
value++;
535+
return *this;
536+
}
537+
HasNoContiguousIteratorConcept operator++(int) {
538+
auto tmp = HasNoContiguousIteratorConcept(value);
539+
value++;
540+
return tmp;
541+
}
542+
543+
void operator+=(difference_type v) { value += v; }
544+
void operator-=(difference_type v) { value -= v; }
545+
HasNoContiguousIteratorConcept operator+(difference_type v) const {
546+
return HasNoContiguousIteratorConcept(value + v);
547+
}
548+
HasNoContiguousIteratorConcept operator-(difference_type v) const {
549+
return HasNoContiguousIteratorConcept(value - v);
550+
}
551+
friend HasNoContiguousIteratorConcept
552+
operator+(difference_type v, const HasNoContiguousIteratorConcept &it) {
553+
return it + v;
554+
}
555+
int operator-(const HasNoContiguousIteratorConcept &other) const {
556+
return value - other.value;
557+
}
558+
559+
bool operator<(const HasNoContiguousIteratorConcept &other) const {
560+
return value < other.value;
561+
}
562+
563+
bool operator==(const HasNoContiguousIteratorConcept &other) const {
564+
return value == other.value;
565+
}
566+
bool operator!=(const HasNoContiguousIteratorConcept &other) const {
567+
return value != other.value;
568+
}
569+
};
510570
#endif
511571

512572
// MARK: Types that are not actually iterators

test/Interop/Cxx/stdlib/overlay/custom-contiguous-iterator-module-interface.swift

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,10 @@
2626
// CHECK: typealias Pointee = Int32
2727
// CHECK: typealias Distance = Int32
2828
// CHECK: }
29+
30+
// CHECK: struct HasNoContiguousIteratorConcept : UnsafeCxxRandomAccessIterator, UnsafeCxxInputIterator {
31+
// CHECK: func successor() -> HasNoContiguousIteratorConcept
32+
// CHECK: var pointee: Int32
33+
// CHECK: typealias Pointee = Int32
34+
// CHECK: typealias Distance = Int32
35+
// CHECK: }

0 commit comments

Comments
 (0)