Skip to content

[AutoDiff] Add SIL differentiability witnesses. #27487

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
7e5d804
Diff witness
rxwei Jun 28, 2019
a5c1d13
Rename SILDifferentiabilityWitness.h.
dan-zheng Oct 2, 2019
0ce7d91
Update SILDifferentiabilityWitness definition.
dan-zheng Oct 2, 2019
471e660
Add SILDifferentiabilityWitness to SILModule.
dan-zheng Oct 2, 2019
bf87d2e
[WIP] Start SILDifferentiabilityWitness parsing/printing.
dan-zheng Oct 2, 2019
6b78684
Use improved syntax.
dan-zheng Oct 2, 2019
ab304f1
Merge branch 'tensorflow' of github.com:apple/swift into sil-differen…
dan-zheng Oct 10, 2019
a2ae0f2
Finish parsing/printing/serialization.
dan-zheng Oct 11, 2019
419eea2
Revamp serialization to enable lookup by key.
dan-zheng Oct 11, 2019
b6cd1d7
Add SIL verification.
dan-zheng Oct 11, 2019
573dd3e
Add miscellaneous todo comments.
dan-zheng Oct 11, 2019
843b631
Add round-trip parsing/printing test.
dan-zheng Oct 11, 2019
835f1c0
Clean up.
dan-zheng Oct 11, 2019
aca1abb
Add Swift source for parsing test.
dan-zheng Oct 11, 2019
d75f6ce
Merge branch 'tensorflow' of github.com:apple/swift into sil-differen…
dan-zheng Oct 11, 2019
4187700
`AutoDiffIndexSubset` -> `IndexSubset`
dan-zheng Oct 11, 2019
873468f
Minor fix.
dan-zheng Oct 11, 2019
aea64d3
Update differentiability witness syntax.
dan-zheng Oct 12, 2019
bb93af3
Add `AutoDiffConfig` and use in `SILDifferentiabilityWitnessKey`.
dan-zheng Oct 13, 2019
f240ed2
Change `AutoDiffConfig` to a POD.
dan-zheng Oct 13, 2019
ad6b7aa
Add `DeclAttribute *` to `SILDifferentiabilityWitness`.
dan-zheng Oct 13, 2019
7c63d03
Clean up.
dan-zheng Oct 13, 2019
c3959ad
Address review feedback.
dan-zheng Oct 13, 2019
2073f86
Address review feedback.
dan-zheng Oct 13, 2019
27d7abc
Parse/print `[serialized]` flag.
dan-zheng Oct 13, 2019
69209be
Add parsing/printing tests, address review feedback.
dan-zheng Oct 13, 2019
df9cd49
Fix serialization and add test.
dan-zheng Oct 13, 2019
d673b70
Fix verification.
dan-zheng Oct 13, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 17 additions & 9 deletions include/swift/AST/ASTMangler.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,23 +155,31 @@ class ASTMangler : public Mangler {
ModuleDecl *Module);

// SWIFT_ENABLE_TENSORFLOW
// Mangle the derivative function (JVP/VJP) with the given:
// - Mangled original function name.
// - Derivative function kind.
// - Parameter/result indices.
/// Mangle the derivative function (JVP/VJP) with the given:
/// - Mangled original function name.
/// - Derivative function kind.
/// - Parameter/result indices.
std::string mangleAutoDiffDerivativeFunctionHelper(
StringRef name, AutoDiffDerivativeFunctionKind kind,
const SILAutoDiffIndices &indices);

// SWIFT_ENABLE_TENSORFLOW
// Mangle the autodiff linear map (differential/pullback) with the given:
// - Mangled original function name.
// - Linear map kind.
// - Parameter/result indices.
/// Mangle the autodiff linear map (differential/pullback) with the given:
/// - Mangled original function name.
/// - Linear map kind.
/// - Parameter/result indices.
std::string mangleAutoDiffLinearMapHelper(
StringRef name, AutoDiffLinearMapKind kind,
const SILAutoDiffIndices &indices);

