Skip to content

Commit 66ea6e6

Browse files
Merge pull request #30839 from aschwaighofer/irgen_fix_partial_apply_on_stack_indirect
IRGen: Fix capture of indirect values in ``[onstack]`` closures
2 parents 3a317b3 + 5ab6df8 commit 66ea6e6

File tree

2 files changed

+169
-25
lines changed

2 files changed

+169
-25
lines changed

lib/IRGen/GenFunc.cpp

Lines changed: 64 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -654,17 +654,24 @@ static void emitApplyArgument(IRGenFunction &IGF,
654654
silConv.getSILType(substParam, substFnTy), in, out);
655655
}
656656

657-
static CanType getArgumentLoweringType(CanType type,
658-
SILParameterInfo paramInfo) {
657+
static CanType getArgumentLoweringType(CanType type, SILParameterInfo paramInfo,
658+
bool isNoEscape) {
659659
switch (paramInfo.getConvention()) {
660660
// Capture value parameters by value, consuming them.
661661
case ParameterConvention::Direct_Owned:
662662
case ParameterConvention::Direct_Unowned:
663663
case ParameterConvention::Direct_Guaranteed:
664+
return type;
665+
// Capture indirect parameters if the closure is not [onstack]. [onstack]
666+
// closures don't take ownership of their arguments so we just capture the
667+
// address.
664668
case ParameterConvention::Indirect_In:
665669
case ParameterConvention::Indirect_In_Constant:
666670
case ParameterConvention::Indirect_In_Guaranteed:
667-
return type;
671+
if (isNoEscape)
672+
return CanInOutType::get(type);
673+
else
674+
return type;
668675

669676
// Capture inout parameters by pointer.
670677
case ParameterConvention::Indirect_Inout:
@@ -1067,33 +1074,57 @@ static llvm::Function *emitPartialApplicationForwarder(IRGenModule &IGM,
10671074
auto fieldConvention = conventions[nextCapturedField];
10681075
Address fieldAddr = fieldLayout.project(subIGF, data, offsets);
10691076
auto &fieldTI = fieldLayout.getType();
1070-
auto fieldSchema = fieldTI.getSchema();
10711077
lastCapturedFieldPtr = fieldAddr.getAddress();
10721078

10731079
Explosion param;
10741080
switch (fieldConvention) {
10751081
case ParameterConvention::Indirect_In:
10761082
case ParameterConvention::Indirect_In_Constant: {
1077-
// The +1 argument is passed indirectly, so we need to copy into a
1078-
// temporary.
1079-
needsAllocas = true;
1080-
auto stackAddr = fieldTI.allocateStack(subIGF, fieldTy, "arg.temp");
1081-
auto addressPointer = stackAddr.getAddress().getAddress();
1082-
fieldTI.initializeWithCopy(subIGF, stackAddr.getAddress(), fieldAddr,
1083-
fieldTy, false);
1084-
param.add(addressPointer);
1085-
1086-
// Remember to deallocate later.
1087-
addressesToDeallocate.push_back(
1088-
AddressToDeallocate{fieldTy, fieldTI, stackAddr});
10891083

1084+
auto initStackCopy = [&addressesToDeallocate, &needsAllocas, &param,
1085+
&subIGF](const TypeInfo &fieldTI, SILType fieldTy,
1086+
Address fieldAddr) {
1087+
// The +1 argument is passed indirectly, so we need to copy into a
1088+
// temporary.
1089+
needsAllocas = true;
1090+
auto stackAddr = fieldTI.allocateStack(subIGF, fieldTy, "arg.temp");
1091+
auto addressPointer = stackAddr.getAddress().getAddress();
1092+
fieldTI.initializeWithCopy(subIGF, stackAddr.getAddress(), fieldAddr,
1093+
fieldTy, false);
1094+
param.add(addressPointer);
1095+
1096+
// Remember to deallocate later.
1097+
addressesToDeallocate.push_back(
1098+
AddressToDeallocate{fieldTy, fieldTI, stackAddr});
1099+
};
1100+
1101+
if (outType->isNoEscape()) {
1102+
// If the closure is [onstack] it only captured the address of the
1103+
// value. Load that address from the context.
1104+
Explosion addressExplosion;
1105+
cast<LoadableTypeInfo>(fieldTI).loadAsCopy(subIGF, fieldAddr,
1106+
addressExplosion);
1107+
assert(fieldTy.isAddress());
1108+
auto newFieldTy = fieldTy.getObjectType();
1109+
auto &newFieldTI =
1110+
subIGF.getTypeInfoForLowered(newFieldTy.getASTType());
1111+
fieldAddr =
1112+
newFieldTI.getAddressForPointer(addressExplosion.claimNext());
1113+
initStackCopy(newFieldTI, newFieldTy, fieldAddr);
1114+
} else {
1115+
initStackCopy(fieldTI, fieldTy, fieldAddr);
1116+
}
10901117
break;
10911118
}
10921119
case ParameterConvention::Indirect_In_Guaranteed:
1093-
// The argument is +0, so we can use the address of the param in
1094-
// the context directly.
1095-
param.add(fieldAddr.getAddress());
1096-
dependsOnContextLifetime = true;
1120+
if (outType->isNoEscape()) {
1121+
cast<LoadableTypeInfo>(fieldTI).loadAsCopy(subIGF, fieldAddr, param);
1122+
} else {
1123+
// The argument is +0, so we can use the address of the param in
1124+
// the context directly.
1125+
param.add(fieldAddr.getAddress());
1126+
dependsOnContextLifetime = true;
1127+
}
10971128
break;
10981129
case ParameterConvention::Indirect_Inout:
10991130
case ParameterConvention::Indirect_InoutAliasable:
@@ -1343,7 +1374,8 @@ Optional<StackAddress> irgen::emitFunctionPartialApplication(
13431374
bool considerParameterSources = true;
13441375
for (auto param : params) {
13451376
SILType argType = IGF.IGM.silConv.getSILType(param, origType);
1346-
auto argLoweringTy = getArgumentLoweringType(argType.getASTType(), param);
1377+
auto argLoweringTy = getArgumentLoweringType(argType.getASTType(), param,
1378+
outType->isNoEscape());
13471379
auto &ti = IGF.getTypeInfoForLowered(argLoweringTy);
13481380

13491381
if (!isa<FixedTypeInfo>(ti)) {
@@ -1370,7 +1402,8 @@ Optional<StackAddress> irgen::emitFunctionPartialApplication(
13701402
for (auto param : params) {
13711403
SILType argType = IGF.IGM.silConv.getSILType(param, origType);
13721404

1373-
auto argLoweringTy = getArgumentLoweringType(argType.getASTType(), param);
1405+
auto argLoweringTy = getArgumentLoweringType(argType.getASTType(), param,
1406+
outType->isNoEscape());
13741407

13751408
auto &ti = IGF.getTypeInfoForLowered(argLoweringTy);
13761409

@@ -1608,9 +1641,15 @@ Optional<StackAddress> irgen::emitFunctionPartialApplication(
16081641
case ParameterConvention::Indirect_In:
16091642
case ParameterConvention::Indirect_In_Constant:
16101643
case ParameterConvention::Indirect_In_Guaranteed: {
1611-
auto addr = fieldLayout.getType().getAddressForPointer(args.claimNext());
1612-
fieldLayout.getType().initializeWithTake(IGF, fieldAddr, addr, fieldTy,
1613-
isOutlined);
1644+
if (outType->isNoEscape()) {
1645+
cast<LoadableTypeInfo>(fieldLayout.getType())
1646+
.initialize(IGF, args, fieldAddr, isOutlined);
1647+
} else {
1648+
auto addr =
1649+
fieldLayout.getType().getAddressForPointer(args.claimNext());
1650+
fieldLayout.getType().initializeWithTake(IGF, fieldAddr, addr,
1651+
fieldTy, isOutlined);
1652+
}
16141653
break;
16151654
}
16161655
// Take direct value arguments and inout pointers by value.

test/IRGen/partial_apply.sil

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -633,3 +633,108 @@ unwind:
633633
unwind
634634
}
635635
sil_vtable A3 {}
636+
637+
638+
// CHECK-LABEL: define{{.*}} swiftcc { i8*, %swift.refcounted* } @partial_apply_callee_guaranteed_indirect_guaranteed_class_pair_param
639+
// CHECK-NOT: ret
640+
// CHECK: call void @llvm.memcpy
641+
// CHECK: insertvalue { i8*, %swift.refcounted* } { i8* bitcast (i64 (i64, %swift.refcounted*)* [[FORWARDER:@.*]] to
642+
// CHECK: ret
643+
// CHECK: define{{.*}} swiftcc i64 [[FORWARDER]](i64 %0, %swift.refcounted* swiftself %1)
644+
// CHECK: entry:
645+
// CHECK: [[CAST:%.*]] = bitcast
646+
// CHECK: [[FIELD:%.*]] = getelementptr inbounds {{.*}}* [[CAST]], i32 0, i32 1
647+
// CHECK: call swiftcc i64 @indirect_guaranteed_captured_class_pair_param(i64 %0, %T13partial_apply14SwiftClassPairV* {{.*}} [[FIELD]])
648+
// CHECK: ret
649+
// CHECK: }
650+
651+
sil @partial_apply_callee_guaranteed_indirect_guaranteed_class_pair_param : $@convention(thin) (@in SwiftClassPair) -> @owned @callee_guaranteed (Int) -> Int {
652+
bb0(%x : $*SwiftClassPair):
653+
%f = function_ref @indirect_guaranteed_captured_class_pair_param : $@convention(thin) (Int, @in_guaranteed SwiftClassPair) -> Int
654+
%p = partial_apply [callee_guaranteed] %f(%x) : $@convention(thin) (Int, @in_guaranteed SwiftClassPair) -> Int
655+
return %p : $@callee_guaranteed(Int) -> (Int)
656+
}
657+
658+
sil public_external @use_closure2 : $@convention(thin) (@noescape @callee_guaranteed (Int) -> Int) -> ()
659+
660+
// CHECK-LABEL: define{{.*}} swiftcc void @partial_apply_stack_callee_guaranteed_indirect_guaranteed_class_pair_param(%T13partial_apply14SwiftClassPairV* {{.*}} %0)
661+
// CHECK: [[CLOSURE_STACK_ADDR:%.*]] = getelementptr inbounds <{ %swift.refcounted, %T13partial_apply14SwiftClassPairV* }>, <{ %swift.refcounted, %T13partial_apply14SwiftClassPairV* }>* {{.*}}, i32 0, i32 1
662+
// CHECK: store %T13partial_apply14SwiftClassPairV* %0, %T13partial_apply14SwiftClassPairV** [[CLOSURE_STACK_ADDR]]
663+
// CHECK: call swiftcc void @use_closure2({{.*}}* [[FORWARDER:@.*]] to i8*), %swift.opaque* {{.*}})
664+
// CHECK: ret void
665+
666+
// CHECK: define{{.*}} swiftcc i64 [[FORWARDER]](i64 %0, %swift.refcounted* swiftself %1)
667+
// CHECK: entry:
668+
// CHECK: [[ADDR:%.*]] = load %T13partial_apply14SwiftClassPairV*, %T13partial_apply14SwiftClassPairV**
669+
// CHECK: [[RES:%.*]] = tail call swiftcc i64 @indirect_guaranteed_captured_class_pair_param(i64 %0, %T13partial_apply14SwiftClassPairV* {{.*}} [[ADDR]])
670+
// CHECK: ret i64 [[RES]]
671+
// CHECK: }
672+
673+
sil @partial_apply_stack_callee_guaranteed_indirect_guaranteed_class_pair_param : $@convention(thin) (@in_guaranteed SwiftClassPair) -> () {
674+
bb0(%x : $*SwiftClassPair):
675+
%f = function_ref @indirect_guaranteed_captured_class_pair_param : $@convention(thin) (Int, @in_guaranteed SwiftClassPair) -> Int
676+
%p = partial_apply [callee_guaranteed] [on_stack] %f(%x) : $@convention(thin) (Int, @in_guaranteed SwiftClassPair) -> Int
677+
%u = function_ref @use_closure2 : $@convention(thin) (@noescape @callee_guaranteed (Int) -> Int) -> ()
678+
%r = apply %u(%p) : $@convention(thin) (@noescape @callee_guaranteed (Int) -> Int) -> ()
679+
dealloc_stack %p : $@noescape @callee_guaranteed (Int) ->(Int)
680+
%t = tuple()
681+
return %t : $()
682+
}
683+
684+
// CHECK: define{{.*}} swiftcc void @partial_apply_stack_callee_guaranteed_indirect_in_class_pair_param(%T13partial_apply14SwiftClassPairV* {{.*}} %0)
685+
// CHECK: [[CLOSURE_STACK_ADDR:%.*]] = getelementptr inbounds <{ %swift.refcounted, %T13partial_apply14SwiftClassPairV* }>, <{ %swift.refcounted, %T13partial_apply14SwiftClassPairV* }>* {{.*}}, i32 0, i32 1
686+
// CHECK: store %T13partial_apply14SwiftClassPairV* %0, %T13partial_apply14SwiftClassPairV** [[CLOSURE_STACK_ADDR]]
687+
// CHECK: call swiftcc void @use_closure2({{.*}}* [[FORWARDER:@.*]] to i8*), %swift.opaque* {{.*}})
688+
// CHECK: ret void
689+
690+
// CHECK: define{{.*}} swiftcc i64 [[FORWARDER]](i64 %0, %swift.refcounted* swiftself %1)
691+
// CHECK: entry:
692+
// CHECK: [[TEMP:%.*]] = alloca %T13partial_apply14SwiftClassPairV
693+
// CHECK: [[VALUE_ADDR:%.*]] = load %T13partial_apply14SwiftClassPairV*, %T13partial_apply14SwiftClassPairV** {{.*}}
694+
// CHECK: call %T13partial_apply14SwiftClassPairV* @"$s13partial_apply14SwiftClassPairVWOc"(%T13partial_apply14SwiftClassPairV* [[VALUE_ADDR]], %T13partial_apply14SwiftClassPairV* [[TEMP]])
695+
// CHECK: [[RES:%.*]] = call swiftcc i64 @indirect_in_captured_class_pair_param(i64 %0, %T13partial_apply14SwiftClassPairV* {{.*}} [[TEMP]])
696+
// CHECK: ret i64 [[RES]]
697+
698+
sil public_external @indirect_in_captured_class_pair_param : $@convention(thin) (Int, @in SwiftClassPair) -> Int
699+
700+
sil @partial_apply_stack_callee_guaranteed_indirect_in_class_pair_param : $@convention(thin) (@in SwiftClassPair) -> () {
701+
bb0(%x : $*SwiftClassPair):
702+
%f = function_ref @indirect_in_captured_class_pair_param : $@convention(thin) (Int, @in SwiftClassPair) -> Int
703+
%p = partial_apply [callee_guaranteed] [on_stack] %f(%x) : $@convention(thin) (Int, @in SwiftClassPair) -> Int
704+
%u = function_ref @use_closure2 : $@convention(thin) (@noescape @callee_guaranteed (Int) -> Int) -> ()
705+
%r = apply %u(%p) : $@convention(thin) (@noescape @callee_guaranteed (Int) -> Int) -> ()
706+
dealloc_stack %p : $@noescape @callee_guaranteed (Int) ->(Int)
707+
destroy_addr %x: $*SwiftClassPair
708+
%t = tuple()
709+
return %t : $()
710+
}
711+
712+
713+
// CHECK-LABEL: define{{.*}}swiftcc void @partial_apply_stack_callee_guaranteed_indirect_in_constant_class_pair_param(%T13partial_apply14SwiftClassPairV* {{.*}} %0)
714+
// CHECK: [[CLOSURE_STACK_ADDR:%.*]] = getelementptr inbounds <{ %swift.refcounted, %T13partial_apply14SwiftClassPairV* }>, <{ %swift.refcounted, %T13partial_apply14SwiftClassPairV* }>* {{.*}}, i32 0, i32 1
715+
// CHECK: store %T13partial_apply14SwiftClassPairV* %0, %T13partial_apply14SwiftClassPairV** [[CLOSURE_STACK_ADDR]]
716+
// CHECK: call swiftcc void @use_closure2({{.*}}* [[FORWARDER:@.*]] to i8*), %swift.opaque* {{.*}})
717+
// CHECK: ret void
718+
719+
// CHECK: define{{.*}} swiftcc i64 @"$s46indirect_in_constant_captured_class_pair_paramTA"(i64 %0, %swift.refcounted* swiftself %1)
720+
// CHECK: entry:
721+
// CHECK: [[TEMP:%.*]] = alloca %T13partial_apply14SwiftClassPairV, align 8
722+
// CHECK: [[VALUE_ADDR:%.*]] = load %T13partial_apply14SwiftClassPairV*, %T13partial_apply14SwiftClassPairV** {{.*}}
723+
// CHECK: call %T13partial_apply14SwiftClassPairV* @"$s13partial_apply14SwiftClassPairVWOc"(%T13partial_apply14SwiftClassPairV* [[VALUE_ADDR]], %T13partial_apply14SwiftClassPairV* [[TEMP]])
724+
// CHECK: [[RES:%.*]] = call swiftcc i64 @indirect_in_constant_captured_class_pair_param(i64 %0, %T13partial_apply14SwiftClassPairV* {{.*}} [[TEMP]])
725+
// CHECK: ret i64 [[RES]]
726+
// CHECK: }
727+
728+
sil public_external @indirect_in_constant_captured_class_pair_param : $@convention(thin) (Int, @in_constant SwiftClassPair) -> Int
729+
730+
sil @partial_apply_stack_callee_guaranteed_indirect_in_constant_class_pair_param : $@convention(thin) (@in SwiftClassPair) -> () {
731+
bb0(%x : $*SwiftClassPair):
732+
%f = function_ref @indirect_in_constant_captured_class_pair_param : $@convention(thin) (Int, @in_constant SwiftClassPair) -> Int
733+
%p = partial_apply [callee_guaranteed] [on_stack] %f(%x) : $@convention(thin) (Int, @in_constant SwiftClassPair) -> Int
734+
%u = function_ref @use_closure2 : $@convention(thin) (@noescape @callee_guaranteed (Int) -> Int) -> ()
735+
%r = apply %u(%p) : $@convention(thin) (@noescape @callee_guaranteed (Int) -> Int) -> ()
736+
dealloc_stack %p : $@noescape @callee_guaranteed (Int) ->(Int)
737+
destroy_addr %x: $*SwiftClassPair
738+
%t = tuple()
739+
return %t : $()
740+
}

0 commit comments

Comments
 (0)