Skip to content

Commit 2035bce

Browse files
committed
Fix assorted tests.
- Fix `SILGenModule::getOrCreateCustomDerivativeThunk` for class initializers. - Update `ArrayRef<ManagedValue>` input to `forwardFunctionArguments` instead of the `SmallVectorImpl<SILValue> &forwardedArgs` output. - Polish `convert_function` code for JVP/VJP emission. - Create `convert_function` only if actual differential/pullback type does not match the expected type. 1 remaining test failure with `-Onone`. More test failures with `-O`.
1 parent 66b799f commit 2035bce

File tree

4 files changed

+56
-52
lines changed

4 files changed

+56
-52
lines changed

lib/SILGen/SILGenPoly.cpp

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3880,19 +3880,10 @@ SILFunction *SILGenModule::getOrCreateCustomDerivativeThunk(
38803880
thunkSGF.collectThunkParams(loc, params, &indirectResults);
38813881

38823882
auto *fnRef = thunkSGF.B.createFunctionRef(loc, customDerivativeFn);
3883-
// auto fnRefType = fnRef->getType().mapTypeOutOfContext().castTo<SILFunctionType>();
3884-
auto fnRefType = thunkSGF.F.mapTypeIntoContext(fnRef->getType().mapTypeOutOfContext()).castTo<SILFunctionType>();
3885-
llvm::errs() << "FN REF TYPE\n";
3886-
fnRefType->dump();
3887-
llvm::errs() << "FN REF TYPE 2\n";
3888-
fnRefType = fnRefType->getUnsubstitutedType(M);
3889-
fnRefType->dump();
3890-
3891-
// Collect thunk arguments, converting ownership.
3892-
SmallVector<SILValue, 8> arguments;
3893-
for (auto *indRes : indirectResults)
3894-
arguments.push_back(indRes);
3895-
forwardFunctionArguments(thunkSGF, loc, fnRefType, params, arguments);
3883+
auto fnRefType =
3884+
thunkSGF.F.mapTypeIntoContext(fnRef->getType().mapTypeOutOfContext())
3885+
.castTo<SILFunctionType>()
3886+
->getUnsubstitutedType(M);
38963887

38973888
// Special support for thunking class initializer derivatives.
38983889
//
@@ -3906,19 +3897,32 @@ SILFunction *SILGenModule::getOrCreateCustomDerivativeThunk(
39063897
// - Create a `@thick Class.Type` value and pass it as the last argument.
39073898
auto *origAFD =
39083899
cast<AbstractFunctionDecl>(originalFn->getDeclContext()->getAsDecl());
3909-
if (isa<ConstructorDecl>(origAFD) &&
3900+
bool isClassInitializer =
3901+
isa<ConstructorDecl>(origAFD) &&
3902+
origAFD->getDeclContext()->getSelfClassDecl() &&
39103903
SILDeclRef(origAFD, SILDeclRef::Kind::Initializer).mangle() ==
3911-
originalFn->getName()) {
3912-
auto classArgument = arguments.pop_back_val();
3913-
auto *classDecl = classArgument->getType().getClassOrBoundGenericClass();
3904+
originalFn->getName();
3905+
if (isClassInitializer) {
3906+
params.pop_back();
3907+
auto *classDecl = thunkFnTy->getParameters()
3908+
.back()
3909+
.getInterfaceType()
3910+
->getClassOrBoundGenericClass();
39143911
assert(classDecl && "Expected last argument to have class type");
39153912
auto classMetatype = MetatypeType::get(
39163913
classDecl->getDeclaredInterfaceType(), MetatypeRepresentation::Thick);
39173914
auto canClassMetatype = classMetatype->getCanonicalType();
39183915
auto *metatype = thunkSGF.B.createMetatype(
39193916
loc, SILType::getPrimitiveObjectType(canClassMetatype));
3920-
arguments.push_back(metatype);
3917+
params.push_back(ManagedValue::forUnmanaged(metatype));
39213918
}
3919+
3920+
// Collect thunk arguments, converting ownership.
3921+
SmallVector<SILValue, 8> arguments;
3922+
for (auto *indRes : indirectResults)
3923+
arguments.push_back(indRes);
3924+
forwardFunctionArguments(thunkSGF, loc, fnRefType, params, arguments);
3925+
39223926
// Apply function argument.
39233927
auto apply = thunkSGF.emitApplyWithRethrow(
39243928
loc, fnRef, /*substFnType*/ fnRef->getType(),

lib/SILOptimizer/Utils/Differentiation/JVPEmitter.cpp

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1393,8 +1393,6 @@ void JVPEmitter::visitApplyInst(ApplyInst *ai) {
13931393
auto *differentialDecl = differentialInfo.lookUpLinearMapDecl(ai);
13941394
auto originalDifferentialType =
13951395
getOpType(differential->getType()).getAs<SILFunctionType>();
1396-
auto differentialType =
1397-
remapType(differential->getType()).castTo<SILFunctionType>();
13981396
auto loweredDifferentialType =
13991397
getOpType(getLoweredType(differentialDecl->getInterfaceType()))
14001398
.castTo<SILFunctionType>();
@@ -1436,19 +1434,29 @@ void JVPEmitter::visitReturnInst(ReturnInst *ri) {
14361434
auto differentialType = jvp->getLoweredFunctionType()->getResults().back().getSILStorageInterfaceType();
14371435
differentialType = differentialType.substGenericArgs(getModule(), jvpSubstMap, TypeExpansionContext::minimal());
14381436
differentialType = differentialType.subst(getModule(), jvpSubstMap);
1437+
auto differentialFnType = differentialType.castTo<SILFunctionType>();
14391438

1439+
auto differentialSubstType =
1440+
differentialPartialApply->getType().castTo<SILFunctionType>();
14401441
SILValue differentialValue;
1441-
if (differentialPartialApply->getType().castTo<SILFunctionType>()->isABICompatibleWith(differentialType.castTo<SILFunctionType>(), *differentialPartialApply->getFunction()).isCompatible()) {
1442-
differentialValue = builder.createConvertFunction(loc, differentialPartialApply, differentialType, /*withoutActuallyEscaping*/ false);
1442+
if (differentialSubstType == differentialFnType) {
1443+
differentialValue = differentialPartialApply;
1444+
} else if (differentialSubstType
1445+
->isABICompatibleWith(differentialFnType, *jvp)
1446+
.isCompatible()) {
1447+
differentialValue = builder.createConvertFunction(
1448+
loc, differentialPartialApply, differentialType,
1449+
/*withoutActuallyEscaping*/ false);
14431450
} else {
14441451
// When `diag::autodiff_loadable_value_addressonly_tangent_unsupported`
14451452
// applies, the return type may be ABI-incomaptible with the type of the
1446-
// partially applied differential. In these cases, produce an undef and rely on
1447-
// other code to emit a diagnostic.
1448-
differentialValue = SILUndef::get(differentialType, *differentialPartialApply->getFunction());
1453+
// partially applied differential. In these cases, produce an undef and rely
1454+
// on other code to emit a diagnostic.
1455+
differentialValue = SILUndef::get(differentialType,
1456+
*differentialPartialApply->getFunction());
14491457
}
14501458

1451-
// Return a tuple of the original result and pullback.
1459+
// Return a tuple of the original result and differential.
14521460
SmallVector<SILValue, 8> directResults;
14531461
directResults.append(origResults.begin(), origResults.end());
14541462
directResults.push_back(differentialValue);

lib/SILOptimizer/Utils/Differentiation/VJPEmitter.cpp

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -339,33 +339,25 @@ void VJPEmitter::visitReturnInst(ReturnInst *ri) {
339339
auto *pullbackPartialApply =
340340
builder.createPartialApply(loc, pullbackRef, vjpSubstMap, {pbStructVal},
341341
ParameterConvention::Direct_Guaranteed);
342-
llvm::errs() << "PULLBACK TYPES\n";
343-
pullbackPartialApply->getType().dump();
344-
// auto pullbackType = getLoweredType(vjp->getLoweredFunctionType()->getResults().back().getInterfaceType());
345-
auto pullbackType = vjp->getLoweredFunctionType()->getResults().back().getSILStorageInterfaceType();
346-
pullbackType.dump();
347-
llvm::errs() << "BEFORE REPLACEMENT TYPES\n";
348-
for (auto type : vjpSubstMap.getReplacementTypes())
349-
type->dump();
350-
llvm::errs() << "AFTER REPLACEMENT TYPES\n";
351-
for (auto type : vjp->getLoweredFunctionType()->getSubstitutions().getReplacementTypes())
352-
type->dump();
353-
llvm::errs() << "FINAL REPLACEMENT TYPES\n";
354-
for (auto type : pullbackType.castTo<SILFunctionType>()->getSubstitutions().getReplacementTypes())
355-
type->dump();
356-
// pullbackType = SILType::getPrimitiveObjectType(pullbackType.castTo<SILFunctionType>()->withSubstitutions(vjpSubstMap));
357-
pullbackType = pullbackType.substGenericArgs(getModule(), vjpSubstMap, TypeExpansionContext::minimal());
358-
llvm::errs() << "FINAL2 REPLACEMENT TYPES\n";
359-
for (auto type : pullbackType.castTo<SILFunctionType>()->getSubstitutions().getReplacementTypes())
360-
type->dump();
342+
auto pullbackType = vjp->getLoweredFunctionType()
343+
->getResults()
344+
.back()
345+
.getSILStorageInterfaceType();
346+
pullbackType = pullbackType.substGenericArgs(getModule(), vjpSubstMap,
347+
TypeExpansionContext::minimal());
361348
pullbackType = pullbackType.subst(getModule(), vjpSubstMap);
362-
llvm::errs() << "FINAL3 REPLACEMENT TYPES\n";
363-
for (auto type : pullbackType.castTo<SILFunctionType>()->getSubstitutions().getReplacementTypes())
364-
type->dump();
349+
auto pullbackFnType = pullbackType.castTo<SILFunctionType>();
365350

351+
auto pullbackSubstType =
352+
pullbackPartialApply->getType().castTo<SILFunctionType>();
366353
SILValue pullbackValue;
367-
if (pullbackPartialApply->getType().castTo<SILFunctionType>()->isABICompatibleWith(pullbackType.castTo<SILFunctionType>(), *pullbackPartialApply->getFunction()).isCompatible()) {
368-
pullbackValue = builder.createConvertFunction(loc, pullbackPartialApply, pullbackType, /*withoutActuallyEscaping*/ false);
354+
if (pullbackSubstType == pullbackFnType) {
355+
pullbackValue = pullbackPartialApply;
356+
} else if (pullbackSubstType->isABICompatibleWith(pullbackFnType, *vjp)
357+
.isCompatible()) {
358+
pullbackValue =
359+
builder.createConvertFunction(loc, pullbackPartialApply, pullbackType,
360+
/*withoutActuallyEscaping*/ false);
369361
} else {
370362
// When `diag::autodiff_loadable_value_addressonly_tangent_unsupported`
371363
// applies, the return type may be ABI-incomaptible with the type of the

test/AutoDiff/downstream/sildeclref_parse.sil

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,10 @@ bb0(%0 : $Class<T>):
4343
%1 = class_method %0 : $Class<T>, #Class.f!1 : <T> (Class<T>) -> (T, Float) -> T, $@convention(method) <τ_0_0> (@in_guaranteed τ_0_0, Float, @guaranteed Class<τ_0_0>) -> @out τ_0_0
4444

4545
// CHECK: class_method %0 : $Class<T>, #Class.f!1.jvp.SSU
46-
%2 = class_method %0 : $Class<T>, #Class.f!1.jvp.SSU.<T where T : Differentiable> : <T> (Class<T>) -> (T, Float) -> T, $@convention(method) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0, Float, @guaranteed Class<τ_0_0>) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector)
46+
%2 = class_method %0 : $Class<T>, #Class.f!1.jvp.SSU.<T where T : Differentiable> : <T> (Class<T>) -> (T, Float) -> T, $@convention(method) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0, Float, @guaranteed Class<τ_0_0>) -> (@out τ_0_0, @owned @callee_guaranteed <τ_0_0, τ_0_1> in (@in_guaranteed τ_0_0, Float) -> @out τ_0_1 for <τ_0_0.TangentVector, τ_0_0.TangentVector>)
4747

4848
// CHECK: class_method %0 : $Class<T>, #Class.f!1.vjp.SSU
49-
%3 = class_method %0 : $Class<T>, #Class.f!1.vjp.SSU.<T where T : Differentiable> : <T> (Class<T>) -> (T, Float) -> T, $@convention(method) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0, Float, @guaranteed Class<τ_0_0>) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, Float))
49+
%3 = class_method %0 : $Class<T>, #Class.f!1.vjp.SSU.<T where T : Differentiable> : <T> (Class<T>) -> (T, Float) -> T, $@convention(method) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0, Float, @guaranteed Class<τ_0_0>) -> (@out τ_0_0, @owned @callee_guaranteed <τ_0_0, τ_0_1> in (@in_guaranteed τ_0_0) -> (@out τ_0_1, Float) for <τ_0_0.TangentVector, τ_0_0.TangentVector>)
5050

5151
%6 = tuple ()
5252
return %6 : $()

0 commit comments

Comments
 (0)