Skip to content

Commit 86377b3

Browse files
authored
Merge pull request #76743 from swiftlang/coro-pa-context
Fix partial apply forwarder emission for coroutines that are methods of structs with type parameters
2 parents 5671688 + 6eefd96 commit 86377b3

File tree

2 files changed

+59
-52
lines changed

2 files changed

+59
-52
lines changed

lib/IRGen/GenFunc.cpp

Lines changed: 25 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1120,11 +1120,11 @@ class PartialApplicationForwarderEmission {
11201120
virtual void addDynamicFunctionContext(Explosion &explosion) = 0;
11211121
virtual void addDynamicFunctionPointer(Explosion &explosion) = 0;
11221122

1123-
virtual void addSelf(Explosion &explosion) { addArgument(explosion); }
1124-
virtual void addWitnessSelfMetadata(llvm::Value *value) {
1123+
void addSelf(Explosion &explosion) { addArgument(explosion); }
1124+
void addWitnessSelfMetadata(llvm::Value *value) {
11251125
addArgument(value);
11261126
}
1127-
virtual void addWitnessSelfWitnessTable(llvm::Value *value) {
1127+
void addWitnessSelfWitnessTable(llvm::Value *value) {
11281128
addArgument(value);
11291129
}
11301130
virtual void forwardErrorResult() = 0;
@@ -1438,12 +1438,6 @@ class CoroPartialApplicationForwarderEmission
14381438
: public PartialApplicationForwarderEmission {
14391439
using super = PartialApplicationForwarderEmission;
14401440

1441-
private:
1442-
llvm::Value *Self;
1443-
llvm::Value *FirstData;
1444-
llvm::Value *SecondData;
1445-
WitnessMetadata Witness;
1446-
14471441
public:
14481442
CoroPartialApplicationForwarderEmission(
14491443
IRGenModule &IGM, IRGenFunction &subIGF, llvm::Function *fwd,
@@ -1454,8 +1448,7 @@ class CoroPartialApplicationForwarderEmission
14541448
ArrayRef<ParameterConvention> conventions)
14551449
: PartialApplicationForwarderEmission(
14561450
IGM, subIGF, fwd, staticFnPtr, calleeHasContext, origSig, origType,
1457-
substType, outType, subs, layout, conventions),
1458-
Self(nullptr), FirstData(nullptr), SecondData(nullptr) {}
1451+
substType, outType, subs, layout, conventions) {}
14591452

14601453
void begin() override {
14611454
auto unsubstType = substType->getUnsubstitutedType(IGM.getSILModule());
@@ -1499,41 +1492,13 @@ class CoroPartialApplicationForwarderEmission
14991492
void gatherArgumentsFromApply() override {
15001493
super::gatherArgumentsFromApply(false);
15011494
}
1502-
llvm::Value *getDynamicFunctionPointer() override {
1503-
llvm::Value *Ret = SecondData;
1504-
SecondData = nullptr;
1505-
return Ret;
1506-
}
1507-
llvm::Value *getDynamicFunctionContext() override {
1508-
llvm::Value *Ret = FirstData;
1509-
FirstData = nullptr;
1510-
return Ret;
1511-
}
1495+
llvm::Value *getDynamicFunctionPointer() override { return args.takeLast(); }
1496+
llvm::Value *getDynamicFunctionContext() override { return args.takeLast(); }
15121497
void addDynamicFunctionContext(Explosion &explosion) override {
1513-
assert(!Self && "context value overrides 'self'");
1514-
FirstData = explosion.claimNext();
1498+
addArgument(explosion);
15151499
}
15161500
void addDynamicFunctionPointer(Explosion &explosion) override {
1517-
SecondData = explosion.claimNext();
1518-
}
1519-
void addSelf(Explosion &explosion) override {
1520-
assert(!FirstData && "'self' overrides another context value");
1521-
if (!hasSelfContextParameter(origType)) {
1522-
// witness methods can be declared on types that are not classes. Pass
1523-
// such "self" argument as a plain argument.
1524-
addArgument(explosion);
1525-
return;
1526-
}
1527-
Self = explosion.claimNext();
1528-
FirstData = Self;
1529-
}
1530-
1531-
void addWitnessSelfMetadata(llvm::Value *value) override {
1532-
Witness.SelfMetadata = value;
1533-
}
1534-
1535-
void addWitnessSelfWitnessTable(llvm::Value *value) override {
1536-
Witness.SelfWitnessTable = value;
1501+
addArgument(explosion);
15371502
}
15381503

15391504
void forwardErrorResult() override {
@@ -1554,13 +1519,26 @@ class CoroPartialApplicationForwarderEmission
15541519
}
15551520

15561521
Explosion callCoroutine(FunctionPointer &fnPtr) {
1557-
Callee callee({origType, substType, subs}, fnPtr, FirstData, SecondData);
1522+
bool isWitnessMethodCallee = origType->getRepresentation() ==
1523+
SILFunctionTypeRepresentation::WitnessMethod;
1524+
1525+
WitnessMetadata witnessMetadata;
1526+
if (isWitnessMethodCallee) {
1527+
witnessMetadata.SelfWitnessTable = args.takeLast();
1528+
witnessMetadata.SelfMetadata = args.takeLast();
1529+
}
1530+
1531+
llvm::Value *selfValue = nullptr;
1532+
if (calleeHasContext || hasSelfContextParameter(origType))
1533+
selfValue = args.takeLast();
1534+
1535+
Callee callee({origType, substType, subs}, fnPtr, selfValue);
15581536

15591537
std::unique_ptr<CallEmission> emitSuspend =
1560-
getCallEmission(subIGF, Self, std::move(callee));
1538+
getCallEmission(subIGF, callee.getSwiftContext(), std::move(callee));
15611539

15621540
emitSuspend->begin();
1563-
emitSuspend->setArgs(args, /*isOutlined=*/false, &Witness);
1541+
emitSuspend->setArgs(args, /*isOutlined=*/false, &witnessMetadata);
15641542
Explosion yieldedValues;
15651543
emitSuspend->emitToExplosion(yieldedValues, /*isOutlined=*/false);
15661544
emitSuspend->end();
@@ -1966,12 +1944,7 @@ static llvm::Value *emitPartialApplicationForwarder(
19661944
} else {
19671945
argValue = subIGF.Builder.CreateBitCast(rawData, expectedArgTy);
19681946
}
1969-
if (haveContextArgument) {
1970-
Explosion e;
1971-
e.add(argValue);
1972-
emission->addDynamicFunctionContext(e);
1973-
} else
1974-
emission->addArgument(argValue);
1947+
emission->addArgument(argValue);
19751948

19761949
// If there's a data pointer required, grab it and load out the
19771950
// extra, previously-curried parameters.

test/AutoDiff/validation-test/modify_accessor.swift

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,5 +42,39 @@ ModifyAccessorTests.test("SimpleModifyAccessor") {
4242
expectEqual((100, 20), valueWithGradient(at: 10, of: modify_struct))
4343
}
4444

45+
ModifyAccessorTests.test("GenericModifyAccessor") {
46+
struct S<T : Differentiable & SignedNumeric & Comparable>: Differentiable {
47+
private var _x : T
48+
49+
func _endMutation() {}
50+
51+
var x: T {
52+
get{_x}
53+
set(newValue) { _x = newValue }
54+
_modify {
55+
defer { _endMutation() }
56+
if (x > -x) {
57+
yield &_x
58+
} else {
59+
yield &_x
60+
}
61+
}
62+
}
63+
64+
init(_ x : T) {
65+
self._x = x
66+
}
67+
}
68+
69+
func modify_struct(_ x : Float) -> Float {
70+
var s = S<Float>(x)
71+
s.x *= s.x
72+
return s.x
73+
}
74+
75+
expectEqual((100, 20), valueWithGradient(at: 10, of: modify_struct))
76+
}
77+
78+
4579
runAllTests()
4680

0 commit comments

Comments
 (0)