Skip to content

[AutoDiff] Clean up differentiation-related attributes. #28466

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 25, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 23 additions & 32 deletions include/swift/AST/Attr.h
Original file line number Diff line number Diff line change
Expand Up @@ -1570,9 +1570,8 @@ class DifferentiableAttr final
/// has a where clause.
GenericSignature DerivativeGenericSignature = GenericSignature();

explicit DifferentiableAttr(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
bool linear,
explicit DifferentiableAttr(bool implicit, SourceLoc atLoc,
SourceRange baseRange, bool linear,
ArrayRef<ParsedAutoDiffParameter> parameters,
Optional<DeclNameWithLoc> jvp,
Optional<DeclNameWithLoc> vjp,
Expand Down Expand Up @@ -1683,38 +1682,31 @@ class DifferentiatingAttr final
DeclNameWithLoc Original;
/// The original function, resolved by the type checker.
FuncDecl *OriginalFunction = nullptr;
/// Whether this function is linear.
bool Linear;
/// The number of parsed parameters specified in 'wrt:'.
unsigned NumParsedParameters = 0;
/// The differentiation parameters' indices, resolved by the type checker.
IndexSubset *ParameterIndices = nullptr;

explicit DifferentiatingAttr(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
DeclNameWithLoc original, bool linear,
explicit DifferentiatingAttr(bool implicit, SourceLoc atLoc,
SourceRange baseRange, DeclNameWithLoc original,
ArrayRef<ParsedAutoDiffParameter> params);

explicit DifferentiatingAttr(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
DeclNameWithLoc original, bool linear,
explicit DifferentiatingAttr(bool implicit, SourceLoc atLoc,
SourceRange baseRange, DeclNameWithLoc original,
IndexSubset *indices);

public:
static DifferentiatingAttr *create(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
DeclNameWithLoc original, bool linear,
DeclNameWithLoc original,
ArrayRef<ParsedAutoDiffParameter> params);

static DifferentiatingAttr *create(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
DeclNameWithLoc original, bool linear,
DeclNameWithLoc original,
IndexSubset *indices);

DeclNameWithLoc getOriginal() const { return Original; }

bool isLinear() const { return Linear; }

FuncDecl *getOriginalFunction() const { return OriginalFunction; }
void setOriginalFunction(FuncDecl *decl) { OriginalFunction = decl; }

Expand Down Expand Up @@ -1765,34 +1757,33 @@ class TransposingAttr final
unsigned NumParsedParameters = 0;
/// The differentiation parameters' indices, resolved by the type checker.
IndexSubset *ParameterIndices = nullptr;
explicit TransposingAttr(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
TypeRepr *baseType, DeclNameWithLoc original,

explicit TransposingAttr(bool implicit, SourceLoc atLoc,
SourceRange baseRange, TypeRepr *baseType,
DeclNameWithLoc original,
ArrayRef<ParsedAutoDiffParameter> params);

explicit TransposingAttr(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
TypeRepr *baseType, DeclNameWithLoc original,
IndexSubset *indices);


explicit TransposingAttr(bool implicit, SourceLoc atLoc,
SourceRange baseRange, TypeRepr *baseType,
DeclNameWithLoc original, IndexSubset *indices);

public:
static TransposingAttr *create(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
TypeRepr *baseType, DeclNameWithLoc original,
ArrayRef<ParsedAutoDiffParameter> params);

static TransposingAttr *create(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
TypeRepr *baseType, DeclNameWithLoc original,
IndexSubset *indices);

TypeRepr *getBaseType() const { return BaseType; }
DeclNameWithLoc getOriginal() const { return Original; }

FuncDecl *getOriginalFunction() const { return OriginalFunction; }
void setOriginalFunction(FuncDecl *decl) { OriginalFunction = decl; }

/// The parsed transposing parameters, i.e. the list of parameters
/// specified in 'wrt:'.
ArrayRef<ParsedAutoDiffParameter> getParsedParameters() const {
Expand All @@ -1804,14 +1795,14 @@ class TransposingAttr final
size_t numTrailingObjects(OverloadToken<ParsedAutoDiffParameter>) const {
return NumParsedParameters;
}

IndexSubset *getParameterIndices() const {
return ParameterIndices;
}
void setParameterIndices(IndexSubset *pi) {
ParameterIndices = pi;
}

static bool classof(const DeclAttribute *DA) {
return DA->getKind() == DAK_Transposing;
}
Expand Down
8 changes: 4 additions & 4 deletions include/swift/AST/DiagnosticsParse.def
Original file line number Diff line number Diff line change
Expand Up @@ -1557,17 +1557,17 @@ ERROR(attr_differentiable_expected_parameter_list,PointsToFirstBadToken,
"expected a list of parameters to differentiate with respect to", ())
ERROR(attr_differentiable_use_wrt_not_withrespectto,none,
"use 'wrt:' to specify parameters to differentiate with respect to", ())
ERROR(attr_differentiable_missing_label,PointsToFirstBadToken,
"missing label '%0:' in '@differentiable' attribute", (StringRef))
ERROR(attr_differentiable_expected_label,none,
"expected either 'wrt:' or a function specifier label, e.g. 'jvp:', "
"or 'vjp:'", ())

// differentiating
ERROR(attr_differentiating_expected_original_name,PointsToFirstBadToken,
"expected an original function name", ())
ERROR(attr_differentiating_expected_label_linear_or_wrt,none,
"expected either 'linear' or 'wrt:'", ())
ERROR(attr_missing_label,PointsToFirstBadToken,
"missing label '%0:' in '@%1' attribute", (StringRef, StringRef))
ERROR(attr_expected_label,none,
"expected label '%0:' in '@%1' attribute", (StringRef, StringRef))

// transposing
ERROR(attr_transposing_expected_original_name,PointsToFirstBadToken,
Expand Down
3 changes: 2 additions & 1 deletion include/swift/Parse/Parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -1007,7 +1007,8 @@ class Parser {
/// Parse the @differentiating attribute.
ParserResult<DifferentiatingAttr>
parseDifferentiatingAttribute(SourceLoc AtLoc, SourceLoc Loc);


/// Parse the @transposing attribute.
ParserResult<TransposingAttr> parseTransposingAttribute(SourceLoc AtLoc,
SourceLoc Loc);

Expand Down
71 changes: 33 additions & 38 deletions lib/AST/Attr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1451,9 +1451,8 @@ SpecializeAttr *SpecializeAttr::create(ASTContext &Ctx, SourceLoc atLoc,


// SWIFT_ENABLE_TENSORFLOW
DifferentiableAttr::DifferentiableAttr(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
bool linear,
DifferentiableAttr::DifferentiableAttr(bool implicit, SourceLoc atLoc,
SourceRange baseRange, bool linear,
ArrayRef<ParsedAutoDiffParameter> params,
Optional<DeclNameWithLoc> jvp,
Optional<DeclNameWithLoc> vjp,
Expand Down Expand Up @@ -1488,8 +1487,8 @@ DifferentiableAttr::create(ASTContext &context, bool implicit,
TrailingWhereClause *clause) {
unsigned size = totalSizeToAlloc<ParsedAutoDiffParameter>(parameters.size());
void *mem = context.Allocate(size, alignof(DifferentiableAttr));
return new (mem) DifferentiableAttr(context, implicit, atLoc, baseRange,
linear, parameters, std::move(jvp),
return new (mem) DifferentiableAttr(implicit, atLoc, baseRange, linear,
parameters, std::move(jvp),
std::move(vjp), clause);
}

Expand Down Expand Up @@ -1570,48 +1569,45 @@ void DifferentiableAttr::print(llvm::raw_ostream &OS, const Decl *D,

// SWIFT_ENABLE_TENSORFLOW
DifferentiatingAttr::DifferentiatingAttr(
ASTContext &context, bool implicit, SourceLoc atLoc, SourceRange baseRange,
DeclNameWithLoc original, bool linear,
ArrayRef<ParsedAutoDiffParameter> params)
bool implicit, SourceLoc atLoc, SourceRange baseRange,
DeclNameWithLoc original, ArrayRef<ParsedAutoDiffParameter> params)
: DeclAttribute(DAK_Differentiating, atLoc, baseRange, implicit),
Original(std::move(original)), Linear(linear),
NumParsedParameters(params.size()) {
Original(std::move(original)), NumParsedParameters(params.size()) {
std::copy(params.begin(), params.end(),
getTrailingObjects<ParsedAutoDiffParameter>());
}

DifferentiatingAttr::DifferentiatingAttr(
ASTContext &context, bool implicit, SourceLoc atLoc, SourceRange baseRange,
DeclNameWithLoc original, bool linear, IndexSubset *indices)
DifferentiatingAttr::DifferentiatingAttr(bool implicit, SourceLoc atLoc,
SourceRange baseRange,
DeclNameWithLoc original,
IndexSubset *indices)
: DeclAttribute(DAK_Differentiating, atLoc, baseRange, implicit),
Original(std::move(original)), Linear(linear), ParameterIndices(indices) {
}
Original(std::move(original)), ParameterIndices(indices) {}

DifferentiatingAttr *
DifferentiatingAttr::create(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
DeclNameWithLoc original, bool linear,
DifferentiatingAttr::create(ASTContext &context, bool implicit, SourceLoc atLoc,
SourceRange baseRange, DeclNameWithLoc original,
ArrayRef<ParsedAutoDiffParameter> params) {
unsigned size = totalSizeToAlloc<ParsedAutoDiffParameter>(params.size());
void *mem = context.Allocate(size, alignof(DifferentiatingAttr));
return new (mem) DifferentiatingAttr(context, implicit, atLoc, baseRange,
std::move(original), linear, params);
return new (mem) DifferentiatingAttr(implicit, atLoc, baseRange,
std::move(original), params);
}

DifferentiatingAttr *
DifferentiatingAttr::create(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
DeclNameWithLoc original, bool linear,
IndexSubset *indices) {
DifferentiatingAttr *DifferentiatingAttr::create(ASTContext &context,
bool implicit, SourceLoc atLoc,
SourceRange baseRange,
DeclNameWithLoc original,
IndexSubset *indices) {
void *mem = context.Allocate(sizeof(DifferentiatingAttr),
alignof(DifferentiatingAttr));
return new (mem) DifferentiatingAttr(context, implicit, atLoc, baseRange,
std::move(original), linear, indices);
return new (mem) DifferentiatingAttr(implicit, atLoc, baseRange,
std::move(original), indices);
}

TransposingAttr::TransposingAttr(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
TypeRepr *baseType, DeclNameWithLoc original,
TransposingAttr::TransposingAttr(bool implicit, SourceLoc atLoc,
SourceRange baseRange, TypeRepr *baseType,
DeclNameWithLoc original,
ArrayRef<ParsedAutoDiffParameter> params)
: DeclAttribute(DAK_Transposing, atLoc, baseRange, implicit),
BaseType(baseType), Original(std::move(original)),
Expand All @@ -1620,10 +1616,9 @@ TransposingAttr::TransposingAttr(ASTContext &context, bool implicit,
getTrailingObjects<ParsedAutoDiffParameter>());
}

TransposingAttr::TransposingAttr(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
TypeRepr *baseType, DeclNameWithLoc original,
IndexSubset *indices)
TransposingAttr::TransposingAttr(bool implicit, SourceLoc atLoc,
SourceRange baseRange, TypeRepr *baseType,
DeclNameWithLoc original, IndexSubset *indices)
: DeclAttribute(DAK_Transposing, atLoc, baseRange, implicit),
BaseType(baseType), Original(std::move(original)),
ParameterIndices(indices) {}
Expand All @@ -1635,8 +1630,8 @@ TransposingAttr::create(ASTContext &context, bool implicit, SourceLoc atLoc,
ArrayRef<ParsedAutoDiffParameter> params) {
unsigned size = totalSizeToAlloc<ParsedAutoDiffParameter>(params.size());
void *mem = context.Allocate(size, alignof(TransposingAttr));
return new (mem) TransposingAttr(context, implicit, atLoc, baseRange,
baseType, std::move(original), params);
return new (mem) TransposingAttr(implicit, atLoc, baseRange, baseType,
std::move(original), params);
}

TransposingAttr *
Expand All @@ -1646,8 +1641,8 @@ TransposingAttr::create(ASTContext &context, bool implicit, SourceLoc atLoc,
IndexSubset *indices) {
void *mem =
context.Allocate(sizeof(TransposingAttr), alignof(TransposingAttr));
return new (mem) TransposingAttr(context, implicit, atLoc, baseRange,
baseType, std::move(original), indices);
return new (mem) TransposingAttr(implicit, atLoc, baseRange, baseType,
std::move(original), indices);
}

ImplementsAttr::ImplementsAttr(SourceLoc atLoc, SourceRange range,
Expand Down
Loading