Skip to content

Commit 126f1ac

Browse files
authored
[AutoDiff] Disable differentiable_function_extract explicit type as… (#35239)
`differentiability_function_extract` instruction has an optional explicit extractee type. This is currently used by TypeSubstCloner and the LoadableByAddress transform to rewrite `differentiability_function_extract` instructions while preserving `@differentiable` function type invariants. There is an assertion that `differentiability_function_extract` instructions do not have explicit extractee types outside of canonical/lowered SIL. However, this does not handle the SIL deserialization case above: when a function containing a `differentiable_function_extract` instruction with an explicit type is deserialized into a raw SIL module (which happens when optimizations are enabled). Removing the assertion unblocks this encountered use case. A more robust longer-term solution may be to change SIL `@differentiable` function types to explicitly store component original/JVP/VJP function types. Also fix `differentiable_function_extract` extractee type serialization. Resolves SR-14004.
1 parent 650c1e6 commit 126f1ac

File tree

7 files changed

+53
-20
lines changed

7 files changed

+53
-20
lines changed

include/swift/SIL/SILInstruction.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8920,7 +8920,12 @@ class DifferentiableFunctionExtractInst
89208920
SILModule &module);
89218921

89228922
public:
8923-
/// Note: explicit extractee type may be specified only in lowered SIL.
8923+
/// Note: explicit extractee type is used to avoid inconsistent typing in:
8924+
/// - Canonical SIL, due to generic specialization.
8925+
/// - Lowered SIL, due to LoadableByAddress.
8926+
/// - Raw SIL, due to deserialization of canonical/lowered SIL functions.
8927+
/// See `TypeSubstCloner::visitDifferentiableFunctionExtractInst` for an
8928+
/// explanation of how explicit extractee type is used.
89248929
explicit DifferentiableFunctionExtractInst(
89258930
SILModule &module, SILDebugLocation debugLoc,
89268931
NormalDifferentiableFunctionTypeComponent extractee, SILValue function,

lib/SIL/IR/SILInstructions.cpp

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -731,18 +731,6 @@ DifferentiableFunctionExtractInst::DifferentiableFunctionExtractInst(
731731
: getExtracteeType(function, extractee, module),
732732
function.getOwnershipKind()),
733733
Extractee(extractee), HasExplicitExtracteeType(extracteeType.hasValue()) {
734-
#ifndef NDEBUG
735-
if (extracteeType.hasValue()) {
736-
// Note: explicit extractee type is used to avoid inconsistent typing in:
737-
// - Canonical SIL, due to generic specialization.
738-
// - Lowered SIL, due to LoadableByAddress.
739-
// See `TypeSubstCloner::visitDifferentiableFunctionExtractInst` for an
740-
// explanation of how explicit extractee type is used.
741-
assert((module.getStage() == SILStage::Canonical ||
742-
module.getStage() == SILStage::Lowered) &&
743-
"Explicit type is valid only in canonical or lowered SIL");
744-
}
745-
#endif
746734
}
747735

748736
SILType LinearFunctionExtractInst::

lib/Serialization/DeserializeSIL.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1186,7 +1186,7 @@ bool SILDeserializer::readSILInstruction(SILFunction *Fn,
11861186
case SIL_INST_DIFFERENTIABLE_FUNCTION_EXTRACT:
11871187
SILInstDifferentiableFunctionExtractLayout::readRecord(
11881188
scratch, TyID, TyCategory, ValID, /*extractee*/ Attr,
1189-
/*hasExplicitExtracteeType*/ Attr2);
1189+
/*hasExplicitExtracteeType*/ Attr2, /*explicitExtracteeType*/ TyID2);
11901190
RawOpCode = (unsigned)SILInstructionKind::DifferentiableFunctionExtractInst;
11911191
break;
11921192
case SIL_INST_LINEAR_FUNCTION_EXTRACT:
@@ -2747,8 +2747,11 @@ bool SILDeserializer::readSILInstruction(SILFunction *Fn,
27472747
auto val = getLocalValue(ValID, silTy);
27482748
NormalDifferentiableFunctionTypeComponent extractee(Attr);
27492749
Optional<SILType> explicitExtracteeType = None;
2750-
if (Attr2)
2751-
explicitExtracteeType = silTy;
2750+
if (Attr2) {
2751+
auto extracteeASTType = MF->getType(TyID2);
2752+
explicitExtracteeType =
2753+
getSILType(extracteeASTType, SILValueCategory::Object, Fn);
2754+
}
27522755
ResultInst = Builder.createDifferentiableFunctionExtract(
27532756
Loc, extractee, val, explicitExtracteeType);
27542757
break;

lib/Serialization/ModuleFormat.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ const uint16_t SWIFTMODULE_VERSION_MAJOR = 0;
5656
/// describe what change you made. The content of this comment isn't important;
5757
/// it just ensures a conflict if two people change the module format.
5858
/// Don't worry about adhering to the 80-column limit for this line.
59-
const uint16_t SWIFTMODULE_VERSION_MINOR = 589; // cache prespecialization decls.
59+
const uint16_t SWIFTMODULE_VERSION_MINOR = 590; // differentiable_function_extract explicit extractee type
6060

6161
/// A standard hash seed used for all string hashes in a serialized module.
6262
///

lib/Serialization/SILFormat.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -478,8 +478,9 @@ namespace sil_block {
478478
TypeIDField,
479479
SILTypeCategoryField,
480480
ValueIDField,
481-
BCFixed<2>, // extractee
482-
BCFixed<1> // has explicit extractee type?
481+
BCFixed<2>, // extractee
482+
BCFixed<1>, // has explicit extractee type?
483+
TypeIDField // explicit extractee type
483484
>;
484485

485486
using SILInstLinearFunctionExtractLayout = BCRecordLayout<

lib/Serialization/SerializeSIL.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2294,11 +2294,13 @@ void SILSerializer::writeSILInstruction(const SILInstruction &SI) {
22942294
auto operandType = dfei->getOperand()->getType();
22952295
auto operandTypeRef = S.addTypeRef(operandType.getASTType());
22962296
auto rawExtractee = (unsigned)dfei->getExtractee();
2297+
auto extracteeTypeRef = S.addTypeRef(dfei->getType().getASTType());
22972298
SILInstDifferentiableFunctionExtractLayout::emitRecord(
22982299
Out, ScratchRecord,
22992300
SILAbbrCodes[SILInstDifferentiableFunctionExtractLayout::Code],
23002301
operandTypeRef, (unsigned)operandType.getCategory(), operandRef,
2301-
rawExtractee, (unsigned)dfei->hasExplicitExtracteeType());
2302+
rawExtractee, (unsigned)dfei->hasExplicitExtracteeType(),
2303+
extracteeTypeRef);
23022304
break;
23032305
}
23042306
case SILInstructionKind::LinearFunctionExtractInst: {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// RUN: %empty-directory(%t)
2+
// RUN: %target-build-swift-dylib(%t/%target-library-name(Library)) -emit-module -emit-module-path %t/Library.swiftmodule -module-name Library -DLIBRARY %s
3+
// RUN: %target-build-swift -I %t -O -emit-module %s
4+
5+
// SR-14004: Assertion failure due to function with `differentiable_function_extract`
6+
// with explicit extractee type being deserialized into a raw SIL module.
7+
8+
#if LIBRARY
9+
10+
import _Differentiation
11+
12+
public struct Struct<Scalar>: Differentiable {}
13+
14+
@differentiable
15+
public func foo<Scalar>(_ x: Struct<Scalar>) -> Struct<Scalar> { x }
16+
17+
@inlinable
18+
@differentiable
19+
public func bar<Scalar>(_ x: Struct<Scalar>) -> Struct<Scalar> {
20+
foo(x)
21+
}
22+
23+
#else
24+
25+
import _Differentiation
26+
import Library
27+
28+
public func foo(
29+
body: @differentiable (Struct<Float>) -> Struct<Float> = bar
30+
) {
31+
fatalError()
32+
}
33+
34+
#endif

0 commit comments

Comments
 (0)