Skip to content

[AutoDiff] Handle materializing adjoints with non-differentiable fields #67319

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 96 additions & 30 deletions include/swift/SILOptimizer/Differentiation/AdjointValue.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
#define SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_ADJOINTVALUE_H

#include "swift/AST/Decl.h"
#include "swift/SIL/SILDebugVariable.h"
#include "swift/SIL/SILLocation.h"
#include "swift/SIL/SILValue.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/Support/Debug.h"
Expand All @@ -38,10 +40,18 @@ enum AdjointValueKind {

/// A concrete SIL value.
Concrete,

/// A special adjoint, made up of 2 adjoints -- an aggregate base adjoint and
/// an element adjoint to add to one of its fields. This case exists to avoid
/// eager materialization of a base adjoint upon addition with one of its
/// fields.
AddElement,
};

class AdjointValue;

struct AddElementValue;

class AdjointValueBase {
friend class AdjointValue;

Expand All @@ -60,9 +70,13 @@ class AdjointValueBase {
union Value {
unsigned numAggregateElements;
SILValue concrete;
AddElementValue *addElementValue;

Value(unsigned numAggregateElements)
: numAggregateElements(numAggregateElements) {}
Value(SILValue v) : concrete(v) {}
Value(AddElementValue *addElementValue)
: addElementValue(addElementValue) {}
Value() {}
} value;

Expand All @@ -86,6 +100,11 @@ class AdjointValueBase {

explicit AdjointValueBase(SILType type, llvm::Optional<DebugInfo> debugInfo)
: kind(AdjointValueKind::Zero), type(type), debugInfo(debugInfo) {}

explicit AdjointValueBase(SILType type, AddElementValue *addElementValue,
llvm::Optional<DebugInfo> debugInfo)
: kind(AdjointValueKind::AddElement), type(type), debugInfo(debugInfo),
value(addElementValue) {}
};

/// A symbolic adjoint value that wraps a `SILValue`, a zero, or an aggregate
Expand Down Expand Up @@ -127,6 +146,14 @@ class AdjointValue final {
return new (buf) AdjointValueBase(type, elements, debugInfo);
}

static AdjointValue
createAddElement(llvm::BumpPtrAllocator &allocator, SILType type,
AddElementValue *addElementValue,
llvm::Optional<DebugInfo> debugInfo = llvm::None) {
auto *buf = allocator.Allocate<AdjointValueBase>();
return new (buf) AdjointValueBase(type, addElementValue, debugInfo);
}

AdjointValueKind getKind() const { return base->kind; }
SILType getType() const { return base->type; }
CanType getSwiftType() const { return getType().getASTType(); }
Expand All @@ -140,6 +167,9 @@ class AdjointValue final {
bool isZero() const { return getKind() == AdjointValueKind::Zero; }
bool isAggregate() const { return getKind() == AdjointValueKind::Aggregate; }
bool isConcrete() const { return getKind() == AdjointValueKind::Concrete; }
bool isAddElement() const {
return getKind() == AdjointValueKind::AddElement;
}

unsigned getNumAggregateElements() const {
assert(isAggregate());
Expand All @@ -162,41 +192,77 @@ class AdjointValue final {
return base->value.concrete;
}

void print(llvm::raw_ostream &s) const {
switch (getKind()) {
case AdjointValueKind::Zero:
s << "Zero[" << getType() << ']';
break;
case AdjointValueKind::Aggregate:
s << "Aggregate[" << getType() << "](";
if (auto *decl =
getType().getASTType()->getStructOrBoundGenericStruct()) {
interleave(
llvm::zip(decl->getStoredProperties(), getAggregateElements()),
[&s](std::tuple<VarDecl *, const AdjointValue &> elt) {
s << std::get<0>(elt)->getName() << ": ";
std::get<1>(elt).print(s);
},
[&s] { s << ", "; });
} else if (getType().is<TupleType>()) {
interleave(
getAggregateElements(),
[&s](const AdjointValue &elt) { elt.print(s); },
[&s] { s << ", "; });
} else {
llvm_unreachable("Invalid aggregate");
}
s << ')';
break;
case AdjointValueKind::Concrete:
s << "Concrete[" << getType() << "](" << base->value.concrete << ')';
break;
}
AddElementValue *getAddElementValue() const {
assert(isAddElement());
return base->value.addElementValue;
}

void print(llvm::raw_ostream &s) const;

SWIFT_DEBUG_DUMP { print(llvm::dbgs()); };
};

/// An abstraction that represents the field locator in
/// an `AddElement` adjoint kind. Depending on the aggregate
/// kind - tuple or struct, of the `baseAdjoint` in an
/// `AddElement` adjoint, the field locator may be an `unsigned int`
/// or a `VarDecl *`.
struct FieldLocator final {
FieldLocator(VarDecl *field) : inner(field) {}
FieldLocator(unsigned int index) : inner(index) {}

friend AddElementValue;

private:
bool isTupleFieldLocator() const {
return std::holds_alternative<unsigned int>(inner);
}

const static constexpr std::true_type TUPLE_FIELD_LOCATOR_TAG =
std::true_type{};
const static constexpr std::false_type STRUCT_FIELD_LOCATOR_TAG =
std::false_type{};

unsigned int getInner(std::true_type) const {
return std::get<unsigned int>(inner);
}

VarDecl *getInner(std::false_type) const {
return std::get<VarDecl *>(inner);
}

std::variant<unsigned int, VarDecl *> inner;
};

/// The underlying value for an `AddElement` adjoint.
struct AddElementValue final {
AdjointValue baseAdjoint;
AdjointValue eltToAdd;
FieldLocator fieldLocator;

AddElementValue(AdjointValue baseAdjoint, AdjointValue eltToAdd,
FieldLocator fieldLocator)
: baseAdjoint(baseAdjoint), eltToAdd(eltToAdd),
fieldLocator(fieldLocator) {
assert(baseAdjoint.getType().is<TupleType>() ||
baseAdjoint.getType().getStructOrBoundGenericStruct() != nullptr);
}

bool isTupleAdjoint() const { return fieldLocator.isTupleFieldLocator(); }

bool isStructAdjoint() const { return !isTupleAdjoint(); }

VarDecl *getFieldDecl() const {
assert(isStructAdjoint());
return this->fieldLocator.getInner(FieldLocator::STRUCT_FIELD_LOCATOR_TAG);
}

unsigned int getFieldIndex() const {
assert(isTupleAdjoint());
return this->fieldLocator.getInner(FieldLocator::TUPLE_FIELD_LOCATOR_TAG);
}
};

inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
const AdjointValue &adjVal) {
adjVal.print(os);
Expand Down
70 changes: 70 additions & 0 deletions lib/SILOptimizer/Differentiation/AdjointValue.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
//===--- AdjointValue.h - Helper class for differentiation ----*- C++ -*---===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//
//
// AdjointValue - a symbolic representation for adjoint values enabling
// efficient differentiation by avoiding zero materialization.
//
//===----------------------------------------------------------------------===//

#define DEBUG_TYPE "differentiation"

#include "swift/SILOptimizer/Differentiation/AdjointValue.h"

void swift::autodiff::AdjointValue::print(llvm::raw_ostream &s) const {
switch (getKind()) {
case AdjointValueKind::Zero:
s << "Zero[" << getType() << ']';
break;
case AdjointValueKind::Aggregate:
s << "Aggregate[" << getType() << "](";
if (auto *decl = getType().getASTType()->getStructOrBoundGenericStruct()) {
interleave(
llvm::zip(decl->getStoredProperties(), getAggregateElements()),
[&s](std::tuple<VarDecl *, const AdjointValue &> elt) {
s << std::get<0>(elt)->getName() << ": ";
std::get<1>(elt).print(s);
},
[&s] { s << ", "; });
} else if (getType().is<TupleType>()) {
interleave(
getAggregateElements(),
[&s](const AdjointValue &elt) { elt.print(s); }, [&s] { s << ", "; });
} else {
llvm_unreachable("Invalid aggregate");
}
s << ')';
break;
case AdjointValueKind::Concrete:
s << "Concrete[" << getType() << "](" << base->value.concrete << ')';
break;
case AdjointValueKind::AddElement:
auto *addElementValue = getAddElementValue();
auto baseAdjoint = addElementValue->baseAdjoint;
auto eltToAdd = addElementValue->eltToAdd;

s << "AddElement[";
baseAdjoint.print(s);

s << ", Field(";
if (addElementValue->isTupleAdjoint()) {
s << addElementValue->getFieldIndex();
} else {
s << addElementValue->getFieldDecl()->getNameStr();
}
s << "), ";

eltToAdd.print(s);

s << "]";
break;
}
}
1 change: 1 addition & 0 deletions lib/SILOptimizer/Differentiation/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
target_sources(swiftSILOptimizer PRIVATE
ADContext.cpp
AdjointValue.cpp
Common.cpp
DifferentiationInvoker.cpp
JVPCloner.cpp
Expand Down
5 changes: 3 additions & 2 deletions lib/SILOptimizer/Differentiation/JVPCloner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,11 +228,12 @@ class JVPCloner::Implementation final
auto zeroVal = emitZeroDirect(val.getSwiftType(), loc);
return zeroVal;
}
case AdjointValueKind::Concrete:
return val.getConcreteValue();
case AdjointValueKind::Aggregate:
case AdjointValueKind::AddElement:
llvm_unreachable(
"Tuples and structs are not supported in forward mode yet.");
case AdjointValueKind::Concrete:
return val.getConcreteValue();
}
llvm_unreachable("Invalid adjoint value kind"); // silences MSVC C4715
}
Expand Down
Loading