Skip to content

Commit 3faf474

Browse files
Marc Rasirxwei
authored andcommitted
fix conversion problem
1 parent bb9c2fe commit 3faf474

File tree

2 files changed

+13
-6
lines changed

2 files changed

+13
-6
lines changed

include/swift/SIL/SILInstruction.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7936,7 +7936,7 @@ class DifferentiableFunctionExtractInst
79367936
} rawValue;
79377937
Extractee() = default;
79387938
Extractee(innerty rawValue) : rawValue(rawValue) {}
7939-
Extractee(unsigned rawValue) : Extractee((innerty)rawValue) {}
7939+
explicit Extractee(unsigned rawValue) : Extractee((innerty)rawValue) {}
79407940
Extractee(AutoDiffAssociatedFunctionKind kind);
79417941
explicit Extractee(StringRef name);
79427942
operator innerty() const { return rawValue; }

lib/SIL/TypeLowering.cpp

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -896,16 +896,23 @@ namespace {
896896
DifferentiableFunctionExtractee::Original,
897897
TC.getTypeLowering(origFnTy, getResilienceExpansion())
898898
});
899-
for (auto kind : {AutoDiffAssociatedFunctionKind::JVP,
900-
AutoDiffAssociatedFunctionKind::VJP}) {
899+
for (AutoDiffAssociatedFunctionKind kind :
900+
{AutoDiffAssociatedFunctionKind::JVP,
901+
AutoDiffAssociatedFunctionKind::VJP}) {
901902
auto assocFnTy = origFnTy->getAutoDiffAssociatedFunctionType(
902903
paramIndices, 0, kind, TC,
903904
LookUpConformanceInModule(&TC.M));
904905
auto silTy = SILType::getPrimitiveObjectType(assocFnTy);
906+
auto extractee = DifferentiableFunctionExtractee(kind);
907+
908+
// A bug caused by implicit conversions caused us to get the wrong
909+
// extractee, so assert that we have the right extractee to prevent
910+
// reoccurrence of the bug.
911+
assert(extractee.getExtracteeAsAssociatedFunction() ==
912+
Optional<AutoDiffAssociatedFunctionKind>(kind));
913+
905914
children.push_back(Child{
906-
DifferentiableFunctionExtractee(kind),
907-
TC.getTypeLowering(silTy, getResilienceExpansion())
908-
});
915+
extractee, TC.getTypeLowering(silTy, getResilienceExpansion())});
909916
}
910917
assert(children.size() == 3);
911918
}

0 commit comments

Comments
 (0)