Skip to content

Commit 7505a8a

Browse files
authored
Merge pull request #64517 from rjmccall/arity-reabstraction-closures
Implement arity reabstraction for closures
2 parents b984933 + 5cf05f5 commit 7505a8a

19 files changed

+819
-386
lines changed

include/swift/AST/Types.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3103,6 +3103,8 @@ class AnyFunctionType : public TypeBase {
31033103
return true;
31043104
return false;
31053105
}
3106+
3107+
Param getCanonical(CanGenericSignature genericSig) const;
31063108

31073109
ParameterTypeFlags getParameterFlags() const { return Flags; }
31083110

@@ -6864,6 +6866,8 @@ class PackType final : public TypeBase, public llvm::FoldingSetNode,
68646866
BEGIN_CAN_TYPE_WRAPPER(PackType, Type)
68656867
static CanPackType get(const ASTContext &ctx, ArrayRef<CanType> elements);
68666868
static CanPackType get(const ASTContext &ctx, CanTupleEltTypeArrayRef elts);
6869+
static CanPackType get(const ASTContext &ctx,
6870+
AnyFunctionType::CanParamArrayRef params);
68676871

68686872
static CanTypeWrapper<PackType>
68696873
getSingletonPackExpansion(CanType packParameter);

include/swift/Basic/ArrayRefView.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,17 @@ template<typename Orig, typename Projected>
167167
using CastArrayRefView =
168168
ArrayRefView<Orig, Projected *, arrayRefViewCastHelper<Projected, Orig>>;
169169

170+
namespace generator_details {
171+
template <class T> struct is_array_ref_like;
172+
173+
template <class Orig, class Projected, Projected (&Project)(const Orig &),
174+
bool AllowOrigAccess>
175+
struct is_array_ref_like<ArrayRefView<Orig, Projected, Project,
176+
AllowOrigAccess>> {
177+
enum { value = true };
178+
};
179+
}
180+
170181
} // end namespace swift
171182

172183
#endif // SWIFT_BASIC_ARRAYREFVIEW_H

