Skip to content

[AutoDiff] Support '@differentiable(linear)' function conversion. #27687

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 15, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/swift/AST/DiagnosticsSIL.def
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,9 @@ ERROR(constexpr_imported_func_not_onone, none, "imported constant evaluable "
// Automatic differentiation diagnostics
ERROR(autodiff_internal_swift_not_imported,none,
"AD internal error: the Swift module is not imported", ())
ERROR(autodiff_conversion_to_linear_function_not_supported,none,
"conversion to '@differentiable(linear)' function type is not yet "
"supported", ())
ERROR(autodiff_function_not_differentiable_error,none,
"function is not differentiable", ())
ERROR(autodiff_expression_not_differentiable_error,none,
Expand Down
9 changes: 7 additions & 2 deletions include/swift/AST/DiagnosticsSema.def
Original file line number Diff line number Diff line change
Expand Up @@ -1177,9 +1177,14 @@ ERROR(c_function_pointer_from_generic_function,none,
"a C function pointer cannot be formed from a reference to a generic "
"function", ())
// SWIFT_ENABLE_TENSORFLOW
// TODO(TF-908): Remove this diagnostic once linear-to-differentiable conversion
// is supported.
ERROR(unsupported_linear_to_differentiable_conversion,none,
"conversion from '@differentiable(linear)' to '@differentiable' is not "
"yet supported", ())
ERROR(invalid_differentiable_function_conversion_expr,none,
"a '@differentiable' function can only be formed from a reference to a "
"'func' or a literal closure", ())
"a '@differentiable%select{|(linear)}0' function can only be formed from "
"a reference to a 'func' or a literal closure", (bool))
NOTE(invalid_differentiable_function_conversion_parameter,none,
"did you mean to take a '%0' closure?", (StringRef))
ERROR(invalid_autoclosure_forwarding,none,
Expand Down
32 changes: 32 additions & 0 deletions include/swift/AST/Expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -3034,6 +3034,16 @@ class DifferentiableFunctionExpr : public ImplicitConversionExpr {
}
};

class LinearFunctionExpr : public ImplicitConversionExpr {
public:
LinearFunctionExpr(Expr *subExpr, Type ty)
: ImplicitConversionExpr(ExprKind::LinearFunction, subExpr, ty) {}

static bool classof(const Expr *E) {
return E->getKind() == ExprKind::LinearFunction;
}
};

class DifferentiableFunctionExtractOriginalExpr
: public ImplicitConversionExpr {
public:
Expand All @@ -3045,6 +3055,28 @@ class DifferentiableFunctionExtractOriginalExpr
return E->getKind() == ExprKind::DifferentiableFunctionExtractOriginal;
}
};

class LinearFunctionExtractOriginalExpr : public ImplicitConversionExpr {
public:
LinearFunctionExtractOriginalExpr(Expr *subExpr, Type ty)
: ImplicitConversionExpr(ExprKind::LinearFunctionExtractOriginal,
subExpr, ty) {}

static bool classof(const Expr *E) {
return E->getKind() == ExprKind::LinearFunctionExtractOriginal;
}
};

class LinearToDifferentiableFunctionExpr : public ImplicitConversionExpr {
public:
LinearToDifferentiableFunctionExpr(Expr *subExpr, Type ty)
: ImplicitConversionExpr(
ExprKind::LinearToDifferentiableFunction, subExpr, ty) {}

static bool classof(const Expr *E) {
return E->getKind() == ExprKind::LinearToDifferentiableFunction;
}
};
// SWIFT_ENABLE_TENSORFLOW END

