Skip to content

[AutoDiff] Fix subset parameters thunk partial_apply substitutions. #27604

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 10, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 30 additions & 22 deletions lib/SILOptimizer/Mandatory/Differentiation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1145,9 +1145,11 @@ class ADContext {
/// purposes.
void foldDifferentiableFunctionExtraction(DifferentiableFunctionInst *source);

/// Get or create a derivative function index subset thunk from
/// `actualIndices` to `desiredIndices` for the given derivative function
/// value and original function operand.
/// Get or create a derivative function parameter index subset thunk from
/// `actualIndices` to `desiredIndices` for the given associated function
/// value and original function operand. Returns a pair of the parameter
/// index subset thunk and its interface substitution map (used to partially
/// apply the thunk).
/// Calls `getOrCreateSubsetParametersThunkForLinearMap` to thunk the linear
/// map returned by the derivative function.
std::pair<SILFunction *, SubstitutionMap>
Expand All @@ -1156,11 +1158,14 @@ class ADContext {
AutoDiffDerivativeFunctionKind kind, SILAutoDiffIndices desiredIndices,
SILAutoDiffIndices actualIndices);

/// Get or create a derivative function index subset thunk from
/// `actualIndices` to `desiredIndices` for the given derivative function
/// value and original function operand.
SILFunction *getOrCreateSubsetParametersThunkForLinearMap(
SILFunction *derivativeFn, CanSILFunctionType linearMapType,
/// Get or create a derivative function parameter index subset thunk from
/// `actualIndices` to `desiredIndices` for the given associated function
/// value and original function operand. Returns a pair of the parameter
/// index subset thunk and its interface substitution map (used to partially
/// apply the thunk).
std::pair<SILFunction *, SubstitutionMap>
getOrCreateSubsetParametersThunkForLinearMap(
SILFunction *assocFn, CanSILFunctionType linearMapType,
CanSILFunctionType targetType, AutoDiffDerivativeFunctionKind kind,
SILAutoDiffIndices desiredIndices, SILAutoDiffIndices actualIndices);

Expand Down Expand Up @@ -8098,7 +8103,7 @@ class Differentiation : public SILModuleTransform {
};
} // end anonymous namespace

SILFunction *
std::pair<SILFunction *, SubstitutionMap>
ADContext::getOrCreateSubsetParametersThunkForLinearMap(
SILFunction *parentThunk, CanSILFunctionType linearMapType,
CanSILFunctionType targetType, AutoDiffDerivativeFunctionKind kind,
Expand All @@ -8107,8 +8112,8 @@ ADContext::getOrCreateSubsetParametersThunkForLinearMap(
<< "Getting a subset parameters thunk for " << linearMapType
<< " from " << actualIndices << " to " << desiredIndices << '\n');

SubstitutionMap interfaceSubs = parentThunk->getForwardingSubstitutionMap();
GenericEnvironment *genericEnv = parentThunk->getGenericEnvironment();
SubstitutionMap interfaceSubs;
GenericEnvironment *genericEnv = nullptr;
auto thunkType = buildThunkType(
parentThunk, linearMapType, targetType, genericEnv, interfaceSubs,
/*withoutActuallyEscaping*/ true,
Expand Down Expand Up @@ -8141,7 +8146,7 @@ ADContext::getOrCreateSubsetParametersThunkForLinearMap(
ProfileCounter(), IsThunk, IsNotDynamic);

if (!thunk->empty())
return thunk;
return {thunk, interfaceSubs};

thunk->setGenericEnvironment(genericEnv);
thunk->setOwnershipEliminated();
Expand Down Expand Up @@ -8289,7 +8294,7 @@ ADContext::getOrCreateSubsetParametersThunkForLinearMap(
for (auto *alloc : reversed(localAllocations))
builder.createDeallocStack(loc, alloc);
builder.createReturn(loc, ai);
return thunk;
return {thunk, interfaceSubs};
}

// If pullback thunk, return only the desired results and clean up the
Expand Down Expand Up @@ -8325,7 +8330,7 @@ ADContext::getOrCreateSubsetParametersThunkForLinearMap(
builder.createReturn(loc, result);

getGeneratedFunctions().push_back(thunk);
return thunk;
return {thunk, interfaceSubs};
}

std::pair<SILFunction *, SubstitutionMap>
Expand Down Expand Up @@ -8465,22 +8470,25 @@ ADContext::getOrCreateSubsetParametersThunkForDerivativeFunction(
auto linearMapTargetType = targetType->getResults().back().getSILStorageType()
.castTo<SILFunctionType>();

auto *innerThunk = getOrCreateSubsetParametersThunkForLinearMap(
thunk, linearMapType, linearMapTargetType, kind,
desiredIndices, actualIndices);
SILFunction *linearMapThunk;
SubstitutionMap linearMapSubs;
std::tie(linearMapThunk, linearMapSubs) =
getOrCreateSubsetParametersThunkForLinearMap(
thunk, linearMapType, linearMapTargetType, kind,
desiredIndices, actualIndices);

auto *innerThunkFRI = builder.createFunctionRef(loc, innerThunk);
auto *newDerivative = builder.createPartialApply(
loc, innerThunkFRI, thunk->getForwardingSubstitutionMap(), {linearMap},
auto *linearMapThunkFRI = builder.createFunctionRef(loc, linearMapThunk);
auto *thunkedLinearMap = builder.createPartialApply(
loc, linearMapThunkFRI, linearMapSubs, {linearMap},
ParameterConvention::Direct_Guaranteed);

assert(origFnType->getResults().size() == 1);
if (origFnType->getResults().front().isFormalDirect()) {
auto result = joinElements(
{originalDirectResult, newDerivative}, builder, loc);
{originalDirectResult, thunkedLinearMap}, builder, loc);
builder.createReturn(loc, result);
} else {
builder.createReturn(loc, newDerivative);
builder.createReturn(loc, thunkedLinearMap);
}

getGeneratedFunctions().push_back(thunk);
Expand Down
10 changes: 10 additions & 0 deletions test/AutoDiff/generics.swift
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,16 @@ extension TF_817 {
}
}

// TF-886: Test `partial_apply` of linear map subset parameters thunk.
@differentiable
func TF_886_foo<T, U: Differentiable>(_: Float, _: T, _: U) -> Float {
return 0
}
@differentiable
func TF_886_bar<T>(x: Float, y: T) -> Float {
return TF_886_foo(x, y, 0)
}

// Test layout requirements.

// The layout requirement is "contextual": the requirement is not on `T`, the
Expand Down