include/swift/Basic/Generators.h

Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
//===--- Generators.h - "Coroutines" for doing traversals -------*- C++ -*-===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2023 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See https://swift.org/LICENSE.txt for license information
9+
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===----------------------------------------------------------------------===//
12+
//
13+
// This file defines a few types for defining types that follow this
14+
// simple generator concept:
15+
//
16+
// concept Generator {
17+
// // ...some number of accessors for the current value...
18+
//
19+
// /// Is this generator finished producing values?
20+
// bool isFinished() const;
21+
//
22+
// /// Given that this generator is not finished, advance to the
23+
// /// next value.
24+
// void advance();
25+
//
26+
// /// Finish the generator, asserting that all values have been
27+
// /// produced.
28+
// void finish();
29+
// };
30+
//
31+
// concept SimpleGenerator : Generator {
32+
// type reference;
33+
//
34+
// reference claimNext();
35+
// }
36+
//
37+
// Generators are useful when some structure needs to be traversed but
38+
// that traversal can't be done in a simple lexical loop. For example,
39+
// you can't do two traversals in parallel with a single loop unless you
40+
// break down all the details of the traversal. This is a minor problem
41+
// for simple traversals like walking a flat array, but it's a significant
42+
// problem when traversals get more complex, like when different components
43+
// of an array are grouped together according to some additional structure
44+
// (such as the abstraction pattern of a function's parameter list).
45+
// It's tempting to write those traversals as higher-order functions that
46+
// invoke a callback for each element, but this breaks down when parallel
47+
// traversal is required. Expressing the traversal as a generator
48+
// allows the traversal logic to to be reused without that limitation.
49+
//
50+
//===----------------------------------------------------------------------===//
51+
52+
#ifndef SWIFT_BASIC_GENERATORS_H
53+
#define SWIFT_BASIC_GENERATORS_H
54+
55+
#include "llvm/ADT/ArrayRef.h"
56+
57+
namespace swift {
58+
59+
namespace generator_details {
60+
61+
template <class T>
62+
struct is_array_ref_like {
63+
enum { value = false };
64+
};
65+
66+
template <class T>
67+
struct is_array_ref_like<llvm::ArrayRef<T>> {
68+
enum { value = true };
69+
};
70+
71+
template <class T>
72+
struct is_array_ref_like<llvm::MutableArrayRef<T>> {
73+
enum { value = true };
74+
};
75+
}
76+
77+
/// A class for generating the elements of an ArrayRef-like collection.
78+
template <class CollectionType>
79+
class ArrayRefGenerator {
80+
static_assert(generator_details::is_array_ref_like<CollectionType>::value,
81+
"ArrayRefGenerator should only be used with ArrayRef-like "
82+
"types");
83+
84+
CollectionType values;
85+
86+
public:
87+
using reference =
88+
typename std::iterator_traits<typename CollectionType::iterator>::reference;
89+
90+
ArrayRefGenerator() {}
91+
ArrayRefGenerator(CollectionType values) : values(values) {}
92+
93+
// Prevent accidental copying of the generator.
94+
ArrayRefGenerator(const ArrayRefGenerator &other) = delete;
95+
ArrayRefGenerator &operator=(const ArrayRefGenerator &other) = delete;
96+
97+
ArrayRefGenerator(ArrayRefGenerator &&other) = default;
98+
ArrayRefGenerator &operator=(ArrayRefGenerator &&other) = default;
99+
100+
/// Explicitly copy the current generator state.
101+
ArrayRefGenerator clone() const {
102+
return ArrayRefGenerator(values);
103+
}
104+
105+
/// Return the current element of the array.
106+
reference getCurrent() const {
107+
assert(!isFinished());
108+
return values.front();
109+
}
110+
111+
/// Claim the current element of the array and advance past it.
112+
reference claimNext() {
113+
assert(!isFinished());
114+
reference result = getCurrent();
115+
advance();
116+
return result;
117+
}
118+
119+
/// Claim the next N elements of the array and advance past them.
120+
CollectionType claimNext(size_t count) {
121+
assert(count <= values.size() && "claiming too many values");
122+
CollectionType result = values.slice(0, count);
123+
values = values.slice(count);
124+
return result;
125+
}
126+
127+
/// Is this generation finished?
128+
bool isFinished() const {
129+
return values.empty();
130+
}
131+
132+
/// Given that this generation is not finished, advance to the
133+
/// next element.
134+
void advance() {
135+
assert(!isFinished());
136+
values = values.slice(1);
137+
}
138+
139+
/// Perform any final work required to complete the generation.
140+
void finish() {
141+
assert(isFinished() && "didn't finish generating the collection");
142+
}
143+
};
144+
145+
/// An abstracting reference to an existing generator.
146+
///
147+
/// The implementation of this type holds the reference to the existing
148+
/// generator without allocating any additional storage; it is sufficient
149+
/// for the caller ensures that the object passed to the constructor
150+
/// stays valid. Values of this type can otherwise be safely copied
151+
/// around.
152+
template <class T>
153+
class SimpleGeneratorRef {
154+
public:
155+
using reference = T;
156+
157+
private:
158+
struct VTable {
159+
bool (*isFinished)(const void *impl);
160+
reference (*claimNext)(void *impl);
161+
void (*advance)(void *impl);
162+
void (*finish)(void *impl);
163+
};
164+
165+
template <class G> struct VTableImpl {
166+
static constexpr VTable vtable = {
167+
[](const void *p) { return static_cast<const G*>(p)->isFinished(); },
168+
[](void *p) -> reference { return static_cast<G*>(p)->claimNext(); },
169+
[](void *p) { static_cast<G*>(p)->advance(); },
170+
[](void *p) { static_cast<G*>(p)->finish(); },
171+
};
172+
};
173+
174+
const VTable *vtable;
175+
void *pointer;
176+
177+
public:
178+
constexpr SimpleGeneratorRef() : vtable(nullptr), pointer(nullptr) {}
179+
180+
template <class G>
181+
constexpr SimpleGeneratorRef(G &generator)
182+
: vtable(&VTableImpl<G>::vtable), pointer(&generator) {}
183+
184+
/// Test whether this generator ref was initialized with a
185+
/// valid reference to a generator.
186+
explicit operator bool() const {
187+
return pointer != nullptr;
188+
}
189+
190+
bool isFinished() const {
191+
assert(pointer);
192+
return vtable->isFinished(pointer);
193+
}
194+
195+
reference claimNext() {
196+
assert(pointer);
197+
return vtable->claimNext(pointer);
198+
}
199+
200+
void advance() {
201+
assert(pointer);
202+
vtable->advance(pointer);
203+
}
204+
205+
void finish() {
206+
assert(pointer);
207+
vtable->finish(pointer);
208+
}
209+
};
210+
211+
} // end namespace swift
212+
213+
#endif

include/swift/SIL/SILArgument.h

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,12 +338,14 @@ class SILFunctionArgument : public SILArgument {
338338
ValueOwnershipKind ownershipKind, const ValueDecl *decl = nullptr,
339339
bool isNoImplicitCopy = false,
340340
LifetimeAnnotation lifetimeAnnotation = LifetimeAnnotation::None,
341-
bool isCapture = false)
341+
bool isCapture = false,
342+
bool isParameterPack = false)
342343
: SILArgument(ValueKind::SILFunctionArgument, parentBlock, type,
343344
ownershipKind, decl) {
344345
sharedUInt32().SILFunctionArgument.noImplicitCopy = isNoImplicitCopy;
345346
sharedUInt32().SILFunctionArgument.lifetimeAnnotation = lifetimeAnnotation;
346347
sharedUInt32().SILFunctionArgument.closureCapture = isCapture;
348+
sharedUInt32().SILFunctionArgument.parameterPack = isParameterPack;
347349
}
348350

349351
// A special constructor, only intended for use in
@@ -369,6 +371,23 @@ class SILFunctionArgument : public SILArgument {
369371
sharedUInt32().SILFunctionArgument.closureCapture = newValue;
370372
}
371373

