Skip to content

[AutoDiff upstream] Add differentiable_function canonicalization. #30818

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions include/swift/AST/DiagnosticsSIL.def
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,42 @@ NOTE(autodiff_control_flow_not_supported,none,
"cannot differentiate unsupported control flow", ())
NOTE(autodiff_missing_return,none,
"missing return for differentiation", ())
NOTE(autodiff_external_nondifferentiable_function,none,
"cannot differentiate functions that have not been marked "
"'@differentiable' and that are defined in other files", ())
NOTE(autodiff_opaque_function_not_differentiable,none,
"opaque non-'@differentiable' function is not differentiable", ())
NOTE(autodiff_private_derivative_from_fragile,none,
"differentiated functions in "
"%select{'@inlinable' functions|default arguments}0 must be marked "
"'@differentiable' or have a public '@derivative'"
"%select{|; this is not possible with a closure, make a top-level "
"function instead}1", (unsigned, bool))
NOTE(autodiff_function_noderivative_parameter_not_differentiable,none,
"cannot differentiate with respect to a '@noDerivative' parameter", ())
NOTE(autodiff_function_assoc_func_unmet_requirements,none,
"function call is not differentiable because generic requirements are not "
"met: '%0'", (/*requirements*/ StringRef))
NOTE(autodiff_nondifferentiable_argument,none,
"cannot differentiate through a non-differentiable argument; do you want "
"to use 'withoutDerivative(at:)'?", ())
NOTE(autodiff_nondifferentiable_result,none,
"cannot differentiate through a non-differentiable result; do you want to "
"use 'withoutDerivative(at:)'?", ())
NOTE(autodiff_protocol_member_not_differentiable,none,
"member is not differentiable because the corresponding protocol "
"requirement is not '@differentiable'", ())
NOTE(autodiff_class_member_not_differentiable,none,
"member is not differentiable because the corresponding class member "
"is not '@differentiable'", ())
NOTE(autodiff_member_subset_indices_not_differentiable,none,
"member is differentiable only with respect to a smaller subset of "
"arguments", ())
// TODO(TF-642): Remove when `partial_apply` works with `@differentiable`
// functions.
NOTE(autodiff_cannot_param_subset_thunk_partially_applied_orig_fn,none,
"cannot convert a direct method reference to a '@differentiable' "
"function; use an explicit closure instead", ())

