Skip to content

[cxx-interop] Conform to UnsafeCxxContiguousIterator based on iterator_concept nested type #77049

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 47 additions & 33 deletions lib/ClangImporter/ClangDerivedConformances.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,18 +125,12 @@ lookupNestedClangTypeDecl(const clang::CXXRecordDecl *clangDecl,

static clang::TypeDecl *
getIteratorCategoryDecl(const clang::CXXRecordDecl *clangDecl) {
clang::IdentifierInfo *iteratorCategoryDeclName =
&clangDecl->getASTContext().Idents.get("iterator_category");
auto iteratorCategories = clangDecl->lookup(iteratorCategoryDeclName);
// If this is a templated typedef, Clang might have instantiated several
// equivalent typedef decls. If they aren't equivalent, Clang has already
// complained about this. Let's assume that they are equivalent. (see
// filterNonConflictingPreviousTypedefDecls in clang/Sema/SemaDecl.cpp)
if (iteratorCategories.empty())
return nullptr;
auto iteratorCategory = iteratorCategories.front();
return lookupNestedClangTypeDecl(clangDecl, "iterator_category");
}

return dyn_cast_or_null<clang::TypeDecl>(iteratorCategory);
static clang::TypeDecl *
getIteratorConceptDecl(const clang::CXXRecordDecl *clangDecl) {
return lookupNestedClangTypeDecl(clangDecl, "iterator_concept");
}

static ValueDecl *lookupOperator(NominalTypeDecl *decl, Identifier id,
Expand Down Expand Up @@ -435,55 +429,54 @@ void swift::conformToCxxIteratorIfNeeded(
if (!iteratorCategory)
return;

auto unwrapUnderlyingTypeDecl =
[](clang::TypeDecl *typeDecl) -> clang::CXXRecordDecl * {
clang::CXXRecordDecl *underlyingDecl = nullptr;
if (auto typedefDecl = dyn_cast<clang::TypedefNameDecl>(typeDecl)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we have multiple layers of TypedefNameDecls? Would that be a problem?
E.g.:

struct MyIterator {
  using iterator_concept =  std::vector<int>::iterator::iterator_concept;
// ...
private:
  std::vector<int>::iterator it;
};

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! Let's make sure this works. I'll add a test.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually this already works fine, I'll add a test in a separate PR.

auto type = typedefDecl->getUnderlyingType();
underlyingDecl = type->getAsCXXRecordDecl();
} else {
underlyingDecl = dyn_cast<clang::CXXRecordDecl>(typeDecl);
}
if (underlyingDecl) {
underlyingDecl = underlyingDecl->getDefinition();
}
return underlyingDecl;
};

// If `iterator_category` is a typedef or a using-decl, retrieve the
// underlying struct decl.
clang::CXXRecordDecl *underlyingCategoryDecl = nullptr;
if (auto typedefDecl = dyn_cast<clang::TypedefNameDecl>(iteratorCategory)) {
auto type = typedefDecl->getUnderlyingType();
underlyingCategoryDecl = type->getAsCXXRecordDecl();
} else {
underlyingCategoryDecl = dyn_cast<clang::CXXRecordDecl>(iteratorCategory);
}
if (underlyingCategoryDecl) {
underlyingCategoryDecl = underlyingCategoryDecl->getDefinition();
}

auto underlyingCategoryDecl = unwrapUnderlyingTypeDecl(iteratorCategory);
if (!underlyingCategoryDecl)
return;

auto isIteratorCategoryDecl = [&](const clang::CXXRecordDecl *base,
StringRef tag) {
auto isIteratorTagDecl = [&](const clang::CXXRecordDecl *base,
StringRef tag) {
return base->isInStdNamespace() && base->getIdentifier() &&
base->getName() == tag;
};
auto isInputIteratorDecl = [&](const clang::CXXRecordDecl *base) {
return isIteratorCategoryDecl(base, "input_iterator_tag");
return isIteratorTagDecl(base, "input_iterator_tag");
};
auto isRandomAccessIteratorDecl = [&](const clang::CXXRecordDecl *base) {
return isIteratorCategoryDecl(base, "random_access_iterator_tag");
return isIteratorTagDecl(base, "random_access_iterator_tag");
};
auto isContiguousIteratorDecl = [&](const clang::CXXRecordDecl *base) {
return isIteratorCategoryDecl(base, "contiguous_iterator_tag"); // C++20
return isIteratorTagDecl(base, "contiguous_iterator_tag"); // C++20
};

// Traverse all transitive bases of `underlyingDecl` to check if
// it inherits from `std::input_iterator_tag`.
bool isInputIterator = isInputIteratorDecl(underlyingCategoryDecl);
bool isRandomAccessIterator =
isRandomAccessIteratorDecl(underlyingCategoryDecl);
bool isContiguousIterator = isContiguousIteratorDecl(underlyingCategoryDecl);
underlyingCategoryDecl->forallBases([&](const clang::CXXRecordDecl *base) {
if (isInputIteratorDecl(base)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do not shortcut the search if we only found an input iterator but we do when we found a random access one. Is this intended?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeap, this is intended. If we found an input_iterator_tag, we might still find a random_access_iterator_tag later during traversal, which gives us stronger guarantees. This isn't very likely to happen in practice, but it's valid to have a complex inheritance tree of the tags.

isInputIterator = true;
}
if (isRandomAccessIteratorDecl(base)) {
isRandomAccessIterator = true;
isInputIterator = true;
}
if (isContiguousIteratorDecl(base)) {
isContiguousIterator = true;
isRandomAccessIterator = true;
isInputIterator = true;
return false;
}
return true;
Expand All @@ -492,6 +485,27 @@ void swift::conformToCxxIteratorIfNeeded(
if (!isInputIterator)
return;

bool isContiguousIterator = false;
// In C++20, `std::contiguous_iterator_tag` is specified as a type called
// `iterator_concept`. It is not possible to detect a contiguous iterator
// based on its `iterator_category`. The type might not have an
// `iterator_concept` defined.
if (auto iteratorConcept = getIteratorConceptDecl(clangDecl)) {
if (auto underlyingConceptDecl =
unwrapUnderlyingTypeDecl(iteratorConcept)) {
isContiguousIterator = isContiguousIteratorDecl(underlyingConceptDecl);
if (!isContiguousIterator)
underlyingConceptDecl->forallBases(
[&](const clang::CXXRecordDecl *base) {
if (isContiguousIteratorDecl(base)) {
isContiguousIterator = true;
return false;
}
return true;
});
}
}

// Check if present: `var pointee: Pointee { get }`
auto pointeeId = ctx.getIdentifier("pointee");
auto pointee = lookupDirectSingleWithoutExtensions<VarDecl>(decl, pointeeId);
Expand Down
66 changes: 63 additions & 3 deletions test/Interop/Cxx/stdlib/overlay/Inputs/custom-iterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,8 @@ struct ConstContiguousIterator {
const int *value;

public:
using iterator_category = std::contiguous_iterator_tag;
using iterator_category = std::random_access_iterator_tag;
using iterator_concept = std::contiguous_iterator_tag;
using value_type = int;
using pointer = int *;
using reference = const int &;
Expand Down Expand Up @@ -403,7 +404,8 @@ struct HasCustomContiguousIteratorTag {

public:
struct CustomTag : std::contiguous_iterator_tag {};
using iterator_category = CustomTag;
using iterator_category = std::random_access_iterator_tag;
using iterator_concept = CustomTag;
using value_type = int;
using pointer = int *;
using reference = const int &;
Expand Down Expand Up @@ -458,7 +460,8 @@ struct MutableContiguousIterator {
int *value;

public:
using iterator_category = std::contiguous_iterator_tag;
using iterator_category = std::random_access_iterator_tag;
using iterator_concept = std::contiguous_iterator_tag;
using value_type = int;
using pointer = int *;
using reference = const int &;
Expand Down Expand Up @@ -507,6 +510,63 @@ struct MutableContiguousIterator {
return value != other.value;
}
};

/// This is actually just a random access iterator
struct HasNoContiguousIteratorConcept {
private:
const int *value;

public:
using iterator_category = std::contiguous_iterator_tag;
// no iterator_concept
using value_type = int;
using pointer = int *;
using reference = const int &;
using difference_type = int;

HasNoContiguousIteratorConcept(const int *value) : value(value) {}
HasNoContiguousIteratorConcept(const HasNoContiguousIteratorConcept &other) =
default;

const int &operator*() const { return *value; }

HasNoContiguousIteratorConcept &operator++() {
value++;
return *this;
}
HasNoContiguousIteratorConcept operator++(int) {
auto tmp = HasNoContiguousIteratorConcept(value);
value++;
return tmp;
}

void operator+=(difference_type v) { value += v; }
void operator-=(difference_type v) { value -= v; }
HasNoContiguousIteratorConcept operator+(difference_type v) const {
return HasNoContiguousIteratorConcept(value + v);
}
HasNoContiguousIteratorConcept operator-(difference_type v) const {
return HasNoContiguousIteratorConcept(value - v);
}
friend HasNoContiguousIteratorConcept
operator+(difference_type v, const HasNoContiguousIteratorConcept &it) {
return it + v;
}
int operator-(const HasNoContiguousIteratorConcept &other) const {
return value - other.value;
}

bool operator<(const HasNoContiguousIteratorConcept &other) const {
return value < other.value;
}

bool operator==(const HasNoContiguousIteratorConcept &other) const {
return value == other.value;
}
bool operator!=(const HasNoContiguousIteratorConcept &other) const {
return value != other.value;
}
};
#endif

// MARK: Types that are not actually iterators
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,10 @@
// CHECK: typealias Pointee = Int32
// CHECK: typealias Distance = Int32
// CHECK: }

// CHECK: struct HasNoContiguousIteratorConcept : UnsafeCxxRandomAccessIterator, UnsafeCxxInputIterator {
// CHECK: func successor() -> HasNoContiguousIteratorConcept
// CHECK: var pointee: Int32
// CHECK: typealias Pointee = Int32
// CHECK: typealias Distance = Int32
// CHECK: }