Skip to content

Commit 1ab1598

Browse files
committed
Correctly infer result indices & types for multiple semantic inout results
1 parent fe2df1e commit 1ab1598

17 files changed

+257
-153
lines changed

include/swift/AST/AutoDiff.h

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,12 @@ inline llvm::raw_ostream &operator<<(llvm::raw_ostream &s,
248248
/// an `inout` parameter type. Used in derivative function type calculation.
249249
struct AutoDiffSemanticFunctionResultType {
250250
Type type;
251-
bool isInout;
251+
unsigned index : 30;
252+
bool isInout : 1;
253+
bool isWrtParam : 1;
254+
255+
AutoDiffSemanticFunctionResultType(Type t, unsigned idx, bool inout, bool wrt)
256+
: type(t), index(idx), isInout(inout), isWrtParam(wrt) { }
252257
};
253258

254259
/// Key for caching SIL derivative function types.
@@ -569,19 +574,22 @@ namespace autodiff {
569574
/// `inout` parameter types.
570575
///
571576
/// The function type may have at most two parameter lists.
572-
///
573-
/// Remaps the original semantic result using `genericEnv`, if specified.
574-
void getFunctionSemanticResultTypes(
575-
AnyFunctionType *functionType,
576-
SmallVectorImpl<AutoDiffSemanticFunctionResultType> &result,
577-
GenericEnvironment *genericEnv = nullptr);
577+
void getFunctionSemanticResults(
578+
const AnyFunctionType *functionType,
579+
const IndexSubset *parameterIndices,
580+
SmallVectorImpl<AutoDiffSemanticFunctionResultType> &resultTypes);
581+
582+
/// Returns the indices of semantic results for a given function.
583+
IndexSubset *getFunctionSemanticResultIndices(
584+
const AnyFunctionType *functionType,
585+
const IndexSubset *parameterIndices);
578586

579-
/// Returns the indices of all semantic results for a given function.
580-
IndexSubset *getAllFunctionSemanticResultIndices(
581-
const AbstractFunctionDecl *AFD);
587+
IndexSubset *getFunctionSemanticResultIndices(
588+
const AbstractFunctionDecl *AFD,
589+
const IndexSubset *parameterIndices);
582590

583591
/// Returns the lowered SIL parameter indices for the given AST parameter
584-
/// indices and `AnyfunctionType`.
592+
/// indices and `AnyFunctionType`.
585593
///
586594
/// Notable lowering-related changes:
587595
/// - AST tuple parameter types are exploded when lowered to SIL.

lib/AST/AutoDiff.cpp

Lines changed: 63 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -176,57 +176,89 @@ void AnyFunctionType::getSubsetParameters(
176176
}
177177
}
178178

179-
void autodiff::getFunctionSemanticResultTypes(
180-
AnyFunctionType *functionType,
181-
SmallVectorImpl<AutoDiffSemanticFunctionResultType> &result,
182-
GenericEnvironment *genericEnv) {
179+
void autodiff::getFunctionSemanticResults(
180+
const AnyFunctionType *functionType,
181+
const IndexSubset *parameterIndices,
182+
SmallVectorImpl<AutoDiffSemanticFunctionResultType> &resultTypes) {
183183
auto &ctx = functionType->getASTContext();
184184

185-
// Remap type in `genericEnv`, if specified.
186-
auto remap = [&](Type type) {
187-
if (!genericEnv)
188-
return type;
189-
return genericEnv->mapTypeIntoContext(type);
190-
};
191-
192185
// Collect formal result type as a semantic result, unless it is
193186
// `Void`.
194187
auto formalResultType = functionType->getResult();
195188
if (auto *resultFunctionType =
196-
functionType->getResult()->getAs<AnyFunctionType>()) {
189+
functionType->getResult()->getAs<AnyFunctionType>())
197190
formalResultType = resultFunctionType->getResult();
198-
}
191+
192+
unsigned resultIdx = 0;
199193
if (!formalResultType->isEqual(ctx.TheEmptyTupleType)) {
200194
// Separate tuple elements into individual results.
201195
if (formalResultType->is<TupleType>()) {
202196
for (auto elt : formalResultType->castTo<TupleType>()->getElements()) {
203-
result.push_back({remap(elt.getType()), /*isInout*/ false});
197+
resultTypes.emplace_back(elt.getType(), resultIdx++,
198+
/*isInout*/ false, /*isWrt*/ false);
204199
}
205200
} else {
206-
result.push_back({remap(formalResultType), /*isInout*/ false});
201+
resultTypes.emplace_back(formalResultType, resultIdx++,
202+
/*isInout*/ false, /*isWrt*/ false);
207203
}
208204
}
209205

210-
// Collect `inout` parameters as semantic results.
211-
for (auto param : functionType->getParams())
212-
if (param.isInOut())
213-
result.push_back({remap(param.getPlainType()), /*isInout*/ true});
214-
if (auto *resultFunctionType =
215-
functionType->getResult()->getAs<AnyFunctionType>()) {
216-
for (auto param : resultFunctionType->getParams())
217-
if (param.isInOut())
218-
result.push_back({remap(param.getPlainType()), /*isInout*/ true});
219-
}
206+
bool addNonWrts = resultTypes.empty();
207+
208+
// Collect wrt `inout` parameters as semantic results
209+
// As an extention, collect all (including non-wrt) inouts as results for
210+
// functions returning void.
211+
auto collectSemanticResults = [&](const AnyFunctionType *functionType,
212+
unsigned curryOffset = 0) {
213+
for (auto paramAndIndex : enumerate(functionType->getParams())) {
214+
if (!paramAndIndex.value().isInOut())
215+
continue;
216+
217+
unsigned idx = paramAndIndex.index() + curryOffset;
218+
assert(idx < parameterIndices->getCapacity() &&
219+
"invalid parameter index");
220+
bool isWrt = parameterIndices->contains(idx);
221+
if (addNonWrts || isWrt)
222+
resultTypes.emplace_back(paramAndIndex.value().getPlainType(),
223+
resultIdx, /*isInout*/ true, isWrt);
224+
resultIdx += 1;
225+
}
226+
};
227+
228+
if (auto *resultFnType =
229+
functionType->getResult()->getAs<AnyFunctionType>()) {
230+
// Here we assume that the input is a function type with curried `Self`
231+
assert(functionType->getNumParams() == 1 && "unexpected function type");
232+
233+
collectSemanticResults(resultFnType);
234+
collectSemanticResults(functionType, resultFnType->getNumParams());
235+
} else
236+
collectSemanticResults(functionType);
220237
}
221238

222239
IndexSubset *
223-
autodiff::getAllFunctionSemanticResultIndices(const AbstractFunctionDecl *AFD) {
224-
auto originalFn = AFD->getInterfaceType()->castTo<AnyFunctionType>();
240+
autodiff::getFunctionSemanticResultIndices(const AnyFunctionType *functionType,
241+
const IndexSubset *parameterIndices) {
242+
auto &ctx = functionType->getASTContext();
243+
225244
SmallVector<AutoDiffSemanticFunctionResultType, 1> semanticResults;
226-
autodiff::getFunctionSemanticResultTypes(originalFn, semanticResults);
227-
auto numResults = semanticResults.size();
228-
return IndexSubset::getDefault(
229-
AFD->getASTContext(), numResults, /*includeAll*/ true);
245+
autodiff::getFunctionSemanticResults(functionType, parameterIndices,
246+
semanticResults);
247+
SmallVector<unsigned> resultIndices;
248+
unsigned cap = 0;
249+
for (const auto& result : semanticResults) {
250+
resultIndices.push_back(result.index);
251+
cap = std::max(cap, result.index + 1U);
252+
}
253+
254+
return IndexSubset::get(ctx, cap, resultIndices);
255+
}
256+
257+
IndexSubset *
258+
autodiff::getFunctionSemanticResultIndices(const AbstractFunctionDecl *AFD,
259+
const IndexSubset *parameterIndices) {
260+
return getFunctionSemanticResultIndices(AFD->getInterfaceType()->castTo<AnyFunctionType>(),
261+
parameterIndices);
230262
}
231263

232264
// TODO(TF-874): Simplify this helper. See TF-874 for WIP.

lib/AST/Type.cpp

Lines changed: 24 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -5550,85 +5550,41 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
55505550

55515551
// Get the original non-inout semantic result types.
55525552
SmallVector<AutoDiffSemanticFunctionResultType, 1> originalResults;
5553-
autodiff::getFunctionSemanticResultTypes(this, originalResults);
5553+
autodiff::getFunctionSemanticResults(this, parameterIndices, originalResults);
55545554
// Error if no original semantic results.
55555555
if (originalResults.empty())
55565556
return llvm::make_error<DerivativeFunctionTypeError>(
55575557
this, DerivativeFunctionTypeError::Kind::NoSemanticResults);
5558+
55585559
// Accumulate non-inout result tangent spaces.
5559-
SmallVector<Type, 1> resultTanTypes;
5560-
bool hasInoutResult = false;
5560+
SmallVector<Type, 1> resultTanTypes, inoutTanTypes;
55615561
for (auto i : range(originalResults.size())) {
55625562
auto originalResult = originalResults[i];
55635563
auto originalResultType = originalResult.type;
5564+
55645565
// Voids currently have a defined tangent vector, so ignore them.
55655566
if (originalResultType->isVoid())
55665567
continue;
5567-
if (originalResult.isInout) {
5568-
hasInoutResult = true;
5569-
continue;
5570-
}
5568+
55715569
// Get the original semantic result type's `TangentVector` associated type.
5570+
// Error if a semantic result has no tangent space.
55725571
auto resultTan =
55735572
originalResultType->getAutoDiffTangentSpace(lookupConformance);
55745573
if (!resultTan)
5575-
continue;
5576-
auto resultTanType = resultTan->getType();
5577-
resultTanTypes.push_back(resultTanType);
5578-
}
5579-
// Append non-wrt inout result tangent spaces.
5580-
// This uses the logic from getSubsetParameters(), only operating over all
5581-
// parameter indices and looking for non-wrt indices.
5582-
SmallVector<AnyFunctionType *, 2> curryLevels;
5583-
// An inlined version of unwrapCurryLevels().
5584-
AnyFunctionType *fnTy = this;
5585-
while (fnTy != nullptr) {
5586-
curryLevels.push_back(fnTy);
5587-
fnTy = fnTy->getResult()->getAs<AnyFunctionType>();
5588-
}
5589-
5590-
SmallVector<unsigned, 2> curryLevelParameterIndexOffsets(curryLevels.size());
5591-
unsigned currentOffset = 0;
5592-
for (unsigned curryLevelIndex : llvm::reverse(indices(curryLevels))) {
5593-
curryLevelParameterIndexOffsets[curryLevelIndex] = currentOffset;
5594-
currentOffset += curryLevels[curryLevelIndex]->getNumParams();
5595-
}
5596-
5597-
if (!makeSelfParamFirst) {
5598-
std::reverse(curryLevels.begin(), curryLevels.end());
5599-
std::reverse(curryLevelParameterIndexOffsets.begin(),
5600-
curryLevelParameterIndexOffsets.end());
5601-
}
5602-
5603-
for (unsigned curryLevelIndex : indices(curryLevels)) {
5604-
auto *curryLevel = curryLevels[curryLevelIndex];
5605-
unsigned parameterIndexOffset =
5606-
curryLevelParameterIndexOffsets[curryLevelIndex];
5607-
for (unsigned paramIndex : range(curryLevel->getNumParams())) {
5608-
if (parameterIndices->contains(parameterIndexOffset + paramIndex))
5609-
continue;
5610-
5611-
auto param = curryLevel->getParams()[paramIndex];
5612-
if (param.isInOut()) {
5613-
auto resultType = param.getPlainType();
5614-
if (resultType->isVoid())
5615-
continue;
5616-
auto resultTan = resultType->getAutoDiffTangentSpace(lookupConformance);
5617-
if (!resultTan)
5618-
continue;
5619-
auto resultTanType = resultTan->getType();
5620-
resultTanTypes.push_back(resultTanType);
5621-
}
5622-
}
5623-
}
5624-
5625-
// Error if no semantic result has a tangent space.
5626-
if (resultTanTypes.empty() && !hasInoutResult) {
5627-
return llvm::make_error<DerivativeFunctionTypeError>(
5574+
return llvm::make_error<DerivativeFunctionTypeError>(
56285575
this, DerivativeFunctionTypeError::Kind::NonDifferentiableResult,
5629-
std::make_pair(originalResults.front().type, /*index*/ 0));
5576+
std::make_pair(originalResultType, unsigned(originalResult.index)));
5577+
5578+
if (!originalResult.isInout)
5579+
resultTanTypes.push_back(resultTan->getType());
5580+
else if (originalResult.isInout && !originalResult.isWrtParam)
5581+
inoutTanTypes.push_back(resultTan->getType());
56305582
}
56315583

5584+
// Treat non-wrt inouts as semantic results for functions returning Void
5585+
if (resultTanTypes.empty())
5586+
resultTanTypes = inoutTanTypes;
5587+
56325588
// Compute the result linear map function type.
56335589
FunctionType *linearMapType;
56345590
switch (kind) {
@@ -5641,24 +5597,24 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
56415597
//
56425598
// Case 2: original function has a non-wrt `inout` parameter.
56435599
// - Original: `(T0, inout T1, ...) -> Void`
5644-
// - Differential: `(T0.Tan, ...) -> T1.Tan`
5600+
// - Differential: `(T0.Tan, ...) -> T1.Tan`
56455601
//
56465602
// Case 3: original function has a wrt `inout` parameter.
5647-
// - Original: `(T0, inout T1, ...) -> Void`
5648-
// - Differential: `(T0.Tan, inout T1.Tan, ...) -> Void`
5603+
// - Original: `(T0, inout T1, ...) -> Void`
5604+
// - Differential: `(T0.Tan, inout T1.Tan, ...) -> Void`
56495605
SmallVector<AnyFunctionType::Param, 4> differentialParams;
56505606
for (auto i : range(diffParams.size())) {
56515607
auto diffParam = diffParams[i];
56525608
auto paramType = diffParam.getPlainType();
56535609
auto paramTan = paramType->getAutoDiffTangentSpace(lookupConformance);
56545610
// Error if parameter has no tangent space.
5655-
if (!paramTan) {
5611+
if (!paramTan)
56565612
return llvm::make_error<DerivativeFunctionTypeError>(
56575613
this,
56585614
DerivativeFunctionTypeError::Kind::
56595615
NonDifferentiableDifferentiabilityParameter,
56605616
std::make_pair(paramType, i));
5661-
}
5617+
56625618
differentialParams.push_back(AnyFunctionType::Param(
56635619
paramTan->getType(), Identifier(), diffParam.getParameterFlags()));
56645620
}
@@ -5704,13 +5660,13 @@ AnyFunctionType::getAutoDiffDerivativeFunctionLinearMapType(
57045660
auto paramType = diffParam.getPlainType();
57055661
auto paramTan = paramType->getAutoDiffTangentSpace(lookupConformance);
57065662
// Error if parameter has no tangent space.
5707-
if (!paramTan) {
5663+
if (!paramTan)
57085664
return llvm::make_error<DerivativeFunctionTypeError>(
57095665
this,
57105666
DerivativeFunctionTypeError::Kind::
57115667
NonDifferentiableDifferentiabilityParameter,
57125668
std::make_pair(paramType, i));
5713-
}
5669+
57145670
if (diffParam.isInOut()) {
57155671
if (paramType->isVoid())
57165672
continue;

lib/IRGen/IRGenMangler.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ class IRGenMangler : public Mangle::ASTMangler {
5858
beginManglingWithAutoDiffOriginalFunction(func);
5959
auto kind = Demangle::getAutoDiffFunctionKind(derivativeId->getKind());
6060
auto *resultIndices =
61-
autodiff::getAllFunctionSemanticResultIndices(func);
61+
autodiff::getFunctionSemanticResultIndices(func,
62+
derivativeId->getParameterIndices());
6263
AutoDiffConfig config(
6364
derivativeId->getParameterIndices(),
6465
resultIndices,
@@ -89,7 +90,8 @@ class IRGenMangler : public Mangle::ASTMangler {
8990
beginManglingWithAutoDiffOriginalFunction(func);
9091
auto kind = Demangle::getAutoDiffFunctionKind(derivativeId->getKind());
9192
auto *resultIndices =
92-
autodiff::getAllFunctionSemanticResultIndices(func);
93+
autodiff::getFunctionSemanticResultIndices(func,
94+
derivativeId->getParameterIndices());
9395
AutoDiffConfig config(
9496
derivativeId->getParameterIndices(),
9597
resultIndices,

lib/SIL/IR/SILDeclRef.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1067,8 +1067,10 @@ std::string SILDeclRef::mangle(ManglingKind MKind) const {
10671067
auto *silParameterIndices = autodiff::getLoweredParameterIndices(
10681068
derivativeFunctionIdentifier->getParameterIndices(),
10691069
getDecl()->getInterfaceType()->castTo<AnyFunctionType>());
1070-
auto *resultIndices = autodiff::getAllFunctionSemanticResultIndices(
1071-
asAutoDiffOriginalFunction().getAbstractFunctionDecl());
1070+
// FIXME: is this correct in the presence of curried types?
1071+
auto *resultIndices = autodiff::getFunctionSemanticResultIndices(
1072+
asAutoDiffOriginalFunction().getAbstractFunctionDecl(),
1073+
derivativeFunctionIdentifier->getParameterIndices());
10721074
AutoDiffConfig silConfig(
10731075
silParameterIndices, resultIndices,
10741076
derivativeFunctionIdentifier->getDerivativeGenericSignature());

0 commit comments

Comments
 (0)