ERROR(non_physical_addressof,none,
"addressof only works with purely physical lvalues; "
Expand Down
32 changes: 32 additions & 0 deletions include/swift/SIL/SILInstruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -2056,6 +2056,25 @@ class ApplyInstBase<Impl, Base, false> : public Base {
/// does it have the given semantics?
bool doesApplyCalleeHaveSemantics(SILValue callee, StringRef semantics);

/// Predicate used to filter InoutArgumentRange.
struct OperandToInoutArgument {
ArrayRef<SILParameterInfo> paramInfos;
OperandValueArrayRef arguments;
OperandToInoutArgument(ArrayRef<SILParameterInfo> paramInfos,
OperandValueArrayRef arguments)
: paramInfos(paramInfos), arguments(arguments) {
assert(paramInfos.size() == arguments.size());
}
Optional<SILValue> operator()(size_t i) const {
if (paramInfos[i].isIndirectMutating())
return arguments[i];
return None;
}
};

using InoutArgumentRange =
OptionalTransformRange<IntRange<size_t>, OperandToInoutArgument>;

/// The partial specialization of ApplyInstBase for full applications.
/// Adds some methods relating to 'self' and to result types that don't
/// make sense for partial applications.
Expand All @@ -2068,6 +2087,9 @@ class ApplyInstBase<Impl, Base, true>
ApplyInstBase(As &&...args)
: ApplyInstBase<Impl, Base, false>(std::forward<As>(args)...) {}

private:
const Impl &asImpl() const { return static_cast<const Impl &>(*this); }

public:
using super::getCallee;
using super::getSubstCalleeType;
Expand Down Expand Up @@ -2152,6 +2174,16 @@ class ApplyInstBase<Impl, Base, true>
return getArguments().slice(getNumIndirectResults());
}

/// Returns all `@inout` and `@inout_aliasable` arguments passed to the
/// instruction.
InoutArgumentRange getInoutArguments() const {
auto &impl = asImpl();
return InoutArgumentRange(
indices(getArgumentsWithoutIndirectResults()),
OperandToInoutArgument(impl.getSubstCalleeConv().getParameters(),
impl.getArgumentsWithoutIndirectResults()));
}

bool hasSemantics(StringRef semanticsString) const {
return doesApplyCalleeHaveSemantics(getCallee(), semanticsString);
}
Expand Down
3 changes: 3 additions & 0 deletions include/swift/SIL/SILType.h
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,9 @@ class SILType {
bool isLoweringOf(TypeExpansionContext context, SILModule &M,
CanType formalType);

/// Returns true if this SILType is a differentiable type.
bool isDifferentiable(SILModule &M) const;

/// Returns the hash code for the SILType.
llvm::hash_code getHashCode() const {
return llvm::hash_combine(*this);
Expand Down
99 changes: 99 additions & 0 deletions include/swift/SILOptimizer/Utils/Differentiation/Common.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "swift/SIL/SILDifferentiabilityWitness.h"
#include "swift/SIL/SILFunction.h"
#include "swift/SIL/SILModule.h"
#include "swift/SIL/TypeSubstCloner.h"

namespace swift {

Expand All @@ -33,6 +34,12 @@ namespace autodiff {
/// This is being used to print short debug messages within the AD pass.
raw_ostream &getADDebugStream();

/// Given a function call site, gathers all of its actual results (both direct
/// and indirect) in an order defined by its result type.
void collectAllActualResultsInTypeOrder(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a number of problems with this. I am pretty sure we already have something like this.

ApplyInst *ai, ArrayRef<SILValue> extractedDirectResults,
SmallVectorImpl<SILValue> &results);

/// Returns the underlying instruction for the given SILValue, if it exists,
/// peering through function conversion instructions.
template <class Inst> Inst *peerThroughFunctionConversions(SILValue value) {
Expand All @@ -51,6 +58,70 @@ template <class Inst> Inst *peerThroughFunctionConversions(SILValue value) {
return nullptr;
}

/// Given a range of elements, joins these into a single value. If there's
/// exactly one element, returns that element. Otherwise, creates a tuple using
/// a `tuple` instruction.
SILValue joinElements(ArrayRef<SILValue> elements, SILBuilder &builder,
SILLocation loc);

/// Given a value, extracts all elements to `results` from this value if it has
/// a tuple type. Otherwise, add this value directly to `results`.
void extractAllElements(SILValue value, SILBuilder &builder,
SmallVectorImpl<SILValue> &results);

/// Emit a zero value into the given buffer access by calling
/// `AdditiveArithmetic.zero`. The given type must conform to
/// `AdditiveArithmetic`.
void emitZeroIntoBuffer(SILBuilder &builder, CanType type,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please fix these APIs to not take a SILBuilder. Instead, it should take an insertion point and a SILBuilderContext.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change seems important for correctness. It can probably help clean up this code:
https://github.com/apple/swift/blob/13d5a8addbe3605984edc4ce7c6cbf1f0649a9e1/lib/SILOptimizer/Utils/Differentiation/PullbackEmitter.cpp#L486-L497

I started looking into this but ran into some difficulties. The user above wants to insert at the end of a basic block (SILBasicBlock::iterator), but no appropriate constructor exists. Only SILBuilderWithScope(SILInstruction *I, SILBuilderContext &C) is recommended.

SILBuilderWithScope(SILBasicBlock::iterator I, SILBuilderContext &C) would work but doesn't exist. Can we add it, or are there other considerations?

SILValue bufferAccess, SILLocation loc);

//===----------------------------------------------------------------------===//
// Utilities for looking up derivatives of functions
//===----------------------------------------------------------------------===//

/// Returns a differentiability witness (definition or declaration) exactly
/// matching the specified indices. If none are found in the given `module`,
/// returns `nullptr`.
///
/// \param parameterIndices must be lowered to SIL.
/// \param resultIndices must be lowered to SIL.
SILDifferentiabilityWitness *
getExactDifferentiabilityWitness(SILModule &module, SILFunction *original,
IndexSubset *parameterIndices,
IndexSubset *resultIndices);

/// Finds the derivative configuration (from `@differentiable` and
/// `@derivative` attributes) for `original` whose parameter indices are a
/// minimal superset of the specified AST parameter indices. Returns `None` if
/// no such configuration is found.
///
/// \param parameterIndices must be lowered to SIL.
/// \param minimalASTParameterIndices is an output parameter that is set to the
/// AST indices of the minimal configuration, or to `nullptr` if no such
/// configuration exists.
Optional<AutoDiffConfig>
findMinimalDerivativeConfiguration(AbstractFunctionDecl *original,
IndexSubset *parameterIndices,
IndexSubset *&minimalASTParameterIndices);

/// Returns a differentiability witness for `original` whose parameter indices
/// are a minimal superset of the specified parameter indices and whose result
/// indices match the given result indices, out of all
/// differentiability witnesses that come from AST "@differentiable" or
/// "@differentiating" attributes.
///
/// This function never creates new differentiability witness definitions.
/// However, this function may create new differentiability witness declarations
/// referring to definitions in other modules when these witnesses have not yet
/// been declared in the current module.
///
/// \param module is the SILModule in which to get or create the witnesses.
/// \param parameterIndices must be lowered to SIL.
/// \param resultIndices must be lowered to SIL.
SILDifferentiabilityWitness *getOrCreateMinimalASTDifferentiabilityWitness(
SILModule &module, SILFunction *original, IndexSubset *parameterIndices,
IndexSubset *resultIndices);

} // end namespace autodiff

/// Creates arguments in the entry block based on the function type.
Expand Down Expand Up @@ -85,6 +156,34 @@ inline void createEntryArguments(SILFunction *f) {
}
}

/// Cloner that remaps types using the target function's generic environment.
class BasicTypeSubstCloner final
: public TypeSubstCloner<BasicTypeSubstCloner, SILOptFunctionBuilder> {

static SubstitutionMap getSubstitutionMap(SILFunction *target) {
if (auto *targetGenEnv = target->getGenericEnvironment())
return targetGenEnv->getForwardingSubstitutionMap();
return SubstitutionMap();
}

public:
explicit BasicTypeSubstCloner(SILFunction *original, SILFunction *target)
: TypeSubstCloner(*target, *original, getSubstitutionMap(target)) {}

void postProcess(SILInstruction *orig, SILInstruction *cloned) {
SILClonerWithScopes::postProcess(orig, cloned);
}

void cloneFunction() {
auto &newFunction = Builder.getFunction();
auto *entry = newFunction.createBasicBlock();
createEntryArguments(&newFunction);
SmallVector<SILValue, 8> entryArguments(newFunction.getArguments().begin(),
newFunction.getArguments().end());
cloneFunctionBody(&Original, entry, entryArguments);
}
};

} // end namespace swift

#endif // SWIFT_SILOPTIMIZER_MANDATORY_DIFFERENTIATION_COMMON_H
113 changes: 113 additions & 0 deletions include/swift/SILOptimizer/Utils/Differentiation/Thunk.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
//===--- Thunk.h - Automatic differentiation thunks -----------*- C++ -*---===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//
//
// Automatic differentiation thunk generation utilities.
//
//===----------------------------------------------------------------------===//

#ifndef SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_THUNK_H
#define SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_THUNK_H

#include "swift/AST/AutoDiff.h"
#include "swift/Basic/LLVM.h"
#include "swift/SIL/SILBuilder.h"

namespace swift {

class SILOptFunctionBuilder;
class SILModule;
class SILLocation;
class SILValue;
class OpenedArchetypeType;
class GenericEnvironment;
class SubstitutionMap;
class ArchetypeType;

//===----------------------------------------------------------------------===//
// Helpers
//===----------------------------------------------------------------------===//

namespace autodiff {

//===----------------------------------------------------------------------===//
// Thunk helpers
//===----------------------------------------------------------------------===//
// These helpers are copied/adapted from SILGen. They should be refactored and
// moved to a shared location.
//===----------------------------------------------------------------------===//

/// The thunk kinds used in the differentiation transform.
enum class DifferentiationThunkKind {
/// A reabstraction thunk.
///
/// Reabstraction thunks transform a function-typed value to another one with
/// different parameter/result abstraction patterns. This is identical to the
/// thunks generated by SILGen.
Reabstraction,

/// An index subset thunk.
///
/// An index subset thunk is used transform JVP/VJPs into a version that is
/// "wrt" fewer differentiation parameters.
/// - Differentials of thunked JVPs use zero for non-requested differentiation
/// parameters.
/// - Pullbacks of thunked VJPs discard results for non-requested
/// differentiation parameters.
IndexSubset
};

CanGenericSignature buildThunkSignature(SILFunction *fn, bool inheritGenericSig,
OpenedArchetypeType *openedExistential,
GenericEnvironment *&genericEnv,
SubstitutionMap &contextSubs,
SubstitutionMap &interfaceSubs,
ArchetypeType *&newArchetype);

/// Build the type of a function transformation thunk.
CanSILFunctionType buildThunkType(SILFunction *fn,
CanSILFunctionType &sourceType,
CanSILFunctionType &expectedType,
GenericEnvironment *&genericEnv,
SubstitutionMap &interfaceSubs,
bool withoutActuallyEscaping,
DifferentiationThunkKind thunkKind);

/// Get or create a derivative function parameter index subset thunk from
/// `actualIndices` to `desiredIndices` for the given associated function
/// value and original function operand. Returns a pair of the parameter
/// index subset thunk and its interface substitution map (used to partially
/// apply the thunk).
/// Calls `getOrCreateSubsetParametersThunkForLinearMap` to thunk the linear
/// map returned by the derivative function.
std::pair<SILFunction *, SubstitutionMap>
getOrCreateSubsetParametersThunkForDerivativeFunction(
SILOptFunctionBuilder &fb, SILValue origFnOperand, SILValue derivativeFn,
AutoDiffDerivativeFunctionKind kind, SILAutoDiffIndices desiredIndices,
SILAutoDiffIndices actualIndices);

/// Get or create a derivative function parameter index subset thunk from
/// `actualIndices` to `desiredIndices` for the given associated function
/// value and original function operand. Returns a pair of the parameter
/// index subset thunk and its interface substitution map (used to partially
/// apply the thunk).
std::pair<SILFunction *, SubstitutionMap>
getOrCreateSubsetParametersThunkForLinearMap(
SILOptFunctionBuilder &fb, SILFunction *assocFn,
CanSILFunctionType origFnType, CanSILFunctionType linearMapType,
CanSILFunctionType targetType, AutoDiffDerivativeFunctionKind kind,
SILAutoDiffIndices desiredIndices, SILAutoDiffIndices actualIndices);

} // end namespace autodiff

} // end namespace swift

#endif // SWIFT_SILOPTIMIZER_MANDATORY_DIFFERENTIATION_THUNK_H
6 changes: 6 additions & 0 deletions lib/SIL/IR/SILType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,12 @@ bool SILType::isLoweringOf(TypeExpansionContext context, SILModule &Mod,
return loweredType.getASTType() == formalType;
}

bool SILType::isDifferentiable(SILModule &M) const {
return getASTType()
->getAutoDiffTangentSpace(LookUpConformanceInModule(M.getSwiftModule()))
.hasValue();
}

Type
TypeBase::replaceSubstitutedSILFunctionTypesWithUnsubstituted(SILModule &M) const {
return Type(const_cast<TypeBase*>(this)).transform([&](Type t) -> Type {
Expand Down
Loading