374+
/// Is this parameter a pack that corresponds to multiple
375+
/// formal parameters? (This could mean multiple ParamDecl*s,
376+
/// or it could mean a ParamDecl* that's a pack expansion.) Note
377+
/// that not all lowered parameters of pack type are parameter packs:
378+
/// they can be part of a single formal parameter of tuple type.
379+
/// This flag indicates that the lowered parameter has a one-to-many
380+
/// relationship with formal parameters.
381+
///
382+
/// TODO: preserve the parameter pack references in SIL in a side table
383+
/// instead of using a single bit.
384+
bool isFormalParameterPack() const {
385+
return sharedUInt32().SILFunctionArgument.parameterPack;
386+
}
387+
void setFormalParameterPack(bool isPack) {
388+
sharedUInt32().SILFunctionArgument.parameterPack = isPack;
389+
}
390+
372391
LifetimeAnnotation getLifetimeAnnotation() const {
373392
return LifetimeAnnotation::Case(
374393
sharedUInt32().SILFunctionArgument.lifetimeAnnotation);
@@ -412,6 +431,7 @@ class SILFunctionArgument : public SILArgument {
412431
setNoImplicitCopy(arg->isNoImplicitCopy());
413432
setLifetimeAnnotation(arg->getLifetimeAnnotation());
414433
setClosureCapture(arg->isClosureCapture());
434+
setFormalParameterPack(arg->isFormalParameterPack());
415435
}
416436

417437
static bool classof(const SILInstruction *) = delete;

include/swift/SIL/SILNode.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,8 @@ class alignas(8) SILNode :
276276
SHARED_FIELD(StringLiteralInst, uint32_t length);
277277
SHARED_FIELD(PointerToAddressInst, uint32_t alignment);
278278
SHARED_FIELD(SILFunctionArgument, uint32_t noImplicitCopy : 1,
279-
lifetimeAnnotation : 2, closureCapture : 1);
279+
lifetimeAnnotation : 2, closureCapture : 1,
280+
parameterPack : 1);
280281

281282
// Do not use `_sharedUInt32_private` outside of SILNode.
282283
} _sharedUInt32_private;

lib/AST/ASTContext.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3306,6 +3306,16 @@ CanPackType CanPackType::get(const ASTContext &C,
33063306
return CanPackType(PackType::get(C, ncElements));
33073307
}
33083308

3309+
CanPackType CanPackType::get(const ASTContext &C,
3310+
AnyFunctionType::CanParamArrayRef params) {
3311+
SmallVector<Type, 8> ncElements;
3312+
ncElements.reserve(params.size());
3313+
for (auto param : params) {
3314+
ncElements.push_back(param.getParameterType());
3315+
}
3316+
return CanPackType(PackType::get(C, ncElements));
3317+
}
3318+
33093319
PackType *PackType::get(const ASTContext &C, ArrayRef<Type> elements) {
33103320
RecursiveTypeProperties properties;
33113321
bool isCanonical = true;

lib/AST/Type.cpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1606,14 +1606,19 @@ getCanonicalParams(AnyFunctionType *funcType,
16061606
SmallVectorImpl<AnyFunctionType::Param> &canParams) {
16071607
auto origParams = funcType->getParams();
16081608
for (auto param : origParams) {
1609-
// Canonicalize the type and drop the internal label to canonicalize the
1610-
// Param.
1611-
canParams.emplace_back(param.getPlainType()->getReducedType(genericSig),
1612-
param.getLabel(), param.getParameterFlags(),
1613-
/*InternalLabel=*/Identifier());
1609+
canParams.emplace_back(param.getCanonical(genericSig));
16141610
}
16151611
}
16161612

1613+
AnyFunctionType::Param
1614+
AnyFunctionType::Param::getCanonical(CanGenericSignature genericSig) const {
1615+
// Canonicalize the type and drop the internal label to canonicalize the
1616+
// Param.
1617+
return Param(getPlainType()->getReducedType(genericSig),
1618+
getLabel(), getParameterFlags(),
1619+
/*InternalLabel=*/Identifier());
1620+
}
1621+
16171622
CanType TypeBase::computeCanonicalType() {
16181623
assert(!hasCanonicalTypeComputed() && "called unnecessarily");
16191624

lib/SIL/IR/AbstractionPattern.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2310,10 +2310,13 @@ const {
23102310
if (yieldType)
23112311
yieldType = yieldType->getReducedType(substSig);
23122312

2313+
// Note that we specifically do not want to put subMap in the
2314+
// abstraction patterns here, because the types we will be lowering
2315+
// against them will not be substituted.
23132316
return std::make_tuple(
2314-
AbstractionPattern(subMap, substSig, substTy->getReducedType(substSig)),
2317+
AbstractionPattern(substSig, substTy->getReducedType(substSig)),
23152318
subMap,
23162319
yieldType
2317-
? AbstractionPattern(subMap, substSig, yieldType)
2320+
? AbstractionPattern(substSig, yieldType)
23182321
: AbstractionPattern::getInvalid());
23192322
}

0 commit comments

Comments
 (0)