/// Mangle a SIL differentiability witness key.
/// - Mangled original function name.
/// - Parameter indices.
/// - Result indices.
/// - Derivative generic signature (optional).
std::string mangleSILDifferentiabilityWitnessKey(
SILDifferentiabilityWitnessKey key);
// SWIFT_ENABLE_TENSORFLOW END

std::string mangleKeyPathGetterThunkHelper(const AbstractStorageDecl *property,
GenericSignature *signature,
CanType baseType,
Expand Down
50 changes: 48 additions & 2 deletions include/swift/AST/AutoDiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,16 @@ struct AutoDiffDerivativeFunctionKind {
}
};

/// Identifies an autodiff derivative function configuration:
/// - Parameter indices.
/// - Result indices.
/// - Derivative generic signature (optional).
struct AutoDiffConfig {
IndexSubset *parameterIndices;
IndexSubset *resultIndices;
GenericSignature *derivativeGenericSignature;
};

/// In conjunction with the original function declaration, identifies an
/// autodiff derivative function.
///
Expand All @@ -218,8 +228,7 @@ class AutoDiffDerivativeFunctionIdentifier : public llvm::FoldingSetNode {
IndexSubset *const parameterIndices;

AutoDiffDerivativeFunctionIdentifier(
AutoDiffDerivativeFunctionKind kind,
IndexSubset *parameterIndices) :
AutoDiffDerivativeFunctionKind kind, IndexSubset *parameterIndices) :
kind(kind), parameterIndices(parameterIndices) {}

public:
Expand All @@ -238,6 +247,11 @@ class AutoDiffDerivativeFunctionIdentifier : public llvm::FoldingSetNode {
}
};

/// The key type used for uniquing `SILDifferentiabilityWitness` in
/// `SILModule`: original function name, parameter indices, result indices, and
/// derivative generic signature.
using SILDifferentiabilityWitnessKey = std::pair<StringRef, AutoDiffConfig>;

