Skip to content

Commit 6eefd96

Browse files
committed
Fix partial apply forwarder emission for coroutines that are methods
of structs with type parameters. Simplify the code while here
1 parent a2d9808 commit 6eefd96

File tree

2 files changed

+59
-52
lines changed

2 files changed

+59
-52
lines changed

lib/IRGen/GenFunc.cpp

Lines changed: 25 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1120,11 +1120,11 @@ class PartialApplicationForwarderEmission {
11201120
virtual void addDynamicFunctionContext(Explosion &explosion) = 0;
11211121
virtual void addDynamicFunctionPointer(Explosion &explosion) = 0;
11221122

1123-
virtual void addSelf(Explosion &explosion) { addArgument(explosion); }
1124-
virtual void addWitnessSelfMetadata(llvm::Value *value) {
1123+
void addSelf(Explosion &explosion) { addArgument(explosion); }
1124+
void addWitnessSelfMetadata(llvm::Value *value) {
11251125
addArgument(value);
11261126
}
1127-
virtual void addWitnessSelfWitnessTable(llvm::Value *value) {
1127+
void addWitnessSelfWitnessTable(llvm::Value *value) {
11281128
addArgument(value);
11291129
}
11301130
virtual void forwardErrorResult() = 0;
@@ -1412,12 +1412,6 @@ class CoroPartialApplicationForwarderEmission
14121412
: public PartialApplicationForwarderEmission {
14131413
using super = PartialApplicationForwarderEmission;
14141414

1415-
private:
1416-
llvm::Value *Self;
1417-
llvm::Value *FirstData;
1418-
llvm::Value *SecondData;
1419-
WitnessMetadata Witness;
1420-
14211415
public:
14221416
CoroPartialApplicationForwarderEmission(
14231417
IRGenModule &IGM, IRGenFunction &subIGF, llvm::Function *fwd,
@@ -1428,8 +1422,7 @@ class CoroPartialApplicationForwarderEmission
14281422
ArrayRef<ParameterConvention> conventions)
14291423
: PartialApplicationForwarderEmission(
14301424
IGM, subIGF, fwd, staticFnPtr, calleeHasContext, origSig, origType,
1431-
substType, outType, subs, layout, conventions),
1432-
Self(nullptr), FirstData(nullptr), SecondData(nullptr) {}
1425+
substType, outType, subs, layout, conventions) {}
14331426

14341427
void begin() override {
14351428
auto unsubstType = substType->getUnsubstitutedType(IGM.getSILModule());
@@ -1473,41 +1466,13 @@ class CoroPartialApplicationForwarderEmission
14731466
void gatherArgumentsFromApply() override {
14741467
super::gatherArgumentsFromApply(false);
14751468
}
1476-
llvm::Value *getDynamicFunctionPointer() override {
1477-
llvm::Value *Ret = SecondData;
1478-
SecondData = nullptr;
1479-
return Ret;
1480-
}
1481-
llvm::Value *getDynamicFunctionContext() override {
1482-
llvm::Value *Ret = FirstData;
1483-
FirstData = nullptr;
1484-
return Ret;
1485-
}
1469+
llvm::Value *getDynamicFunctionPointer() override { return args.takeLast(); }
1470+
llvm::Value *getDynamicFunctionContext() override { return args.takeLast(); }
14861471
void addDynamicFunctionContext(Explosion &explosion) override {
1487-
assert(!Self && "context value overrides 'self'");
1488-
FirstData = explosion.claimNext();
1472+
addArgument(explosion);
14891473
}
14901474
void addDynamicFunctionPointer(Explosion &explosion) override {
1491-
SecondData = explosion.claimNext();
1492-
}
1493-
void addSelf(Explosion &explosion) override {
1494-
assert(!FirstData && "'self' overrides another context value");
1495-
if (!hasSelfContextParameter(origType)) {
1496-
// witness methods can be declared on types that are not classes. Pass
1497-
// such "self" argument as a plain argument.
1498-
addArgument(explosion);
1499-
return;
1500-
}
1501-
Self = explosion.claimNext();
1502-
FirstData = Self;
1503-
}
1504-
1505-
void addWitnessSelfMetadata(llvm::Value *value) override {
1506-
Witness.SelfMetadata = value;
1507-
}
1508-
1509-
void addWitnessSelfWitnessTable(llvm::Value *value) override {
1510-
Witness.SelfWitnessTable = value;
1475+
addArgument(explosion);
15111476
}
15121477

15131478
void forwardErrorResult() override {
@@ -1528,13 +1493,26 @@ class CoroPartialApplicationForwarderEmission
15281493
}
15291494

15301495
Explosion callCoroutine(FunctionPointer &fnPtr) {
1531-
Callee callee({origType, substType, subs}, fnPtr, FirstData, SecondData);
1496+
bool isWitnessMethodCallee = origType->getRepresentation() ==
1497+
SILFunctionTypeRepresentation::WitnessMethod;
1498+
1499+
WitnessMetadata witnessMetadata;
1500+
if (isWitnessMethodCallee) {
1501+
witnessMetadata.SelfWitnessTable = args.takeLast();
1502+
witnessMetadata.SelfMetadata = args.takeLast();
1503+
}
1504+
1505+
llvm::Value *selfValue = nullptr;
1506+
if (calleeHasContext || hasSelfContextParameter(origType))
1507+
selfValue = args.takeLast();
1508+
1509+
Callee callee({origType, substType, subs}, fnPtr, selfValue);
15321510

15331511
std::unique_ptr<CallEmission> emitSuspend =
1534-
getCallEmission(subIGF, Self, std::move(callee));
1512+
getCallEmission(subIGF, callee.getSwiftContext(), std::move(callee));
15351513

15361514
emitSuspend->begin();
1537-
emitSuspend->setArgs(args, /*isOutlined=*/false, &Witness);
1515+
emitSuspend->setArgs(args, /*isOutlined=*/false, &witnessMetadata);
15381516
Explosion yieldedValues;
15391517
emitSuspend->emitToExplosion(yieldedValues, /*isOutlined=*/false);
15401518
emitSuspend->end();
@@ -1940,12 +1918,7 @@ static llvm::Value *emitPartialApplicationForwarder(
19401918
} else {
19411919
argValue = subIGF.Builder.CreateBitCast(rawData, expectedArgTy);
19421920
}
1943-
if (haveContextArgument) {
1944-
Explosion e;
1945-
e.add(argValue);
1946-
emission->addDynamicFunctionContext(e);
1947-
} else
1948-
emission->addArgument(argValue);
1921+
emission->addArgument(argValue);
19491922

19501923
// If there's a data pointer required, grab it and load out the
19511924
// extra, previously-curried parameters.

test/AutoDiff/validation-test/modify_accessor.swift

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,5 +39,39 @@ ModifyAccessorTests.test("SimpleModifyAccessor") {
3939
expectEqual((100, 20), valueWithGradient(at: 10, of: modify_struct))
4040
}
4141

42+
ModifyAccessorTests.test("GenericModifyAccessor") {
43+
struct S<T : Differentiable & SignedNumeric & Comparable>: Differentiable {
44+
private var _x : T
45+
46+
func _endMutation() {}
47+
48+
var x: T {
49+
get{_x}
50+
set(newValue) { _x = newValue }
51+
_modify {
52+
defer { _endMutation() }
53+
if (x > -x) {
54+
yield &_x
55+
} else {
56+
yield &_x
57+
}
58+
}
59+
}
60+
61+
init(_ x : T) {
62+
self._x = x
63+
}
64+
}
65+
66+
func modify_struct(_ x : Float) -> Float {
67+
var s = S<Float>(x)
68+
s.x *= s.x
69+
return s.x
70+
}
71+
72+
expectEqual((100, 20), valueWithGradient(at: 10, of: modify_struct))
73+
}
74+
75+
4276
runAllTests()
4377

0 commit comments

Comments
 (0)