Skip to content

Commit 93518a5

Browse files
committed
PR77615
1 parent f2ad9f3 commit 93518a5

File tree

7 files changed

+196
-84
lines changed

7 files changed

+196
-84
lines changed

include/swift/AST/Expr.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4424,6 +4424,8 @@ class ClosureExpr : public AbstractClosureExpr {
44244424
class AutoClosureExpr : public AbstractClosureExpr {
44254425
BraceStmt *Body;
44264426

4427+
ApplyExpr *getUnwrappedCurryThunkImpl() const;
4428+
44274429
public:
44284430
enum class Kind : uint8_t {
44294431
// An autoclosure with type () -> Result. Formed from type checking an
@@ -4485,6 +4487,10 @@ class AutoClosureExpr : public AbstractClosureExpr {
44854487
/// - otherwise, returns nullptr for convenience.
44864488
Expr *getUnwrappedCurryThunkExpr() const;
44874489

4490+
/// Same as getUnwrappedCurryThunkExpr, but get the called ValueDecl instead
4491+
/// of the expr.
4492+
ValueDecl *getUnwrappedCurryThunkCalledValue() const;
4493+
44884494
// Implement isa/cast/dyncast/etc.
44894495
static bool classof(const Expr *E) {
44904496
return E->getKind() == ExprKind::AutoClosure;

lib/AST/Expr.cpp

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2140,7 +2140,7 @@ Expr *AutoClosureExpr::getSingleExpressionBody() const {
21402140
return cast<ReturnStmt>(Body->getLastElement().get<Stmt *>())->getResult();
21412141
}
21422142

2143-
Expr *AutoClosureExpr::getUnwrappedCurryThunkExpr() const {
2143+
ApplyExpr *AutoClosureExpr::getUnwrappedCurryThunkImpl() const {
21442144
auto maybeUnwrapOpenExistential = [](Expr *expr) {
21452145
if (auto *openExistential = dyn_cast<OpenExistentialExpr>(expr)) {
21462146
expr = openExistential->getSubExpr()->getSemanticsProvidingExpr();
@@ -2184,7 +2184,7 @@ Expr *AutoClosureExpr::getUnwrappedCurryThunkExpr() const {
21842184
body = maybeUnwrapConversions(body);
21852185

21862186
if (auto *outerCall = dyn_cast<ApplyExpr>(body)) {
2187-
return outerCall->getFn();
2187+
return outerCall;
21882188
}
21892189

21902190
assert(false && "Malformed curry thunk?");
@@ -2205,7 +2205,7 @@ Expr *AutoClosureExpr::getUnwrappedCurryThunkExpr() const {
22052205
if (auto *outerCall = dyn_cast<ApplyExpr>(innerBody)) {
22062206
auto outerFn = maybeUnwrapConversions(outerCall->getFn());
22072207
if (auto *innerCall = dyn_cast<ApplyExpr>(outerFn)) {
2208-
return innerCall->getFn();
2208+
return innerCall;
22092209
}
22102210
}
22112211
}
@@ -2218,6 +2218,20 @@ Expr *AutoClosureExpr::getUnwrappedCurryThunkExpr() const {
22182218
return nullptr;
22192219
}
22202220

2221+
Expr *AutoClosureExpr::getUnwrappedCurryThunkExpr() const {
2222+
ApplyExpr *ae = getUnwrappedCurryThunkImpl();
2223+
if (ae == nullptr)
2224+
return nullptr;
2225+
return ae->getFn();
2226+
}
2227+
2228+
ValueDecl *AutoClosureExpr::getUnwrappedCurryThunkCalledValue() const {
2229+
ApplyExpr *ae = getUnwrappedCurryThunkImpl();
2230+
if (ae == nullptr)
2231+
return nullptr;
2232+
return ae->getCalledValue(/*skipFunctionConversions=*/true);
2233+
}
2234+
22212235
FORWARD_SOURCE_LOCS_TO(UnresolvedPatternExpr, subPattern)
22222236

22232237
TypeExpr::TypeExpr(TypeRepr *Repr)

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 100 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
#include "swift/SILOptimizer/Utils/SILOptFunctionBuilder.h"
4848
#include "llvm/ADT/APSInt.h"
4949
#include "llvm/ADT/BreadthFirstIterator.h"
50+
#include "llvm/ADT/DenseMap.h"
5051
#include "llvm/ADT/DenseSet.h"
5152
#include "llvm/ADT/SmallSet.h"
5253
#include "llvm/Support/CommandLine.h"
@@ -84,6 +85,9 @@ class DifferentiationTransformer {
8485
/// Context necessary for performing the transformations.
8586
ADContext context;
8687

88+
/// Cache used in getUnwrappedCurryThunkFunction.
89+
llvm::DenseMap<AbstractFunctionDecl *, SILFunction *> afdToSILFn;
90+
8791
/// Promotes the given `differentiable_function` instruction to a valid
8892
/// `@differentiable` function-typed value.
8993
SILValue promoteToDifferentiableFunction(DifferentiableFunctionInst *inst,
@@ -96,6 +100,25 @@ class DifferentiationTransformer {
96100
SILBuilder &builder, SILLocation loc,
97101
DifferentiationInvoker invoker);
98102

103+
/// Emits a reference to a derivative function of `original`, differentiated
104+
/// with respect to a superset of `desiredIndices`. Returns the `SILValue` for
105+
/// the derivative function and the actual indices that the derivative
106+
/// function is with respect to.
107+
///
108+
/// Returns `None` on failure, signifying that a diagnostic has been emitted
109+
/// using `invoker`.
110+
std::optional<std::pair<SILValue, AutoDiffConfig>>
111+
emitDerivativeFunctionReference(
112+
SILBuilder &builder, const AutoDiffConfig &desiredConfig,
113+
AutoDiffDerivativeFunctionKind kind, SILValue original,
114+
DifferentiationInvoker invoker,
115+
SmallVectorImpl<AllocStackInst *> &newBuffersToDealloc);
116+
117+
/// If the given function corresponds to AutoClosureExpr with either
118+
/// SingleCurryThunk or DoubleCurryThunk kind, get the SILFunction
119+
/// corresponding to the function being wrapped in the thunk.
120+
SILFunction *getUnwrappedCurryThunkFunction(SILFunction *originalFn);
121+
99122
public:
100123
/// Construct an `DifferentiationTransformer` for the given module.
101124
explicit DifferentiationTransformer(SILModuleTransform &transform)
@@ -453,21 +476,55 @@ static SILValue reapplyFunctionConversion(
453476
llvm_unreachable("Unhandled function conversion instruction");
454477
}
455478

456-
/// Emits a reference to a derivative function of `original`, differentiated
457-
/// with respect to a superset of `desiredIndices`. Returns the `SILValue` for
458-
/// the derivative function and the actual indices that the derivative function
459-
/// is with respect to.
460-
///
461-
/// Returns `None` on failure, signifying that a diagnostic has been emitted
462-
/// using `invoker`.
463-
static std::optional<std::pair<SILValue, AutoDiffConfig>>
464-
emitDerivativeFunctionReference(
465-
DifferentiationTransformer &transformer, SILBuilder &builder,
466-
const AutoDiffConfig &desiredConfig, AutoDiffDerivativeFunctionKind kind,
467-
SILValue original, DifferentiationInvoker invoker,
468-
SmallVectorImpl<AllocStackInst *> &newBuffersToDealloc) {
469-
ADContext &context = transformer.getContext();
479+
SILFunction *DifferentiationTransformer::getUnwrappedCurryThunkFunction(
480+
SILFunction *originalFn) {
481+
auto *abstractCE = originalFn->getDeclRef().getAbstractClosureExpr();
482+
if (abstractCE == nullptr)
483+
return nullptr;
484+
auto *autoCE = dyn_cast<AutoClosureExpr>(abstractCE);
485+
if (autoCE == nullptr)
486+
return nullptr;
487+
488+
auto *afd =
489+
cast<AbstractFunctionDecl>(autoCE->getUnwrappedCurryThunkCalledValue());
490+
491+
auto silFnIt = afdToSILFn.find(afd);
492+
if (silFnIt == afdToSILFn.end()) {
493+
assert(afdToSILFn.empty());
494+
495+
SILModule *module = getTransform().getModule();
496+
for (SILFunction &currentFunc : module->getFunctions()) {
497+
if (auto *currentAFD =
498+
currentFunc.getDeclRef().getAbstractFunctionDecl()) {
499+
// Update cache only with AFDs which might be potentially wrapped by a
500+
// curry thunk. This includes member function references and references
501+
// to functions having external property wrapper parameters (see
502+
// ExprRewriter::buildDeclRef). If new use cases of curry thunks appear
503+
// in future, the assertion after the loop will be a trigger for such
504+
// cases being unhandled here.
505+
//
506+
// FIXME: References to functions having external property wrapper
507+
// parameters are not handled since we can't now construct a test case
508+
// for that due to the crash
509+
// https://github.com/swiftlang/swift/issues/77613
510+
if (currentAFD->hasCurriedSelf())
511+
afdToSILFn.insert({currentAFD, &currentFunc});
512+
}
513+
}
470514

515+
silFnIt = afdToSILFn.find(afd);
516+
assert(silFnIt != afdToSILFn.end());
517+
}
518+
519+
return silFnIt->second;
520+
}
521+
522+
std::optional<std::pair<SILValue, AutoDiffConfig>>
523+
DifferentiationTransformer::emitDerivativeFunctionReference(
524+
SILBuilder &builder, const AutoDiffConfig &desiredConfig,
525+
AutoDiffDerivativeFunctionKind kind, SILValue original,
526+
DifferentiationInvoker invoker,
527+
SmallVectorImpl<AllocStackInst *> &newBuffersToDealloc) {
471528
// If `original` is itself an `DifferentiableFunctionExtractInst` whose kind
472529
// matches the given kind and desired differentiation parameter indices,
473530
// simply extract the derivative function of its function operand, retain the
@@ -610,26 +667,36 @@ emitDerivativeFunctionReference(
610667
DifferentiabilityKind::Reverse, desiredParameterIndices,
611668
desiredResultIndices, derivativeConstrainedGenSig, /*jvp*/ nullptr,
612669
/*vjp*/ nullptr, /*isSerialized*/ false);
613-
if (transformer.canonicalizeDifferentiabilityWitness(
614-
minimalWitness, invoker, IsNotSerialized))
670+
if (canonicalizeDifferentiabilityWitness(minimalWitness, invoker,
671+
IsNotSerialized))
615672
return std::nullopt;
616673
}
617674
assert(minimalWitness);
618-
if (original->getFunction()->isSerialized() &&
619-
!hasPublicVisibility(minimalWitness->getLinkage())) {
620-
enum { Inlinable = 0, DefaultArgument = 1 };
621-
unsigned fragileKind = Inlinable;
622-
// FIXME: This is not a very robust way of determining if the function is
623-
// a default argument. Also, we have not exhaustively listed all the kinds
624-
// of fragility.
625-
if (original->getFunction()->getLinkage() == SILLinkage::PublicNonABI)
626-
fragileKind = DefaultArgument;
627-
context.emitNondifferentiabilityError(
628-
original, invoker, diag::autodiff_private_derivative_from_fragile,
629-
fragileKind,
630-
isa_and_nonnull<AbstractClosureExpr>(
631-
originalFRI->getLoc().getAsASTNode<Expr>()));
632-
return std::nullopt;
675+
if (original->getFunction()->isSerialized()) {
676+
// When dealing with curry thunk, look at the function being wrapped
677+
// inside implicit closure. If it has public visibility, the corresponding
678+
// differentiability witness also has public visibility. It should be OK
679+
// for implicit wrapper closure and its witness to have private linkage.
680+
SILFunction *unwrappedFn = getUnwrappedCurryThunkFunction(originalFn);
681+
bool isWitnessPublic =
682+
unwrappedFn == nullptr
683+
? hasPublicVisibility(minimalWitness->getLinkage())
684+
: hasPublicVisibility(unwrappedFn->getLinkage());
685+
if (!isWitnessPublic) {
686+
enum { Inlinable = 0, DefaultArgument = 1 };
687+
unsigned fragileKind = Inlinable;
688+
// FIXME: This is not a very robust way of determining if the function
689+
// is a default argument. Also, we have not exhaustively listed all the
690+
// kinds of fragility.
691+
if (original->getFunction()->getLinkage() == SILLinkage::PublicNonABI)
692+
fragileKind = DefaultArgument;
693+
context.emitNondifferentiabilityError(
694+
original, invoker, diag::autodiff_private_derivative_from_fragile,
695+
fragileKind,
696+
isa_and_nonnull<AbstractClosureExpr>(
697+
originalFRI->getLoc().getAsASTNode<Expr>()));
698+
return std::nullopt;
699+
}
633700
}
634701
// TODO(TF-482): Move generic requirement checking logic to
635702
// `getExactDifferentiabilityWitness` and
@@ -1121,8 +1188,8 @@ SILValue DifferentiationTransformer::promoteToDifferentiableFunction(
11211188
for (auto derivativeFnKind : {AutoDiffDerivativeFunctionKind::JVP,
11221189
AutoDiffDerivativeFunctionKind::VJP}) {
11231190
auto derivativeFnAndIndices = emitDerivativeFunctionReference(
1124-
*this, builder, desiredConfig, derivativeFnKind, origFnOperand,
1125-
invoker, newBuffersToDealloc);
1191+
builder, desiredConfig, derivativeFnKind, origFnOperand, invoker,
1192+
newBuffersToDealloc);
11261193
// Show an error at the operator, highlight the argument, and show a note
11271194
// at the definition site of the argument.
11281195
if (!derivativeFnAndIndices)

test/AutoDiff/SILOptimizer/differentiation_diagnostics.swift

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -771,32 +771,6 @@ public func fragileDifferentiable(_ x: Float) -> Float {
771771
implicitlyDifferentiableFromFragile(x)
772772
}
773773

774-
775-
// FIXME: Differentiable curry thunk RequirementMachine error (rdar://87429620, https://github.com/apple/swift/issues/54819).
776-
#if false
777-
// TF-1208: Test curry thunk differentiation regression.
778-
public struct Struct_54819<Scalar> {
779-
var x: Scalar
780-
}
781-
extension Struct_54819: Differentiable where Scalar: Differentiable {
782-
@differentiable(reverse)
783-
public static func id(x: Self) -> Self {
784-
return x
785-
}
786-
}
787-
@differentiable(reverse, wrt: x)
788-
public func f_54819<Scalar: Differentiable>(
789-
_ x: Struct_54819<Scalar>,
790-
// NOTE(TF-1208): This diagnostic is unexpected because `Struct_54819.id` is marked `@differentiable`.
791-
// xpected-error @+3 2 {{function is not differentiable}}
792-
// xpected-note @+2 {{differentiated functions in '@inlinable' functions must be marked '@differentiable' or have a public '@derivative'; this is not possible with a closure, make a top-level function instead}}
793-
// xpected-note @+1 {{opaque non-'@differentiable' function is not differentiable}}
794-
reduction: @differentiable(reverse) (Struct_54819<Scalar>) -> Struct_54819<Scalar> = Struct_54819.id
795-
) -> Struct_54819<Scalar> {
796-
reduction(x)
797-
}
798-
#endif
799-
800774
//===----------------------------------------------------------------------===//
801775
// Coroutines (SIL function yields, `begin_apply`) (not yet supported)
802776
//===----------------------------------------------------------------------===//
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
// RUN: %target-swift-frontend -emit-sil -verify -primary-file %s -o /dev/null
2+
3+
import _Differentiation
4+
5+
/// Minimal reproducer for both single and double curry thunk
6+
7+
@inlinable
8+
func caller<Thing: Differentiable & FloatingPoint>(
9+
of f: @differentiable(reverse) (_: Thing) -> Thing
10+
) -> Int where Thing.TangentVector == Thing {
11+
return 42
12+
}
13+
14+
public struct Struct<Thing: Differentiable & FloatingPoint>: Differentiable where Thing.TangentVector == Thing {
15+
@inlinable
16+
static func foo_single() -> Int {
17+
return caller(of: callee_single) // No error expected
18+
}
19+
20+
@inlinable
21+
@differentiable(reverse)
22+
static func callee_single(input: Thing) -> Thing {
23+
return input
24+
}
25+
26+
@inlinable
27+
func foo_double() -> Int {
28+
return caller(of: callee_double) // No error expected
29+
}
30+
31+
@inlinable
32+
@differentiable(reverse)
33+
func callee_double(input: Thing) -> Thing {
34+
return input
35+
}
36+
}
37+
38+
/// Reproducer from https://github.com/swiftlang/swift/issues/75776
39+
40+
public struct Solution2<Thing: Differentiable & FloatingPoint>: Differentiable where Thing.TangentVector == Thing {
41+
@inlinable
42+
public static func optimization() -> Thing {
43+
var initial = Thing.zero
44+
let (_, delta) = valueWithGradient(at: initial, of: simulationWithLoss) // No error expected
45+
initial.move(by: delta)
46+
return initial
47+
}
48+
49+
@inlinable
50+
@differentiable(reverse)
51+
static func simulationWithLoss(input: Thing) -> Thing {
52+
return input // implementation
53+
}
54+
}
55+
56+
/// Reproducer from https://github.com/swiftlang/swift/issues/54819
57+
58+
public struct TF_688_Struct<Scalar> {
59+
var x: Scalar
60+
}
61+
extension TF_688_Struct: Differentiable where Scalar: Differentiable {
62+
@differentiable(reverse)
63+
public static func id(x: Self) -> Self {
64+
return x
65+
}
66+
}
67+
@differentiable(reverse, wrt: x)
68+
public func TF_688<Scalar: Differentiable>(
69+
_ x: TF_688_Struct<Scalar>,
70+
reduction: @differentiable(reverse) (TF_688_Struct<Scalar>) -> TF_688_Struct<Scalar> = TF_688_Struct.id // No error expected
71+
) -> TF_688_Struct<Scalar> {
72+
reduction(x)
73+
}

test/AutoDiff/SILOptimizer/generics.swift

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -250,27 +250,6 @@ extension TF_682_Proto where Self : Differentiable,
250250
}
251251
}
252252

253-
// NOTE(TF-1208): Differentiation regression due to changes in curry thunk generation.
254-
/*
255-
// TF-688: Test generic curry thunk cloning.
256-
public struct TF_688_Struct<Scalar> {
257-
var x: Scalar
258-
}
259-
extension TF_688_Struct: Differentiable where Scalar: Differentiable {
260-
@differentiable(reverse)
261-
public static func id(x: Self) -> Self {
262-
return x
263-
}
264-
}
265-
@differentiable(reverse, wrt: x)
266-
public func TF_688<Scalar: Differentiable>(
267-
_ x: TF_688_Struct<Scalar>,
268-
reduction: @differentiable(reverse) (TF_688_Struct<Scalar>) -> TF_688_Struct<Scalar> = TF_688_Struct.id
269-
) -> TF_688_Struct<Scalar> {
270-
reduction(x)
271-
}
272-
*/
273-
274253
// TF-697: Test generic requirements of generated derivative function.
275254
protocol TF_697_Module: Differentiable {
276255
associatedtype Input

test/AutoDiff/compiler_crashers/rdar87429620-differentiable-curry-thunk-reqmachine.swift renamed to test/AutoDiff/compiler_crashers_fixed/rdar87429620-differentiable-curry-thunk-reqmachine.swift

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
// RUN: %target-swift-frontend -emit-sil -verify %s
2-
// XFAIL: *
32

43
// rdar://87429620
54
// https://github.com/apple/swift/issues/54819

0 commit comments

Comments
 (0)