Skip to content

Commit 3d61c21

Browse files
authored
[AutoDiff] Revert obsolete SIL undef hack. (#33571)
Previously, JVP/VJP generation used a "return undef" hack when differential/pullback values did not match the expected return type. This was relevant before differentiation supported "loadable types with address-only tangent types", which was diagnosed. Now that support for the above exists, "return undef" should be removed and replaced with an assertion.
1 parent 6b07132 commit 3d61c21

File tree

2 files changed

+10
-12
lines changed

2 files changed

+10
-12
lines changed

lib/SILOptimizer/Differentiation/JVPCloner.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -704,9 +704,11 @@ class JVPCloner::Implementation final
704704
getModule(), jvpSubstMap, TypeExpansionContext::minimal());
705705
differentialType = differentialType.subst(getModule(), jvpSubstMap);
706706
auto differentialFnType = differentialType.castTo<SILFunctionType>();
707-
708707
auto differentialSubstType =
709708
differentialPartialApply->getType().castTo<SILFunctionType>();
709+
710+
// If necessary, convert the differential value to the returned differential
711+
// function type.
710712
SILValue differentialValue;
711713
if (differentialSubstType == differentialFnType) {
712714
differentialValue = differentialPartialApply;
@@ -717,11 +719,8 @@ class JVPCloner::Implementation final
717719
loc, differentialPartialApply, differentialType,
718720
/*withoutActuallyEscaping*/ false);
719721
} else {
720-
// When `diag::autodiff_loadable_value_addressonly_tangent_unsupported`
721-
// applies, the return type may be ABI-incomaptible with the type of the
722-
// partially applied differential. In these cases, produce an undef and
723-
// rely on other code to emit a diagnostic.
724-
differentialValue = SILUndef::get(differentialType, *jvp);
722+
llvm::report_fatal_error("Differential value type is not ABI-compatible "
723+
"with the returned differential type");
725724
}
726725

727726
// Return a tuple of the original result and differential.

lib/SILOptimizer/Differentiation/VJPCloner.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -201,9 +201,11 @@ class VJPCloner::Implementation final
201201
getModule(), vjpSubstMap, TypeExpansionContext::minimal());
202202
pullbackType = pullbackType.subst(getModule(), vjpSubstMap);
203203
auto pullbackFnType = pullbackType.castTo<SILFunctionType>();
204-
205204
auto pullbackSubstType =
206205
pullbackPartialApply->getType().castTo<SILFunctionType>();
206+
207+
// If necessary, convert the pullback value to the returned pullback
208+
// function type.
207209
SILValue pullbackValue;
208210
if (pullbackSubstType == pullbackFnType) {
209211
pullbackValue = pullbackPartialApply;
@@ -213,11 +215,8 @@ class VJPCloner::Implementation final
213215
builder.createConvertFunction(loc, pullbackPartialApply, pullbackType,
214216
/*withoutActuallyEscaping*/ false);
215217
} else {
216-
// When `diag::autodiff_loadable_value_addressonly_tangent_unsupported`
217-
// applies, the return type may be ABI-incomaptible with the type of the
218-
// partially applied pullback. In these cases, produce an undef and rely
219-
// on other code to emit a diagnostic.
220-
pullbackValue = SILUndef::get(pullbackType, *vjp);
218+
llvm::report_fatal_error("Pullback value type is not ABI-compatible "
219+
"with the returned pullback type");
221220
}
222221

223222
// Return a tuple of the original result and pullback.

0 commit comments

Comments
 (0)