Skip to content

[AutoDiff] Parameter indices data structure overhaul #24761

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
May 14, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
241 changes: 224 additions & 17 deletions include/swift/AST/AutoDiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include "ASTContext.h"
#include "llvm/ADT/SmallBitVector.h"
#include "swift/Basic/Range.h"

namespace swift {

Expand Down Expand Up @@ -73,6 +74,7 @@ class ParsedAutoDiffParameter {
};

class AnyFunctionType;
class AutoDiffIndexSubset;
class AutoDiffParameterIndicesBuilder;
class Type;

Expand Down Expand Up @@ -173,7 +175,8 @@ class AutoDiffParameterIndices : public llvm::FoldingSetNode {
/// ==> returns 1110
/// (because the lowered SIL type is (A, B, C, D) -> R)
///
llvm::SmallBitVector getLowered(AnyFunctionType *functionType) const;
AutoDiffIndexSubset *getLowered(ASTContext &ctx,
AnyFunctionType *functionType) const;

void Profile(llvm::FoldingSetNodeID &ID) const {
ID.AddInteger(parameters.size());
Expand Down Expand Up @@ -219,6 +222,216 @@ class AutoDiffParameterIndicesBuilder {
unsigned size() { return parameters.size(); }
};

class AutoDiffIndexSubset : public llvm::FoldingSetNode {
public:
typedef uint64_t BitWord;

static constexpr unsigned bitWordSize = sizeof(BitWord);
static constexpr unsigned numBitsPerBitWord = bitWordSize * 8;

static std::pair<unsigned, unsigned>
getBitWordIndexAndOffset(unsigned index) {
auto bitWordIndex = index / numBitsPerBitWord;
auto bitWordOffset = index % numBitsPerBitWord;
return {bitWordIndex, bitWordOffset};
}

static unsigned getNumBitWordsNeededForCapacity(unsigned capacity) {
if (capacity == 0) return 0;
return capacity / numBitsPerBitWord + 1;
}

private:
/// The total capacity of the index subset, which is `1` less than the largest
/// index.
unsigned capacity;
/// The number of bit words in the index subset.
unsigned numBitWords;

BitWord *getBitWordsData() {
return reinterpret_cast<BitWord *>(this + 1);
}

const BitWord *getBitWordsData() const {
return reinterpret_cast<const BitWord *>(this + 1);
}

ArrayRef<BitWord> getBitWords() const {
return {getBitWordsData(), getNumBitWords()};
}

BitWord getBitWord(unsigned i) const {
return getBitWordsData()[i];
}

BitWord &getBitWord(unsigned i) {
return getBitWordsData()[i];
}

MutableArrayRef<BitWord> getMutableBitWords() {
return {const_cast<BitWord *>(getBitWordsData()), getNumBitWords()};
}

explicit AutoDiffIndexSubset(unsigned capacity, ArrayRef<unsigned> indices)
: capacity(capacity),
numBitWords(getNumBitWordsNeededForCapacity(capacity)) {
std::uninitialized_fill_n(getBitWordsData(), numBitWords, 0);
for (auto i : indices) {
unsigned bitWordIndex, offset;
std::tie(bitWordIndex, offset) = getBitWordIndexAndOffset(i);
getBitWord(bitWordIndex) |= (1 << offset);
}
}

public:
AutoDiffIndexSubset() = delete;
AutoDiffIndexSubset(const AutoDiffIndexSubset &) = delete;
AutoDiffIndexSubset &operator=(const AutoDiffIndexSubset &) = delete;

// Defined in ASTContext.h.
static AutoDiffIndexSubset *get(ASTContext &ctx,
unsigned capacity,
ArrayRef<unsigned> indices);

static AutoDiffIndexSubset *getDefault(ASTContext &ctx,
unsigned capacity,
bool includeAll = false) {
if (includeAll)
return getFromRange(ctx, capacity, IntRange<>(capacity));
return get(ctx, capacity, {});
}

static AutoDiffIndexSubset *getFromRange(ASTContext &ctx,
unsigned capacity,
IntRange<> range) {
return get(ctx, capacity,
SmallVector<unsigned, 8>(range.begin(), range.end()));
}

unsigned getNumBitWords() const {
return numBitWords;
}

unsigned getCapacity() const {
return capacity;
}

class iterator;

iterator begin() const {
return iterator(this);
}

iterator end() const {
return iterator(this, (int)capacity);
}

iterator_range<iterator> getIndices() const {
return make_range(begin(), end());
}

unsigned getNumIndices() const {
return (unsigned)std::distance(begin(), end());
}

bool contains(unsigned index) const {
unsigned bitWordIndex, offset;
std::tie(bitWordIndex, offset) = getBitWordIndexAndOffset(index);
return getBitWord(bitWordIndex) & (1 << offset);
}

bool isEmpty() const {
return llvm::all_of(getBitWords(), [](BitWord bw) { return !(bool)bw; });
}

bool equals(AutoDiffIndexSubset *other) const {
return capacity == other->getCapacity() &&
getBitWords().equals(other->getBitWords());
}

bool isSubsetOf(AutoDiffIndexSubset *other) const;
bool isSupersetOf(AutoDiffIndexSubset *other) const;

AutoDiffIndexSubset *adding(unsigned index, ASTContext &ctx) const;
AutoDiffIndexSubset *extendingCapacity(ASTContext &ctx,
unsigned newCapacity) const;

void Profile(llvm::FoldingSetNodeID &id) const {
id.AddInteger(capacity);
for (auto index : getIndices())
id.AddInteger(index);
}

void print(llvm::raw_ostream &s = llvm::outs()) const {
s << '{';
interleave(range(capacity), [this, &s](unsigned i) { s << contains(i); },
[&s] { s << ", "; });
s << '}';
}

void dump(llvm::raw_ostream &s = llvm::errs()) const {
s << "(autodiff_index_subset capacity=" << capacity << " indices=(";
interleave(getIndices(), [&s](unsigned i) { s << i; },
[&s] { s << ", "; });
s << "))";
}

int findNext(int startIndex) const;
int findFirst() const { return findNext(-1); }
int findPrevious(int endIndex) const;
int findLast() const { return findPrevious(capacity); }

class iterator {
public:
typedef unsigned value_type;
typedef unsigned difference_type;
typedef unsigned * pointer;
typedef unsigned & reference;
typedef std::forward_iterator_tag iterator_category;

private:
const AutoDiffIndexSubset *parent;
int current = 0;

void advance() {
assert(current != -1 && "Trying to advance past end.");
current = parent->findNext(current);
}

public:
iterator(const AutoDiffIndexSubset *parent, int current)
: parent(parent), current(current) {}
explicit iterator(const AutoDiffIndexSubset *parent)
: iterator(parent, parent->findFirst()) {}
iterator(const iterator &) = default;

iterator operator++(int) {
auto prev = *this;
advance();
return prev;
}

iterator &operator++() {
advance();
return *this;
}

unsigned operator*() const { return current; }

bool operator==(const iterator &other) const {
assert(parent == other.parent &&
"Comparing iterators from different AutoDiffIndexSubsets");
return current == other.current;
}

bool operator!=(const iterator &other) const {
assert(parent == other.parent &&
"Comparing iterators from different AutoDiffIndexSubsets");
return current != other.current;
}
};
};

/// SIL-level automatic differentiation indices. Consists of a source index,
/// i.e. index of the dependent result to differentiate from, and parameter
/// indices, i.e. index of independent parameters to differentiate with
Expand All @@ -242,38 +455,33 @@ struct SILAutoDiffIndices {
/// Function type: (A, B) -> (C, D) -> R
/// Bits: [C][D][A][B]
///
llvm::SmallBitVector parameters;
AutoDiffIndexSubset *parameters;

/// Creates a set of AD indices from the given source index and a bit vector
/// representing parameter indices.
/*implicit*/ SILAutoDiffIndices(unsigned source,
llvm::SmallBitVector parameters)
AutoDiffIndexSubset *parameters)
: source(source), parameters(parameters) {}

/// Creates a set of AD indices from the given source index and an array of
/// parameter indices. Elements in `parameters` must be ascending integers.
/*implicit*/ SILAutoDiffIndices(unsigned source,
ArrayRef<unsigned> parameters);

bool operator==(const SILAutoDiffIndices &other) const;

/// Queries whether the function's parameter with index `parameterIndex` is
/// one of the parameters to differentiate with respect to.
bool isWrtParameter(unsigned parameterIndex) const {
return parameterIndex < parameters.size() &&
parameters.test(parameterIndex);
return parameterIndex < parameters->getCapacity() &&
parameters->contains(parameterIndex);
}

void print(llvm::raw_ostream &s = llvm::outs()) const {
s << "(source=" << source << " parameters=(";
interleave(parameters.set_bits(),
interleave(parameters->getIndices(),
[&s](unsigned p) { s << p; }, [&s]{ s << ' '; });
s << "))";
}

std::string mangle() const {
std::string result = "src_" + llvm::utostr(source) + "_wrt_";
interleave(parameters.set_bits(),
interleave(parameters->getIndices(),
[&](unsigned idx) { result += llvm::utostr(idx); },
[&] { result += '_'; });
return result;
Expand Down Expand Up @@ -449,19 +657,18 @@ template<typename T> struct DenseMapInfo;

template<> struct DenseMapInfo<SILAutoDiffIndices> {
static SILAutoDiffIndices getEmptyKey() {
return { DenseMapInfo<unsigned>::getEmptyKey(), SmallBitVector() };
return { DenseMapInfo<unsigned>::getEmptyKey(), nullptr };
}

static SILAutoDiffIndices getTombstoneKey() {
return { DenseMapInfo<unsigned>::getTombstoneKey(),
SmallBitVector(sizeof(intptr_t), true) };
return { DenseMapInfo<unsigned>::getTombstoneKey(), nullptr };
}

static unsigned getHashValue(const SILAutoDiffIndices &Val) {
auto params = Val.parameters.set_bits();
unsigned combinedHash =
hash_combine(~1U, DenseMapInfo<unsigned>::getHashValue(Val.source),
hash_combine_range(params.begin(), params.end()));
hash_combine_range(Val.parameters->begin(),
Val.parameters->end()));
return combinedHash;
}

Expand Down
4 changes: 3 additions & 1 deletion include/swift/AST/DiagnosticsParse.def
Original file line number Diff line number Diff line change
Expand Up @@ -1535,7 +1535,9 @@ ERROR(sil_inst_autodiff_operand_list_expected_rbrace,PointsToFirstBadToken,
ERROR(sil_inst_autodiff_num_operand_list_order_mismatch,PointsToFirstBadToken,
"the number of operand lists does not match the order", ())
ERROR(sil_inst_autodiff_expected_associated_function_kind_attr,PointsToFirstBadToken,
"expects an assoiacted function kind attribute, e.g. '[jvp]'", ())
"expected an associated function kind attribute, e.g. '[jvp]'", ())
ERROR(sil_inst_autodiff_expected_function_type_operand,PointsToFirstBadToken,
"expected an operand of a function type", ())

//------------------------------------------------------------------------------
// MARK: Generics parsing diagnostics
Expand Down
6 changes: 3 additions & 3 deletions include/swift/AST/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -4132,14 +4132,14 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,

// SWIFT_ENABLE_TENSORFLOW
CanSILFunctionType getWithDifferentiability(
unsigned differentiationOrder, const SmallBitVector &parameterIndices);
unsigned differentiationOrder, AutoDiffIndexSubset *parameterIndices);

CanSILFunctionType getWithoutDifferentiability();

/// Returns the type of a differentiation function that is associated with
/// a function of this type.
CanSILFunctionType getAutoDiffAssociatedFunctionType(
const SmallBitVector &parameterIndices, unsigned resultIndex,
AutoDiffIndexSubset *parameterIndices, unsigned resultIndex,
unsigned differentiationOrder, AutoDiffAssociatedFunctionKind kind,
SILModule &module, LookupConformanceFn lookupConformance,
GenericSignature *whereClauseGenericSignature = nullptr);
Expand All @@ -4148,7 +4148,7 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
/// differentiate with respect to for this differentiable function type. (e.g.
/// which parameters are not @nondiff). The function type must be
/// differentiable.
SmallBitVector getDifferentiationParameterIndices() const;
AutoDiffIndexSubset *getDifferentiationParameterIndices();

/// If this is a @convention(witness_method) function with a class
/// constrained self parameter, return the class constraint for the
Expand Down
2 changes: 1 addition & 1 deletion include/swift/SIL/SILBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,7 @@ class SILBuilder {

/// SWIFT_ENABLE_TENSORFLOW
AutoDiffFunctionInst *createAutoDiffFunction(
SILLocation loc, const llvm::SmallBitVector &parameterIndices,
SILLocation loc, AutoDiffIndexSubset *parameterIndices,
unsigned differentiationOrder, SILValue original,
ArrayRef<SILValue> associatedFunctions = {}) {
return insert(AutoDiffFunctionInst::create(getModule(),
Expand Down
3 changes: 3 additions & 0 deletions include/swift/SIL/SILFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,9 @@ class SILDifferentiableAttr final {
SILFunction *getOriginal() const { return Original; }

const SILAutoDiffIndices &getIndices() const { return indices; }
void setIndices(const SILAutoDiffIndices &indices) {
this->indices = indices;
}

TrailingWhereClause *getWhereClause() const { return WhereClause; }

Expand Down
Loading