/// Automatic differentiation utility namespace.
namespace autodiff {
/// Appends the subset's parameter's types to `result`, in the order in
Expand Down Expand Up @@ -363,10 +377,42 @@ class VectorSpace {

namespace llvm {

using swift::AutoDiffConfig;
using swift::AutoDiffDerivativeFunctionKind;
using swift::GenericSignature;
using swift::IndexSubset;
using swift::SILAutoDiffIndices;

template<typename T> struct DenseMapInfo;

template<> struct DenseMapInfo<AutoDiffConfig> {
static AutoDiffConfig getEmptyKey() {
auto *ptr = llvm::DenseMapInfo<void *>::getEmptyKey();
return {static_cast<IndexSubset *>(ptr), static_cast<IndexSubset *>(ptr),
static_cast<GenericSignature *>(ptr)};
}

static AutoDiffConfig getTombstoneKey() {
auto *ptr = llvm::DenseMapInfo<void *>::getTombstoneKey();
return {static_cast<IndexSubset *>(ptr), static_cast<IndexSubset *>(ptr),
static_cast<GenericSignature *>(ptr)};
}

static unsigned getHashValue(const AutoDiffConfig &Val) {
unsigned combinedHash = hash_combine(
~1U, DenseMapInfo<void *>::getHashValue(Val.parameterIndices),
DenseMapInfo<void *>::getHashValue(Val.resultIndices),
DenseMapInfo<void *>::getHashValue(Val.derivativeGenericSignature));
return combinedHash;
}

static bool isEqual(const AutoDiffConfig &LHS, const AutoDiffConfig &RHS) {
return LHS.parameterIndices == RHS.parameterIndices &&
LHS.resultIndices == RHS.resultIndices &&
LHS.derivativeGenericSignature == RHS.derivativeGenericSignature;
}
};

template<> struct DenseMapInfo<SILAutoDiffIndices> {
static SILAutoDiffIndices getEmptyKey() {
return { DenseMapInfo<unsigned>::getEmptyKey(), nullptr };
Expand Down
10 changes: 10 additions & 0 deletions include/swift/AST/DiagnosticsParse.def
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,16 @@ ERROR(sil_witness_assoc_conf_not_found,none,
ERROR(sil_witness_protocol_conformance_not_found,none,
"sil protocol conformance not found", ())

// SIL differentiability witnesses
ERROR(sil_diff_witness_expected_token,PointsToFirstBadToken,
"expected '%0' in differentiability witness", (StringRef))
ERROR(sil_diff_witness_expected_index_list,PointsToFirstBadToken,
"expected a space-separated list of indices, e.g. '0 1'", ())
ERROR(sil_diff_witness_expected_parameter_index,PointsToFirstBadToken,
"expected a parameter index to differentiate with respect to", ())
ERROR(sil_diff_witness_expected_result_index,PointsToFirstBadToken,
"expected a result index to differentiate with respect to", ())

// SIL Coverage Map
ERROR(sil_coverage_func_not_found, none,
"sil function not found %0", (Identifier))
Expand Down
3 changes: 3 additions & 0 deletions include/swift/Parse/ParseSILSupport.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ namespace swift {
virtual bool parseSILGlobal(Parser &P) = 0;
virtual bool parseSILWitnessTable(Parser &P) = 0;
virtual bool parseSILDefaultWitnessTable(Parser &P) = 0;
// SWIFT_ENABLE_TENSORFLOW
virtual bool parseSILDifferentiabilityWitness(Parser &P) = 0;
// SWIFT_ENABLE_TENSORFLOW END
virtual bool parseSILCoverageMap(Parser &P) = 0;
virtual bool parseSILProperty(Parser &P) = 0;
virtual bool parseSILScope(Parser &P) = 0;
Expand Down
153 changes: 153 additions & 0 deletions include/swift/SIL/SILDifferentiabilityWitness.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
//===--- SILDifferentiabilityWitness.h - Differentiability witnesses ------===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2014 - 2019 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//
//
// This file defines the SILDifferentiabilityWitness class, which maps an
// original SILFunction and derivative configuration (parameter indices, result
// indices, derivative generic signature) to derivative functions (JVP and VJP).
//
// SIL differentiability witnesses are generated from the `@differentiable`
// and `@differentiating` attributes AST declaration attributes.
// Differentiability witnesses are canonicalized by the differentiation SIL
// transform, which fills in missing derivative functions. Canonical
// differentiability witnesses from other modules can be deserialized to look up
// derivative functions.
//
//===----------------------------------------------------------------------===//

#ifndef SWIFT_SIL_SILDIFFERENTIABILITYWITNESS_H
#define SWIFT_SIL_SILDIFFERENTIABILITYWITNESS_H

#include "swift/AST/Attr.h"
#include "swift/AST/AutoDiff.h"
#include "swift/AST/GenericSignature.h"
#include "swift/SIL/SILAllocated.h"
#include "llvm/ADT/ilist_node.h"
#include "llvm/ADT/ilist.h"

namespace swift {

class SILPrintContext;

class SILDifferentiabilityWitness
: public llvm::ilist_node<SILDifferentiabilityWitness>,
public SILAllocated<SILDifferentiabilityWitness>
{
private:
/// The module which contains the differentiability witness.
SILModule &module;
/// The linkage of the differentiability witness.
SILLinkage linkage;
/// The original function.
SILFunction *originalFunction;
/// The parameter indices.
IndexSubset *parameterIndices;
/// The result indices.
IndexSubset *resultIndices;
/// The derivative generic signature (optional).
GenericSignature *derivativeGenericSignature;
/// The JVP (Jacobian-vector products) derivative function.
SILFunction *jvp;
/// The VJP (vector-Jacobian products) derivative function.
SILFunction *vjp;
/// Whether or not this differentiability witness is serialized, which allows
/// devirtualization from another module.
bool serialized;
/// The AST `@differentiable` or `@differentiating` attribute from which the
/// differentiability witness is generated. Used for diagnostics.
/// Null if the differentiability witness is parsed from SIL or if it is
/// deserialized.
DeclAttribute *attribute = nullptr;

SILDifferentiabilityWitness(SILModule &module, SILLinkage linkage,
SILFunction *originalFunction,
IndexSubset *parameterIndices,
IndexSubset *resultIndices,
GenericSignature *derivativeGenSig,
SILFunction *jvp, SILFunction *vjp,
bool isSerialized, DeclAttribute *attribute)
: module(module), linkage(linkage), originalFunction(originalFunction),
parameterIndices(parameterIndices), resultIndices(resultIndices),
derivativeGenericSignature(derivativeGenSig), jvp(jvp), vjp(vjp),
serialized(isSerialized), attribute(attribute) {}

public:
static SILDifferentiabilityWitness *create(
SILModule &module, SILLinkage linkage, SILFunction *originalFunction,
IndexSubset *parameterIndices, IndexSubset *resultIndices,
GenericSignature *derivativeGenSig, SILFunction *jvp, SILFunction *vjp,
bool isSerialized, DeclAttribute *attribute = nullptr);

SILDifferentiabilityWitnessKey getKey() const;
SILModule &getModule() const { return module; }
SILLinkage getLinkage() const { return linkage; }
SILFunction *getOriginalFunction() const { return originalFunction; }
IndexSubset *getParameterIndices() const {
return parameterIndices;
}
IndexSubset *getResultIndices() const {
return resultIndices;
}
GenericSignature *getDerivativeGenericSignature() const {
return derivativeGenericSignature;
}
SILFunction *getJVP() const { return jvp; }
SILFunction *getVJP() const { return vjp; }
SILFunction *getDerivative(AutoDiffDerivativeFunctionKind kind) const {
switch (kind) {
case AutoDiffDerivativeFunctionKind::JVP: return jvp;
case AutoDiffDerivativeFunctionKind::VJP: return vjp;
}
}
void setJVP(SILFunction *jvp) { this->jvp = jvp; }
void setVJP(SILFunction *vjp) { this->vjp = vjp; }
void setDerivative(AutoDiffDerivativeFunctionKind kind,
SILFunction *derivative) {
switch (kind) {
case AutoDiffDerivativeFunctionKind::JVP: jvp = derivative; break;
case AutoDiffDerivativeFunctionKind::VJP: vjp = derivative; break;
}
}
bool isSerialized() const { return serialized; }
DeclAttribute *getAttribute() const { return attribute; }

/// Verify that the differentiability witness is well-formed.
void verify(const SILModule &module) const;

void print(llvm::raw_ostream &os, bool verbose = false) const;
void dump() const;
};

} // end namespace swift

namespace llvm {

//===----------------------------------------------------------------------===//
// ilist_traits for SILDifferentiabilityWitness
//===----------------------------------------------------------------------===//

template <>
struct ilist_traits<::swift::SILDifferentiabilityWitness>
: public ilist_node_traits<::swift::SILDifferentiabilityWitness> {
using SILDifferentiabilityWitness = ::swift::SILDifferentiabilityWitness;

public:
static void deleteNode(SILDifferentiabilityWitness *DW) {
DW->~SILDifferentiabilityWitness();
}

private:
void createNode(const SILDifferentiabilityWitness &);
};

} // namespace llvm

#endif // SWIFT_SIL_SILDIFFERENTIABILITYWITNESS_H
41 changes: 41 additions & 0 deletions include/swift/SIL/SILModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
#include "swift/SIL/SILCoverageMap.h"
#include "swift/SIL/SILDeclRef.h"
#include "swift/SIL/SILDefaultWitnessTable.h"
// SWIFT_ENABLE_TENSORFLOW
#include "swift/SIL/SILDifferentiabilityWitness.h"
#include "swift/SIL/SILFunction.h"
#include "swift/SIL/SILGlobalVariable.h"
#include "swift/SIL/SILPrintContext.h"
Expand Down Expand Up @@ -113,6 +115,10 @@ class SILModule {
using PropertyListType = llvm::ilist<SILProperty>;
using WitnessTableListType = llvm::ilist<SILWitnessTable>;
using DefaultWitnessTableListType = llvm::ilist<SILDefaultWitnessTable>;
// SWIFT_ENABLE_TENSORFLOW
using DifferentiabilityWitnessListType =
llvm::ilist<SILDifferentiabilityWitness>;
// SWIFT_ENABLE_TENSORFLOW END
using CoverageMapCollectionType =
llvm::MapVector<StringRef, SILCoverageMap *>;

Expand All @@ -139,6 +145,9 @@ class SILModule {
friend SILProperty;
friend SILUndef;
friend SILWitnessTable;
// SWIFT_ENABLE_TENSORFLOW
friend SILDifferentiabilityWitness;
// SWIFT_ENABLE_TENSORFLOW END
friend Lowering::SILGenModule;
friend Lowering::TypeConverter;
class SerializationCallback;
Expand Down Expand Up @@ -194,6 +203,17 @@ class SILModule {
/// The list of SILDefaultWitnessTables in the module.
DefaultWitnessTableListType defaultWitnessTables;

// SWIFT_ENABLE_TENSORFLOW
/// Lookup table for SIL differentiability witnesses from original functions.
/// Indexed by key type: original function, parameter indices, result indices,
/// and derivative generic signature.
llvm::DenseMap<SILDifferentiabilityWitnessKey, SILDifferentiabilityWitness *>
DifferentiabilityWitnessMap;

/// The list of SILDifferentiabilityWitnesses in the module.
DifferentiabilityWitnessListType differentiabilityWitnesses;
// SWIFT_ENABLE_TENSORFLOW END

/// Lookup table for SIL Global Variables.
llvm::StringMap<SILGlobalVariable *> GlobalVariableMap;

Expand Down Expand Up @@ -446,6 +466,27 @@ class SILModule {
return {defaultWitnessTables.begin(), defaultWitnessTables.end()};
}

// SWIFT_ENABLE_TENSORFLOW
using differentiability_witness_iterator = DifferentiabilityWitnessListType::iterator;
using differentiability_witness_const_iterator = DifferentiabilityWitnessListType::const_iterator;
DifferentiabilityWitnessListType &getDifferentiabilityWitnessList() { return differentiabilityWitnesses; }
const DifferentiabilityWitnessListType &getDifferentiabilityWitnessList() const { return differentiabilityWitnesses; }
differentiability_witness_iterator differentiability_witness_begin() { return differentiabilityWitnesses.begin(); }
differentiability_witness_iterator differentiability_witness_end() { return differentiabilityWitnesses.end(); }
differentiability_witness_const_iterator differentiability_witness_begin() const { return differentiabilityWitnesses.begin(); }
differentiability_witness_const_iterator differentiability_witness_end() const { return differentiabilityWitnesses.end(); }
iterator_range<differentiability_witness_iterator>
getDifferentiabilityWitnesses() {
return {differentiabilityWitnesses.begin(),
differentiabilityWitnesses.end()};
}
iterator_range<differentiability_witness_const_iterator>
getDifferentiabilityWitnesses() const {
return {differentiabilityWitnesses.begin(),
differentiabilityWitnesses.end()};
}
// SWIFT_ENABLE_TENSORFLOW END

using sil_global_iterator = GlobalListType::iterator;
using sil_global_const_iterator = GlobalListType::const_iterator;
GlobalListType &getSILGlobalList() { return silGlobals; }
Expand Down
Loading