/// Use an opaque type to abstract a value of the underlying concrete type.
Expand Down
5 changes: 4 additions & 1 deletion include/swift/AST/ExprNodes.def
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,11 @@ ABSTRACT_EXPR(ImplicitConversion, Expr)
EXPR(UnderlyingToOpaque, ImplicitConversionExpr)
// SWIFT_ENABLE_TENSORFLOW
EXPR(DifferentiableFunction, ImplicitConversionExpr)
EXPR(LinearFunction, ImplicitConversionExpr)
EXPR(DifferentiableFunctionExtractOriginal, ImplicitConversionExpr)
EXPR_RANGE(ImplicitConversion, Load, DifferentiableFunctionExtractOriginal)
EXPR(LinearFunctionExtractOriginal, ImplicitConversionExpr)
EXPR(LinearToDifferentiableFunction, ImplicitConversionExpr)
EXPR_RANGE(ImplicitConversion, Load, LinearToDifferentiableFunction)
// SWIFT_ENABLE_TENSORFLOW END
ABSTRACT_EXPR(ExplicitCast, Expr)
ABSTRACT_EXPR(CheckedCast, ExplicitCastExpr)
Expand Down
2 changes: 1 addition & 1 deletion include/swift/SIL/SILBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ class SILBuilder {

LinearFunctionInst *createLinearFunction(
SILLocation Loc, IndexSubset *ParameterIndices, SILValue OriginalFunction,
Optional<SILValue> TransposeFunction) {
Optional<SILValue> TransposeFunction = None) {
return insert(LinearFunctionInst::create(
getModule(), getSILDebugLocation(Loc), ParameterIndices,
OriginalFunction, TransposeFunction, hasOwnership()));
Expand Down
17 changes: 17 additions & 0 deletions lib/AST/ASTDumper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2401,12 +2401,29 @@ class PrintExpr : public ExprVisitor<PrintExpr> {
printRec(E->getSubExpr());
PrintWithColorRAII(OS, ParenthesisColor) << ')';
}
void visitLinearFunctionExpr(LinearFunctionExpr *E) {
printCommon(E, "linear_function") << '\n';
printRec(E->getSubExpr());
PrintWithColorRAII(OS, ParenthesisColor) << ')';
}
void visitDifferentiableFunctionExtractOriginalExpr(
DifferentiableFunctionExtractOriginalExpr *E) {
printCommon(E, "differentiable_function_extract_original") << '\n';
printRec(E->getSubExpr());
PrintWithColorRAII(OS, ParenthesisColor) << ')';
}
void visitLinearFunctionExtractOriginalExpr(
LinearFunctionExtractOriginalExpr *E) {
printCommon(E, "linear_function_extract_original") << '\n';
printRec(E->getSubExpr());
PrintWithColorRAII(OS, ParenthesisColor) << ')';
}
void visitLinearToDifferentiableFunctionExpr(
LinearToDifferentiableFunctionExpr *E) {
printCommon(E, "linear_to_differentiable_function") << '\n';
printRec(E->getSubExpr());
PrintWithColorRAII(OS, ParenthesisColor) << ')';
}
// SWIFT_ENABLE_TENSORFLOW END

void visitInOutExpr(InOutExpr *E) {
Expand Down
6 changes: 6 additions & 0 deletions lib/AST/Expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,10 @@ ConcreteDeclRef Expr::getReferencedDecl() const {
PASS_THROUGH_REFERENCE(UnevaluatedInstance, getSubExpr);
// SWIFT_ENABLE_TENSORFLOW
PASS_THROUGH_REFERENCE(DifferentiableFunction, getSubExpr);
PASS_THROUGH_REFERENCE(LinearFunction, getSubExpr);
PASS_THROUGH_REFERENCE(DifferentiableFunctionExtractOriginal, getSubExpr);
PASS_THROUGH_REFERENCE(LinearFunctionExtractOriginal, getSubExpr);
PASS_THROUGH_REFERENCE(LinearToDifferentiableFunction, getSubExpr);
// SWIFT_ENABLE_TENSORFLOW END
PASS_THROUGH_REFERENCE(BridgeToObjC, getSubExpr);
PASS_THROUGH_REFERENCE(BridgeFromObjC, getSubExpr);
Expand Down Expand Up @@ -678,7 +681,10 @@ bool Expr::canAppendPostfixExpression(bool appendingPostfixOperator) const {
case ExprKind::UnevaluatedInstance:
// SWIFT_ENABLE_TENSORFLOW
case ExprKind::DifferentiableFunction:
case ExprKind::LinearFunction:
case ExprKind::DifferentiableFunctionExtractOriginal:
case ExprKind::LinearFunctionExtractOriginal:
case ExprKind::LinearToDifferentiableFunction:
// SWIFT_ENABLE_TENSORFLOW END
case ExprKind::EnumIsCase:
case ExprKind::ConditionalBridgeFromObjC:
Expand Down
8 changes: 6 additions & 2 deletions lib/SIL/OwnershipUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ bool swift::isOwnershipForwardingValueKind(SILNodeKind kind) {
case SILNodeKind::DestructureTupleInst:
// SWIFT_ENABLE_TENSORFLOW
case SILNodeKind::DifferentiableFunctionInst:
case SILNodeKind::DifferentiableFunctionExtractInst:
// SWIFT_ENABLE_TENSORFLOW
case SILNodeKind::LinearFunctionInst:
// SWIFT_ENABLE_TENSORFLOW END
return true;
default:
return false;
Expand All @@ -62,6 +62,10 @@ bool swift::isGuaranteedForwardingValueKind(SILNodeKind kind) {
case SILNodeKind::StructExtractInst:
case SILNodeKind::OpenExistentialValueInst:
case SILNodeKind::OpenExistentialBoxValueInst:
// SWIFT_ENABLE_TENSORFLOW
case SILNodeKind::DifferentiableFunctionExtractInst:
case SILNodeKind::LinearFunctionExtractInst:
// SWIFT_ENABLE_TENSORFLOW END
return true;
default:
return isOwnershipForwardingValueKind(kind);
Expand Down
31 changes: 31 additions & 0 deletions lib/SILGen/SILGenExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -506,8 +506,14 @@ namespace {
// SWIFT_ENABLE_TENSORFLOW
RValue visitDifferentiableFunctionExpr(DifferentiableFunctionExpr *E,
SGFContext C);
RValue visitLinearFunctionExpr(LinearFunctionExpr *E, SGFContext C);
RValue visitDifferentiableFunctionExtractOriginalExpr(
DifferentiableFunctionExtractOriginalExpr *E, SGFContext C);
RValue visitLinearFunctionExtractOriginalExpr(
LinearFunctionExtractOriginalExpr *E, SGFContext C);
RValue visitLinearToDifferentiableFunctionExpr(
LinearToDifferentiableFunctionExpr *E, SGFContext C);
// SWIFT_ENABLE_TENSORFLOW END
};
} // end anonymous namespace

Expand Down Expand Up @@ -5436,6 +5442,15 @@ RValue RValueEmitter::visitDifferentiableFunctionExpr(
return RValue(SGF, E, SGF.emitManagedRValueWithCleanup(diffFunc));
}

RValue RValueEmitter::visitLinearFunctionExpr(
LinearFunctionExpr *E, SGFContext C) {
auto origFunc = SGF.emitRValueAsSingleValue(E->getSubExpr());
auto destTy = SGF.getLoweredType(E->getType()).castTo<SILFunctionType>();
auto *diffFunc = SGF.B.createLinearFunction(
E, destTy->getDifferentiationParameterIndices(), origFunc.forward(SGF));
return RValue(SGF, E, SGF.emitManagedRValueWithCleanup(diffFunc));
}

RValue RValueEmitter::visitDifferentiableFunctionExtractOriginalExpr(
DifferentiableFunctionExtractOriginalExpr *E, SGFContext C) {
auto diffFunc = SGF.emitRValueAsSingleValue(E->getSubExpr());
Expand All @@ -5445,6 +5460,22 @@ RValue RValueEmitter::visitDifferentiableFunctionExtractOriginalExpr(
auto ownedOrigFunc = SGF.B.emitCopyValueOperation(E, borrowedOrigFunc);
return RValue(SGF, E, SGF.emitManagedRValueWithCleanup(ownedOrigFunc));
}

RValue RValueEmitter::visitLinearFunctionExtractOriginalExpr(
LinearFunctionExtractOriginalExpr *E, SGFContext C) {
auto diffFunc = SGF.emitRValueAsSingleValue(E->getSubExpr());
auto borrowedDiffFunc = diffFunc.borrow(SGF, E);
auto *borrowedOrigFunc = SGF.B.createLinearFunctionExtract(
E, LinearFunctionExtractee::Original, borrowedDiffFunc.getValue());
auto ownedOrigFunc = SGF.B.emitCopyValueOperation(E, borrowedOrigFunc);
return RValue(SGF, E, SGF.emitManagedRValueWithCleanup(ownedOrigFunc));
}

RValue RValueEmitter::visitLinearToDifferentiableFunctionExpr(
LinearToDifferentiableFunctionExpr *E, SGFContext C) {
// TODO: Implement this.
llvm_unreachable("Unsupported!");
}
// SWIFT_ENABLE_TENSORFLOW END

RValue RValueEmitter::visitTapExpr(TapExpr *E, SGFContext C) {
Expand Down
16 changes: 12 additions & 4 deletions lib/SILOptimizer/Mandatory/Differentiation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8773,6 +8773,8 @@ void Differentiation::run() {
// A global differentiation context.
ADContext context(*this);

bool errorOccurred = false;

// Register all `@differentiable` attributes and `differentiable_function`
// instructions in the module that trigger differentiation.
for (SILFunction &f : module) {
Expand All @@ -8783,10 +8785,18 @@ void Differentiation::run() {
context.getInvokers().insert({diffAttr, invoker});
continue;
}
for (SILBasicBlock &bb : f)
for (SILInstruction &i : bb)
for (SILBasicBlock &bb : f) {
for (SILInstruction &i : bb) {
if (auto *dfi = dyn_cast<DifferentiableFunctionInst>(&i))
context.getDifferentiableFunctionInsts().push_back(dfi);
else if (auto *lfi = dyn_cast<LinearFunctionInst>(&i)) {
astCtx.Diags.diagnose(
lfi->getLoc().getSourceLoc(),
diag::autodiff_conversion_to_linear_function_not_supported);
errorOccurred = true;
}
}
}
}

// If nothing has triggered differentiation, there's nothing to do.
Expand All @@ -8802,8 +8812,6 @@ void Differentiation::run() {
return;
}

bool errorOccurred = false;

// Process all `[differentiable]` attributes.
for (auto invokerPair : context.getInvokers()) {
auto *attr = invokerPair.first;
Expand Down
74 changes: 57 additions & 17 deletions lib/Sema/CSApply.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5883,9 +5883,21 @@ maybeDiagnoseUnsupportedDifferentiableConversion(ConstraintSystem &cs,
auto &tc = cs.getTypeChecker();
Type fromType = cs.getType(expr);
auto fromFnType = fromType->getAs<AnyFunctionType>();
auto isToTypeLinear =
toType->getDifferentiabilityKind() == DifferentiabilityKind::Linear;
// Conversion from a `@differentiable` function to a `@differentiable(linear)`
// function is not allowed, because the from-expression will never be a
// closure expression or a declaration/member reference.
if (fromFnType->getDifferentiabilityKind() == DifferentiabilityKind::Normal &&
toType->getDifferentiabilityKind() == DifferentiabilityKind::Linear) {
tc.diagnose(expr->getLoc(),
diag::invalid_differentiable_function_conversion_expr,
isToTypeLinear);
return;
}
// Conversion from a non-`@differentiable` function to a `@differentiable` is
// only allowed from a closure expression or a declaration/member reference.
if (toType->isDifferentiable() && !fromFnType->isDifferentiable()) {
if (!fromFnType->isDifferentiable() && toType->isDifferentiable()) {
auto maybeDiagnoseFunctionRef = [&](Expr *semanticExpr) {
if (auto *capture = dyn_cast<CaptureListExpr>(semanticExpr))
semanticExpr = capture->getClosureBody();
Expand All @@ -5897,20 +5909,15 @@ maybeDiagnoseUnsupportedDifferentiableConversion(ConstraintSystem &cs,
// note with a fix-it.
if (auto *paramDecl = dyn_cast<ParamDecl>(declRef->getDecl())) {
tc.diagnose(expr->getLoc(),
diag::invalid_differentiable_function_conversion_expr);
diag::invalid_differentiable_function_conversion_expr,
isToTypeLinear);
if (paramDecl->getType()->is<AnyFunctionType>()) {
auto *typeRepr = paramDecl->getTypeLoc().getTypeRepr();
while (auto *attributed = dyn_cast<AttributedTypeRepr>(typeRepr))
typeRepr = attributed->getTypeRepr();
std::string attributeString = "@differentiable";
switch (toType->getDifferentiabilityKind()) {
case DifferentiabilityKind::Linear:
if (isToTypeLinear)
attributeString += "(linear)";
break;
case DifferentiabilityKind::Normal:
case DifferentiabilityKind::NonDifferentiable:
break;
}
auto *funcTypeRepr = cast<FunctionTypeRepr>(typeRepr);
auto paramListLoc = funcTypeRepr->getArgsTypeRepr()->getStartLoc();
tc.diagnose(paramDecl->getLoc(),
Expand All @@ -5930,7 +5937,8 @@ maybeDiagnoseUnsupportedDifferentiableConversion(ConstraintSystem &cs,
return;
}
tc.diagnose(expr->getLoc(),
diag::invalid_differentiable_function_conversion_expr);
diag::invalid_differentiable_function_conversion_expr,
isToTypeLinear);
};
maybeDiagnoseFunctionRef(getSemanticExprForDeclOrMemberRef(expr));
}
Expand Down Expand Up @@ -6583,23 +6591,55 @@ Expr *ExprRewriter::coerceToType(Expr *expr, Type toType,

// SWIFT_ENABLE_TENSORFLOW
auto fromEI = fromFunc->getExtInfo();
auto isFromDifferentiable = fromEI.isDifferentiable();
auto isToDifferentiable = toEI.isDifferentiable();
// Handle implicit conversion from @differentiable.
if (fromEI.isDifferentiable() && !toEI.isDifferentiable()) {
if (isFromDifferentiable && !isToDifferentiable) {
fromFunc = fromFunc->getWithoutDifferentiability()
->castTo<FunctionType>();
expr = cs.cacheType(new (tc.Context)
DifferentiableFunctionExtractOriginalExpr(expr, fromFunc));
switch (fromEI.getDifferentiabilityKind()) {
case DifferentiabilityKind::Normal:
expr = cs.cacheType(new (tc.Context)
DifferentiableFunctionExtractOriginalExpr(expr, fromFunc));
break;
case DifferentiabilityKind::Linear:
expr = cs.cacheType(new (tc.Context)
LinearFunctionExtractOriginalExpr(expr, fromFunc));
break;
case DifferentiabilityKind::NonDifferentiable:
llvm_unreachable("Cannot be NonDifferentiable");
}
}
// Handle implicit conversion to @differentiable.
// Handle implicit conversion from @differentiable(linear) to
// @differentiable.
else if (fromEI.getDifferentiabilityKind() ==
DifferentiabilityKind::Linear &&
toEI.getDifferentiabilityKind() == DifferentiabilityKind::Normal) {
// TODO(TF-908): Create a `LinearToDifferentiableFunctionExpr` and SILGen
// it as thunk application. Remove the diagnostic.
tc.diagnose(expr->getLoc(),
diag::unsupported_linear_to_differentiable_conversion);
}
// Handle implicit conversion from non-@differentiable to @differentiable.
maybeDiagnoseUnsupportedDifferentiableConversion(cs, expr, toFunc);
if (!fromEI.isDifferentiable() && toEI.isDifferentiable()) {
if (!isFromDifferentiable && isToDifferentiable) {
auto newEI =
fromEI.withDifferentiabilityKind(toEI.getDifferentiabilityKind());
fromFunc = FunctionType::get(toFunc->getParams(), fromFunc->getResult())
->withExtInfo(newEI)
->castTo<FunctionType>();
expr = cs.cacheType(new (tc.Context)
DifferentiableFunctionExpr(expr, fromFunc));
switch (toEI.getDifferentiabilityKind()) {
case DifferentiabilityKind::Normal:
expr = cs.cacheType(new (tc.Context)
DifferentiableFunctionExpr(expr, fromFunc));
break;
case DifferentiabilityKind::Linear:
expr = cs.cacheType(new (tc.Context)
LinearFunctionExpr(expr, fromFunc));
break;
case DifferentiabilityKind::NonDifferentiable:
llvm_unreachable("Cannot be NonDifferentiable");
}
}

// If we have a ClosureExpr, then we can safely propagate the 'no escape'
Expand Down
11 changes: 11 additions & 0 deletions test/AutoDiff/differentiable_func_type_type_checking.swift
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,17 @@ extension Float {
_ = gradient(of: Float.addOne) // okay
_ = gradient(of: Float(1.0).addOne) // okay

// TODO(TF-908): Remove this test once linear-to-differentiable conversion is supported.
func linearToDifferentiable(_ f: @escaping @differentiable(linear) (Float) -> Float) {
// expected-error @+1 {{conversion from '@differentiable(linear)' to '@differentiable' is not yet supported}}
_ = f as @differentiable (Float) -> Float
}

func differentiableToLinear(_ f: @escaping @differentiable (Float) -> Float) {
// expected-error @+1 {{a '@differentiable(linear)' function can only be formed from a reference to a 'func' or a literal closure}}
_ = f as @differentiable(linear) (Float) -> Float
}

//===----------------------------------------------------------------------===//
// Parameter selection (@nondiff)
//===----------------------------------------------------------------------===//
Expand Down
Loading