Skip to content

Commit 60378e1

Browse files
authored
Merge pull request #64233 from eeckstein/optimize-convert-function
SILCombine: handle `convert_escape_to_noescape` in the apply-of-convert-function optimization
2 parents 6491dd5 + 9deb942 commit 60378e1

File tree

5 files changed

+61
-14
lines changed

5 files changed

+61
-14
lines changed

lib/SILOptimizer/SILCombiner/SILCombinerApplyVisitors.cpp

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,9 @@ SILCombiner::optimizeApplyOfConvertFunctionInst(FullApplySite AI,
135135
if (auto *TTI = dyn_cast<ThinToThickFunctionInst>(funcOper))
136136
funcOper = TTI->getOperand();
137137

138-
auto *FRI = dyn_cast<FunctionRefInst>(funcOper);
139-
if (!FRI)
138+
if (!isa<FunctionRefInst>(funcOper) &&
139+
// Optimizing partial_apply will then enable the partial_apply -> apply peephole.
140+
!isa<PartialApplyInst>(funcOper))
140141
return nullptr;
141142

142143
// Grab our relevant callee types...
@@ -151,8 +152,8 @@ SILCombiner::optimizeApplyOfConvertFunctionInst(FullApplySite AI,
151152
// relevant types from the ConvertFunction function type and AI.
152153
Builder.setCurrentDebugScope(AI.getDebugScope());
153154
OperandValueArrayRef Ops = AI.getArguments();
154-
SILFunctionConventions substConventions(SubstCalleeTy, FRI->getModule());
155-
SILFunctionConventions convertConventions(ConvertCalleeTy, FRI->getModule());
155+
SILFunctionConventions substConventions(SubstCalleeTy, CFI->getModule());
156+
SILFunctionConventions convertConventions(ConvertCalleeTy, CFI->getModule());
156157
auto context = AI.getFunction()->getTypeExpansionContext();
157158
auto oldOpRetTypes = substConventions.getIndirectSILResultTypes(context);
158159
auto newOpRetTypes = convertConventions.getIndirectSILResultTypes(context);
@@ -229,7 +230,7 @@ SILCombiner::optimizeApplyOfConvertFunctionInst(FullApplySite AI,
229230
Builder.createBranch(AI.getLoc(), TAI->getNormalBB(), branchArgs);
230231
}
231232

232-
return Builder.createTryApply(AI.getLoc(), FRI, SubstitutionMap(), Args,
233+
return Builder.createTryApply(AI.getLoc(), funcOper, SubstitutionMap(), Args,
233234
normalBB, TAI->getErrorBB(),
234235
TAI->getApplyOptions());
235236
}
@@ -239,9 +240,9 @@ SILCombiner::optimizeApplyOfConvertFunctionInst(FullApplySite AI,
239240
// otherwise, we would be creating malformed SIL).
240241
ApplyOptions Options = AI.getApplyOptions();
241242
Options -= ApplyFlags::DoesNotThrow;
242-
if (FRI->getFunctionType()->hasErrorResult())
243+
if (funcOper->getType().castTo<SILFunctionType>()->hasErrorResult())
243244
Options |= ApplyFlags::DoesNotThrow;
244-
ApplyInst *NAI = Builder.createApply(AI.getLoc(), FRI, SubstitutionMap(),
245+
ApplyInst *NAI = Builder.createApply(AI.getLoc(), funcOper, SubstitutionMap(),
245246
Args, Options);
246247
SILInstruction *result = NAI;
247248

@@ -1455,7 +1456,11 @@ SILInstruction *SILCombiner::visitApplyInst(ApplyInst *AI) {
14551456
if (isa<PartialApplyInst>(AI->getCallee()))
14561457
return nullptr;
14571458

1458-
if (auto *CFI = dyn_cast<ConvertFunctionInst>(AI->getCallee()))
1459+
SILValue callee = AI->getCallee();
1460+
if (auto *cee = dyn_cast<ConvertEscapeToNoEscapeInst>(callee)) {
1461+
callee = cee->getOperand();
1462+
}
1463+
if (auto *CFI = dyn_cast<ConvertFunctionInst>(callee))
14591464
return optimizeApplyOfConvertFunctionInst(AI, CFI);
14601465

14611466
if (tryOptimizeKeypath(AI))

lib/SILOptimizer/SILCombiner/SILCombinerMiscVisitors.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2377,6 +2377,11 @@ SILCombiner::visitDifferentiableFunctionExtractInst(DifferentiableFunctionExtrac
23772377
// match the type of the original `differentiable_function_extract`,
23782378
// create a `convert_function`.
23792379
if (newValue->getType() != DFEI->getType()) {
2380+
CanSILFunctionType opTI = newValue->getType().castTo<SILFunctionType>();
2381+
CanSILFunctionType resTI = DFEI->getType().castTo<SILFunctionType>();
2382+
if (!opTI->isABICompatibleWith(resTI, *DFEI->getFunction()).isCompatible())
2383+
return nullptr;
2384+
23802385
std::tie(newValue, std::ignore) =
23812386
castValueToABICompatibleType(&Builder, DFEI->getLoc(),
23822387
newValue,

test/AutoDiff/e2e_optimizations.swift

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -87,14 +87,10 @@ func test_gradient_float_loop() {
8787
}
8888

8989
// Check whether `apply`s are inlined.
90-
// Currently, the VJP is inlined but the pullback is not.
9190
// CHECK-LABEL: sil hidden @test_gradient_float_loop : $@convention(thin) () -> ()
92-
// CHECK: [[PB_FN_REF:%.*]] = function_ref @{{.*}}24test_gradient_float_loopyyFS2fcfU_TJrSpSr : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
93-
// CHECK: [[GRADIENT_RESULT:%.*]] = apply [[PB_FN_REF]]
94-
// CHECK: [[EXTRACT:%.*]] = tuple_extract [[GRADIENT_RESULT]]
95-
// CHECK: [[GRADIENT_RESULT2:%.*]] = apply [[EXTRACT]]
91+
// CHECK: = function_ref @${{.*24test_gradient_float_loopyyFS2fcfU_TJrSpSr|sSf16_DifferentiationE12_vjpMultiply3lhs3rhsSf5value_Sf_SftSfc8pullbacktSf_SftFZSf_SftSfcfU_}}
9692
// CHECK: [[FN_REF:%.*]] = function_ref @$s9blackHoleSf_Tg5 : $@convention(thin) (Float) -> Float
97-
// CHECK-NEXT: apply [[FN_REF:%.*]]([[GRADIENT_RESULT2]])
93+
// CHECK-NEXT: apply [[FN_REF:%.*]]
9894
// CHECK-NOT: apply
9995
// CHECK-LABEL: } // end sil function 'test_gradient_float_loop'
10096
func array_loop(_ array: [Float]) -> Float {
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
// RUN: %target-swift-frontend %s -O -module-name=test -emit-sil | %FileCheck %s
2+
3+
// REQUIRES: swift_stdlib_no_asserts,optimized_stdlib
4+
5+
public struct S {
6+
var x: Int
7+
8+
@inline(never)
9+
mutating func doSomething() { }
10+
}
11+
12+
// Check that all partial_applys can be optimized away so that no closure context needs to be allocated.
13+
14+
// CHECK-LABEL: sil @$s4test6testit_1xySDySiAA1SVGz_SitF :
15+
// CHECK-NOT: partial_apply
16+
// CHECK: } // end sil function '$s4test6testit_1xySDySiAA1SVGz_SitF'
17+
public func testit(_ data: inout [Int: S], x: Int) {
18+
data[x, default: S(x: x)].doSomething()
19+
}

test/SILOptimizer/sil_combine.sil

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1442,6 +1442,28 @@ entry(%a : $AnotherClass):
14421442
return %r : $MyNSObj
14431443
}
14441444

1445+
sil @createMyInt : $@convention(thin) (Builtin.Int32) -> @out MyInt
1446+
1447+
// CHECK-LABEL: sil @convert_function_of_closure :
1448+
// CHECK: [[F:%.*]] = function_ref @createMyInt
1449+
// CHECK: [[S:%.*]] = alloc_stack $MyInt
1450+
// CHECK: apply [[F]]([[S]], %0)
1451+
// CHECK-NOT: strong_release
1452+
// CHECK: } // end sil function 'convert_function_of_closure'
1453+
sil @convert_function_of_closure : $@convention(thin) (Builtin.Int32) -> () {
1454+
bb0(%0 : $Builtin.Int32):
1455+
%1 = function_ref @createMyInt : $@convention(thin) (Builtin.Int32) -> @out MyInt
1456+
%2 = partial_apply [callee_guaranteed] %1(%0) : $@convention(thin) (Builtin.Int32) -> @out MyInt
1457+
%3 = convert_function %2 : $@callee_guaranteed () -> @out MyInt to $@callee_guaranteed @substituted <τ_0_0> () -> @out τ_0_0 for <MyInt>
1458+
%4 = convert_escape_to_noescape %3 : $@callee_guaranteed @substituted <τ_0_0> () -> @out τ_0_0 for <MyInt> to $@noescape @callee_guaranteed @substituted <τ_0_0> () -> @out τ_0_0 for <MyInt>
1459+
%5 = alloc_stack $MyInt
1460+
%6 = apply %4(%5) : $@noescape @callee_guaranteed @substituted <τ_0_0> () -> @out τ_0_0 for <MyInt>
1461+
dealloc_stack %5 : $*MyInt
1462+
strong_release %2 : $@callee_guaranteed () -> @out MyInt
1463+
%9 = tuple ()
1464+
return %9 : $()
1465+
}
1466+
14451467
// CHECK-LABEL: sil @upcast_formation : $@convention(thin) (@inout E, E, @inout B) -> B {
14461468
// CHECK: bb0
14471469
// CHECK-NEXT: upcast

0 commit comments

Comments
 (0)