Skip to content

Commit 3e45b91

Browse files
author
marcrasi
authored
TF-917: fix irgen for witness method partial_apply (#27726)
1 parent 7498390 commit 3e45b91

File tree

2 files changed

+36
-11
lines changed

2 files changed

+36
-11
lines changed

lib/IRGen/GenFunc.cpp

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -735,6 +735,16 @@ static llvm::Function *emitPartialApplicationForwarder(IRGenModule &IGM,
735735
// Create a new explosion for potentially reabstracted parameters.
736736
Explosion args;
737737

738+
// SWIFT_ENABLE_TENSORFLOW
739+
// The witness method self argument comes after polymorphic arguments (and is
740+
// followed by the self type and the witness table). However, we may encounter
741+
// the witness method self value before reaching the polymorphic arguments. So
742+
// we create a special explosion for storing the witness method self value
743+
// until it's time to add it to 'args'.
744+
bool isWitnessMethodCallee = origType->getRepresentation() ==
745+
SILFunctionTypeRepresentation::WitnessMethod;
746+
Explosion witnessMethodSelfValue;
747+
738748
Address resultValueAddr;
739749

740750
{
@@ -775,6 +785,10 @@ static llvm::Function *emitPartialApplicationForwarder(IRGenModule &IGM,
775785

776786
// Reemit the parameters as unsubstituted.
777787
for (unsigned i = 0; i < outType->getParameters().size(); ++i) {
788+
// SWIFT_ENABLE_TENSORFLOW
789+
bool isWitnessMethodCalleeSelf =
790+
(isWitnessMethodCallee && i + 1 == origType->getParameters().size());
791+
778792
auto origParamInfo = origType->getParameters()[i];
779793
auto &ti = IGM.getTypeInfoForLowered(origParamInfo.getType());
780794
auto schema = ti.getSchema();
@@ -788,16 +802,19 @@ static llvm::Function *emitPartialApplicationForwarder(IRGenModule &IGM,
788802
if (addr->getType() != ti.getStorageType()->getPointerTo())
789803
addr = subIGF.Builder.CreateBitCast(addr,
790804
ti.getStorageType()->getPointerTo());
791-
args.add(addr);
805+
// SWIFT_ENABLE_TENSORFLOW
806+
(isWitnessMethodCalleeSelf ? witnessMethodSelfValue : args).add(addr);
792807
continue;
793808
}
794809

795810
auto outTypeParamInfo = outType->getParameters()[i];
796811
// Indirect parameters need no mapping through the native calling
797812
// convention.
798813
if (isIndirectParam) {
799-
emitApplyArgument(subIGF, origParamInfo, outTypeParamInfo, origParams,
800-
args);
814+
emitApplyArgument(
815+
subIGF, origParamInfo, outTypeParamInfo, origParams,
816+
// SWIFT_ENABLE_TENSORFLOW
817+
(isWitnessMethodCalleeSelf ? witnessMethodSelfValue : args));
801818
continue;
802819
}
803820

@@ -824,7 +841,10 @@ static llvm::Function *emitPartialApplicationForwarder(IRGenModule &IGM,
824841
Explosion nativeApplyArg = nativeSchemaOrigParam.mapIntoNative(
825842
subIGF.IGM, subIGF, nonNativeApplyArg, origParamSILType, false);
826843
assert(nonNativeApplyArg.empty());
827-
nativeApplyArg.transferInto(args, nativeApplyArg.size());
844+
// SWIFT_ENABLE_TENSORFLOW
845+
nativeApplyArg.transferInto(
846+
(isWitnessMethodCalleeSelf ? witnessMethodSelfValue : args),
847+
nativeApplyArg.size());
828848
}
829849
}
830850

@@ -934,13 +954,6 @@ static llvm::Function *emitPartialApplicationForwarder(IRGenModule &IGM,
934954
auto haveContextArgument =
935955
calleeHasContext || hasSelfContextParameter(origType);
936956

937-
// Witness method calls expect self, followed by the self type followed by,
938-
// the witness table at the end of the parameter list. But polymorphic
939-
// arguments come before this.
940-
bool isWitnessMethodCallee = origType->getRepresentation() ==
941-
SILFunctionTypeRepresentation::WitnessMethod;
942-
Explosion witnessMethodSelfValue;
943-
944957
// If there's a data pointer required, but it's a swift-retainable
945958
// value being passed as the context, just forward it down.
946959
if (!layout) {

test/AutoDiff/irgen_crashers.swift

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
// RUN: %target-swift-frontend -emit-ir %s
2+
3+
// TF-917: `partial_apply` IRGen crash.
4+
public protocol TF_917: Differentiable {
5+
@differentiable
6+
func r<A>(_ a: A) -> Float
7+
}
8+
@differentiable
9+
public func tf_917<B: TF_917>(_ b: B) -> Float {
10+
return b.r(0.0)
11+
}
12+

0 commit comments

Comments
 (0)