Skip to content

Commit 5f62183

Browse files
authored
[AutoDiff] Fix differentiable_function-related specialization crashes. (#29800)
Fix crashes related to generic specialization of `partial_apply` operands to `differentiable_function` instructions. `differentiable_function` requires derivative function operand types to match expected derivative function types computed from the original function operand's type, so operands cannot be specialized individually without specializing the others. Resolves TF-891 and TF-1126.
1 parent e40d33e commit 5f62183

File tree

4 files changed

+79
-3
lines changed

4 files changed

+79
-3
lines changed

lib/SILOptimizer/IPO/CapturePropagation.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,17 @@ static SILFunction *getSpecializedWithDeadParams(
408408
return nullptr;
409409
}
410410

411+
// SWIFT_ENABLE_TENSORFLOW
412+
// Disable specialization for instructions that are operands of
413+
// `differentiable_function` instructions. `differentiable_function`
414+
// requires derivative function operand types to match expected derivative
415+
// function types computed from the original function operand's type, so
416+
// operands cannot be specialized individually without specializing the
417+
// others.
418+
if (!PAI->getUsersOfType<DifferentiableFunctionInst>().empty())
419+
return nullptr;
420+
// SWIFT_ENABLE_TENSORFLOW END
421+
411422
auto Rep = Specialized->getLoweredFunctionType()->getRepresentation();
412423
if (getSILFunctionLanguage(Rep) != SILFunctionLanguage::Swift)
413424
return nullptr;
@@ -463,7 +474,7 @@ bool CapturePropagation::optimizePartialApply(PartialApplyInst *PAI) {
463474
if (auto *NewFunc = getSpecializedWithDeadParams(FuncBuilder,
464475
PAI, SubstF, PAI->getNumArguments(), GenericSpecialized)) {
465476
// SWIFT_ENABLE_TENSORFLOW
466-
// Add a previously unexercised check to prevent AD crash. Rewrite
477+
// Add a previously unexercised check to prevent AutoDiff crash. Rewrite
467478
// `partial_apply` only if the specialized function is `@convention(thin)`.
468479
// Revert check when `VJPEmitter::visitApplyInst` no longer produces
469480
// argumentless `partial_apply` instructions.
@@ -476,6 +487,7 @@ bool CapturePropagation::optimizePartialApply(PartialApplyInst *PAI) {
476487
}
477488
return true;
478489
}
490+
// SWIFT_ENABLE_TENSORFLOW END
479491
}
480492

481493
// Second possibility: Are all partially applied arguments constant?

lib/SILOptimizer/Utils/Generics.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -531,7 +531,11 @@ bool ReabstractionInfo::prepareAndCheck(ApplySite Apply, SILFunction *Callee,
531531

532532
// SWIFT_ENABLE_TENSORFLOW
533533
// Disable specialization for instructions that are operands of
534-
// `differentiable_function` instructions.
534+
// `differentiable_function` instructions. `differentiable_function`
535+
// requires derivative function operand types to match expected derivative
536+
// function types computed from the original function operand's type, so
537+
// operands cannot be specialized individually without specializing the
538+
// others.
535539
if (Apply.getInstruction())
536540
for (auto result : Apply.getInstruction()->getResults())
537541
for (auto use : result->getUses())
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
// RUN: %target-swift-frontend -O -emit-sil %s -verify
2+
// REQUIRES: asserts
3+
4+
// TF-1126: Generic specialization crash during capture propagation.
5+
// Related to `@differentiable` function with `partial_apply` operands,
6+
// to be specialized. Occurs only with `-O`.
7+
8+
struct A: Differentiable{
9+
var b: SIMD8<Float>
10+
}
11+
12+
@differentiable
13+
func function(a: A) -> A {
14+
var a = a
15+
a.b = a.b - SIMD8<Float>(repeating: 1.0)
16+
return a
17+
}
18+
19+
let masks: [SIMD8<Float>] = [[1,0,0,0,0,0,0,0],
20+
[0,1,0,0,0,0,0,0],
21+
[0,0,1,0,0,0,0,0],
22+
[0,0,0,1,0,0,0,0],
23+
[0,0,0,0,1,0,0,0],
24+
[0,0,0,0,0,1,0,0],
25+
[0,0,0,0,0,0,1,0],
26+
[0,0,0,0,0,0,0,1]]
27+
28+
extension SIMD8 where Scalar == Float{
29+
@differentiable(where Scalar: Differentiable)
30+
func updated(at index: Int, with newValue: Scalar) -> Self {
31+
let mask = masks[index]
32+
let result = self - (self * mask) + (newValue * mask)
33+
return result
34+
}
35+
}
36+
37+
// Looking for a function: $ss4SIMDPss14DifferentiableRzSB6Scalars11SIMDStoragePRpzsAA13TangentVectorsACPRpzSBAhI_AdFRPzrlE12_vjpSubtract3lhs3rhsx5value_AJ_AJtAJc8pullbacktx_xtFZs5SIMD8VySfG_Tg5
38+
// Expected type: @convention(method) (@in_guaranteed SIMD8<Float>, @in_guaranteed SIMD8<Float>, @thick SIMD8<Float>.Type) -> (@out SIMD8<Float>, @owned @callee_guaranteed (@in_guaranteed SIMD8<Float>) -> (@out SIMD8<Float>, @out SIMD8<Float>))
39+
// Found type: @convention(method) (SIMD8<Float>, SIMD8<Float>, @thick SIMD8<Float>.Type) -> (@out SIMD8<Float>, @owned @callee_guaranteed (@in_guaranteed SIMD8<Float>) -> (@out SIMD8<Float>, @out SIMD8<Float>))
40+
// Assertion failed: (ReInfo.getSpecializedType() == SpecializedF->getLoweredFunctionType() && "Previously specialized function does not match expected type."), function lookupSpecialization, file /Users/swiftninjas/s4tf/swift/lib/SILOptimizer/Utils/Generics.cpp, line 1833.
41+
// Stack dump:
42+
// ...
43+
// 1. Swift version 5.2-dev (Swift bf631dc2e4)
44+
// 2. While running pass #113021 SILFunctionTransform "CapturePropagation" on SILFunction "@AD__$ss5SIMD8V6deleteSfRszrlE7updated2at4withABySfGSi_SftF__vjp_src_0_wrt_1_2".
45+
// for 'updated(at:with:)' (at /Users/porter/Dropbox (PassiveLogic)/Team/Team Members Scratch Space/Porter/Experiments/Playgrounds/delete/delete/main.swift:75:5)
46+
// llvm::sys::PrintStackTrace(llvm::raw_ostream&) + 37
47+
// llvm::sys::RunSignalHandlers() + 85
48+
// SignalHandler(int) + 278
49+
// _sigtramp + 29
50+
// _sigtramp + 2821162056
51+
// abort + 127
52+
// basename_r + 0
53+
// swift::GenericFuncSpecializer::lookupSpecialization() (.cold.1) + 35
54+
// swift::GenericFuncSpecializer::lookupSpecialization() + 2109
55+
// (anonymous namespace)::CapturePropagation::optimizePartialApply(swift::PartialApplyInst*) + 1301
56+
// (anonymous namespace)::CapturePropagation::run() + 265

test/AutoDiff/downstream/compiler_crashers/tf891-protocol-req-capture-propagation.swift renamed to test/AutoDiff/downstream/compiler_crashers_fixed/tf891-protocol-req-capture-propagation.swift

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1-
// RUN: not --crash %target-swift-frontend -O -emit-ir %s
1+
// RUN: %target-swift-frontend -O -emit-ir %s
22
// REQUIRES: asserts
33

4+
// TF-891: Generic specialization crash during capture propagation.
5+
// Related to `@differentiable` function with `partial_apply` operands,
6+
// to be specialized. Occurs only with `-O`.
7+
48
public protocol Protocol: Differentiable {
59
@differentiable
610
func requirement1<T: Protocol>(_ arg: T) -> Float

0 commit comments

Comments
 (0)