Skip to content

Commit 504b794

Browse files
authored
[AutoDiff] Support '@differentiable(linear)' function conversion. (#27687)
- Support the following implicit conversions: - Non-differentiable to `@differentiable(linear)` - Sema: Emit `LinearFunctionExpr`. - SILGen: Lower `LinearFunctionExpr` to `linear_function`. - `@differentiable(linear)` to non-differentiable - Sema: Emit `LinearFunctionExtractOriginalExpr`. - SILGen: Lower `LinearFunctionExtractOriginalExpr` to `linear_function_extract [original]`. - Reject the following implicit conversions: - `@differentiable` to `@differentiable(linear)` - This conversion is not rejected because a `@differentiable` function can never come directly from a closure expression or from a declaration/member reference. - `@differentiable(linear)` to `@differentiable` ([TF-908](https://bugs.swift.org/browse/TF-908)) - This is supported by design, but is not yet implemented due to its complexity. This requires thunking `@differentiable(linear)` to a derivative (JVP) function where the derivative function returns the original result and the same linear map with `@nondiff` parameters partial-applied away. - Properly handle `linear_function`, `differentiable_function_extract`, and `linear_function_extract` in `swift::isGuaranteedForwardingValueKind` and `swift::isOwnershipForwardingValueKind`. This fixed a previously uncaught assertion, which was because `differentiable_function_extract` was in `swift::isOwnershipForwardingValueKind` while it should really be in `swift::isGuaranteedForwardingValueKind`. For complete conversion rules among `@differentiable` functions, see [Differentiable Programming Mega-Proposal - Type conversion](https://github.com/dan-zheng/swift/blob/differentiable-programming/docs/DifferentiableProgramming.md#type-conversion). Resolves [TF-900](https://bugs.swift.org/browse/TF-900).
1 parent bb67311 commit 504b794

15 files changed

+218
-50
lines changed

include/swift/AST/DiagnosticsSIL.def

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,9 @@ ERROR(constexpr_imported_func_not_onone, none, "imported constant evaluable "
445445
// Automatic differentiation diagnostics
446446
ERROR(autodiff_internal_swift_not_imported,none,
447447
"AD internal error: the Swift module is not imported", ())
448+
ERROR(autodiff_conversion_to_linear_function_not_supported,none,
449+
"conversion to '@differentiable(linear)' function type is not yet "
450+
"supported", ())
448451
ERROR(autodiff_function_not_differentiable_error,none,
449452
"function is not differentiable", ())
450453
ERROR(autodiff_expression_not_differentiable_error,none,

include/swift/AST/DiagnosticsSema.def

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1177,9 +1177,14 @@ ERROR(c_function_pointer_from_generic_function,none,
11771177
"a C function pointer cannot be formed from a reference to a generic "
11781178
"function", ())
11791179
// SWIFT_ENABLE_TENSORFLOW
1180+
// TODO(TF-908): Remove this diagnostic once linear-to-differentiable conversion
1181+
// is supported.
1182+
ERROR(unsupported_linear_to_differentiable_conversion,none,
1183+
"conversion from '@differentiable(linear)' to '@differentiable' is not "
1184+
"yet supported", ())
11801185
ERROR(invalid_differentiable_function_conversion_expr,none,
1181-
"a '@differentiable' function can only be formed from a reference to a "
1182-
"'func' or a literal closure", ())
1186+
"a '@differentiable%select{|(linear)}0' function can only be formed from "
1187+
"a reference to a 'func' or a literal closure", (bool))
11831188
NOTE(invalid_differentiable_function_conversion_parameter,none,
11841189
"did you mean to take a '%0' closure?", (StringRef))
11851190
ERROR(invalid_autoclosure_forwarding,none,

include/swift/AST/Expr.h

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3034,6 +3034,16 @@ class DifferentiableFunctionExpr : public ImplicitConversionExpr {
30343034
}
30353035
};
30363036

3037+
class LinearFunctionExpr : public ImplicitConversionExpr {
3038+
public:
3039+
LinearFunctionExpr(Expr *subExpr, Type ty)
3040+
: ImplicitConversionExpr(ExprKind::LinearFunction, subExpr, ty) {}
3041+
3042+
static bool classof(const Expr *E) {
3043+
return E->getKind() == ExprKind::LinearFunction;
3044+
}
3045+
};
3046+
30373047
class DifferentiableFunctionExtractOriginalExpr
30383048
: public ImplicitConversionExpr {
30393049
public:
@@ -3045,6 +3055,28 @@ class DifferentiableFunctionExtractOriginalExpr
30453055
return E->getKind() == ExprKind::DifferentiableFunctionExtractOriginal;
30463056
}
30473057
};
3058+
3059+
class LinearFunctionExtractOriginalExpr : public ImplicitConversionExpr {
3060+
public:
3061+
LinearFunctionExtractOriginalExpr(Expr *subExpr, Type ty)
3062+
: ImplicitConversionExpr(ExprKind::LinearFunctionExtractOriginal,
3063+
subExpr, ty) {}
3064+
3065+
static bool classof(const Expr *E) {
3066+
return E->getKind() == ExprKind::LinearFunctionExtractOriginal;
3067+
}
3068+
};
3069+
3070+
class LinearToDifferentiableFunctionExpr : public ImplicitConversionExpr {
3071+
public:
3072+
LinearToDifferentiableFunctionExpr(Expr *subExpr, Type ty)
3073+
: ImplicitConversionExpr(
3074+
ExprKind::LinearToDifferentiableFunction, subExpr, ty) {}
3075+
3076+
static bool classof(const Expr *E) {
3077+
return E->getKind() == ExprKind::LinearToDifferentiableFunction;
3078+
}
3079+
};
30483080
// SWIFT_ENABLE_TENSORFLOW END
30493081

30503082
/// Use an opaque type to abstract a value of the underlying concrete type.

include/swift/AST/ExprNodes.def

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,11 @@ ABSTRACT_EXPR(ImplicitConversion, Expr)
174174
EXPR(UnderlyingToOpaque, ImplicitConversionExpr)
175175
// SWIFT_ENABLE_TENSORFLOW
176176
EXPR(DifferentiableFunction, ImplicitConversionExpr)
177+
EXPR(LinearFunction, ImplicitConversionExpr)
177178
EXPR(DifferentiableFunctionExtractOriginal, ImplicitConversionExpr)
178-
EXPR_RANGE(ImplicitConversion, Load, DifferentiableFunctionExtractOriginal)
179+
EXPR(LinearFunctionExtractOriginal, ImplicitConversionExpr)
180+
EXPR(LinearToDifferentiableFunction, ImplicitConversionExpr)
181+
EXPR_RANGE(ImplicitConversion, Load, LinearToDifferentiableFunction)
179182
// SWIFT_ENABLE_TENSORFLOW END
180183
ABSTRACT_EXPR(ExplicitCast, Expr)
181184
ABSTRACT_EXPR(CheckedCast, ExplicitCastExpr)

include/swift/SIL/SILBuilder.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -521,7 +521,7 @@ class SILBuilder {
521521

522522
LinearFunctionInst *createLinearFunction(
523523
SILLocation Loc, IndexSubset *ParameterIndices, SILValue OriginalFunction,
524-
Optional<SILValue> TransposeFunction) {
524+
Optional<SILValue> TransposeFunction = None) {
525525
return insert(LinearFunctionInst::create(
526526
getModule(), getSILDebugLocation(Loc), ParameterIndices,
527527
OriginalFunction, TransposeFunction, hasOwnership()));

lib/AST/ASTDumper.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2401,12 +2401,29 @@ class PrintExpr : public ExprVisitor<PrintExpr> {
24012401
printRec(E->getSubExpr());
24022402
PrintWithColorRAII(OS, ParenthesisColor) << ')';
24032403
}
2404+
void visitLinearFunctionExpr(LinearFunctionExpr *E) {
2405+
printCommon(E, "linear_function") << '\n';
2406+
printRec(E->getSubExpr());
2407+
PrintWithColorRAII(OS, ParenthesisColor) << ')';
2408+
}
24042409
void visitDifferentiableFunctionExtractOriginalExpr(
24052410
DifferentiableFunctionExtractOriginalExpr *E) {
24062411
printCommon(E, "differentiable_function_extract_original") << '\n';
24072412
printRec(E->getSubExpr());
24082413
PrintWithColorRAII(OS, ParenthesisColor) << ')';
24092414
}
2415+
void visitLinearFunctionExtractOriginalExpr(
2416+
LinearFunctionExtractOriginalExpr *E) {
2417+
printCommon(E, "linear_function_extract_original") << '\n';
2418+
printRec(E->getSubExpr());
2419+
PrintWithColorRAII(OS, ParenthesisColor) << ')';
2420+
}
2421+
void visitLinearToDifferentiableFunctionExpr(
2422+
LinearToDifferentiableFunctionExpr *E) {
2423+
printCommon(E, "linear_to_differentiable_function") << '\n';
2424+
printRec(E->getSubExpr());
2425+
PrintWithColorRAII(OS, ParenthesisColor) << ')';
2426+
}
24102427
// SWIFT_ENABLE_TENSORFLOW END
24112428

24122429
void visitInOutExpr(InOutExpr *E) {

lib/AST/Expr.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,10 @@ ConcreteDeclRef Expr::getReferencedDecl() const {
352352
PASS_THROUGH_REFERENCE(UnevaluatedInstance, getSubExpr);
353353
// SWIFT_ENABLE_TENSORFLOW
354354
PASS_THROUGH_REFERENCE(DifferentiableFunction, getSubExpr);
355+
PASS_THROUGH_REFERENCE(LinearFunction, getSubExpr);
355356
PASS_THROUGH_REFERENCE(DifferentiableFunctionExtractOriginal, getSubExpr);
357+
PASS_THROUGH_REFERENCE(LinearFunctionExtractOriginal, getSubExpr);
358+
PASS_THROUGH_REFERENCE(LinearToDifferentiableFunction, getSubExpr);
356359
// SWIFT_ENABLE_TENSORFLOW END
357360
PASS_THROUGH_REFERENCE(BridgeToObjC, getSubExpr);
358361
PASS_THROUGH_REFERENCE(BridgeFromObjC, getSubExpr);
@@ -678,7 +681,10 @@ bool Expr::canAppendPostfixExpression(bool appendingPostfixOperator) const {
678681
case ExprKind::UnevaluatedInstance:
679682
// SWIFT_ENABLE_TENSORFLOW
680683
case ExprKind::DifferentiableFunction:
684+
case ExprKind::LinearFunction:
681685
case ExprKind::DifferentiableFunctionExtractOriginal:
686+
case ExprKind::LinearFunctionExtractOriginal:
687+
case ExprKind::LinearToDifferentiableFunction:
682688
// SWIFT_ENABLE_TENSORFLOW END
683689
case ExprKind::EnumIsCase:
684690
case ExprKind::ConditionalBridgeFromObjC:

lib/SIL/OwnershipUtils.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ bool swift::isOwnershipForwardingValueKind(SILNodeKind kind) {
4646
case SILNodeKind::DestructureTupleInst:
4747
// SWIFT_ENABLE_TENSORFLOW
4848
case SILNodeKind::DifferentiableFunctionInst:
49-
case SILNodeKind::DifferentiableFunctionExtractInst:
50-
// SWIFT_ENABLE_TENSORFLOW
49+
case SILNodeKind::LinearFunctionInst:
50+
// SWIFT_ENABLE_TENSORFLOW END
5151
return true;
5252
default:
5353
return false;
@@ -62,6 +62,10 @@ bool swift::isGuaranteedForwardingValueKind(SILNodeKind kind) {
6262
case SILNodeKind::StructExtractInst:
6363
case SILNodeKind::OpenExistentialValueInst:
6464
case SILNodeKind::OpenExistentialBoxValueInst:
65+
// SWIFT_ENABLE_TENSORFLOW
66+
case SILNodeKind::DifferentiableFunctionExtractInst:
67+
case SILNodeKind::LinearFunctionExtractInst:
68+
// SWIFT_ENABLE_TENSORFLOW END
6569
return true;
6670
default:
6771
return isOwnershipForwardingValueKind(kind);

lib/SILGen/SILGenExpr.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -506,8 +506,14 @@ namespace {
506506
// SWIFT_ENABLE_TENSORFLOW
507507
RValue visitDifferentiableFunctionExpr(DifferentiableFunctionExpr *E,
508508
SGFContext C);
509+
RValue visitLinearFunctionExpr(LinearFunctionExpr *E, SGFContext C);
509510
RValue visitDifferentiableFunctionExtractOriginalExpr(
510511
DifferentiableFunctionExtractOriginalExpr *E, SGFContext C);
512+
RValue visitLinearFunctionExtractOriginalExpr(
513+
LinearFunctionExtractOriginalExpr *E, SGFContext C);
514+
RValue visitLinearToDifferentiableFunctionExpr(
515+
LinearToDifferentiableFunctionExpr *E, SGFContext C);
516+
// SWIFT_ENABLE_TENSORFLOW END
511517
};
512518
} // end anonymous namespace
513519

@@ -5436,6 +5442,15 @@ RValue RValueEmitter::visitDifferentiableFunctionExpr(
54365442
return RValue(SGF, E, SGF.emitManagedRValueWithCleanup(diffFunc));
54375443
}
54385444

5445+
RValue RValueEmitter::visitLinearFunctionExpr(
5446+
LinearFunctionExpr *E, SGFContext C) {
5447+
auto origFunc = SGF.emitRValueAsSingleValue(E->getSubExpr());
5448+
auto destTy = SGF.getLoweredType(E->getType()).castTo<SILFunctionType>();
5449+
auto *diffFunc = SGF.B.createLinearFunction(
5450+
E, destTy->getDifferentiationParameterIndices(), origFunc.forward(SGF));
5451+
return RValue(SGF, E, SGF.emitManagedRValueWithCleanup(diffFunc));
5452+
}
5453+
54395454
RValue RValueEmitter::visitDifferentiableFunctionExtractOriginalExpr(
54405455
DifferentiableFunctionExtractOriginalExpr *E, SGFContext C) {
54415456
auto diffFunc = SGF.emitRValueAsSingleValue(E->getSubExpr());
@@ -5445,6 +5460,22 @@ RValue RValueEmitter::visitDifferentiableFunctionExtractOriginalExpr(
54455460
auto ownedOrigFunc = SGF.B.emitCopyValueOperation(E, borrowedOrigFunc);
54465461
return RValue(SGF, E, SGF.emitManagedRValueWithCleanup(ownedOrigFunc));
54475462
}
5463+
5464+
RValue RValueEmitter::visitLinearFunctionExtractOriginalExpr(
5465+
LinearFunctionExtractOriginalExpr *E, SGFContext C) {
5466+
auto diffFunc = SGF.emitRValueAsSingleValue(E->getSubExpr());
5467+
auto borrowedDiffFunc = diffFunc.borrow(SGF, E);
5468+
auto *borrowedOrigFunc = SGF.B.createLinearFunctionExtract(
5469+
E, LinearFunctionExtractee::Original, borrowedDiffFunc.getValue());
5470+
auto ownedOrigFunc = SGF.B.emitCopyValueOperation(E, borrowedOrigFunc);
5471+
return RValue(SGF, E, SGF.emitManagedRValueWithCleanup(ownedOrigFunc));
5472+
}
5473+
5474+
RValue RValueEmitter::visitLinearToDifferentiableFunctionExpr(
5475+
LinearToDifferentiableFunctionExpr *E, SGFContext C) {
5476+
// TODO: Implement this.
5477+
llvm_unreachable("Unsupported!");
5478+
}
54485479
// SWIFT_ENABLE_TENSORFLOW END
54495480

54505481
RValue RValueEmitter::visitTapExpr(TapExpr *E, SGFContext C) {

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8782,6 +8782,8 @@ void Differentiation::run() {
87828782
// A global differentiation context.
87838783
ADContext context(*this);
87848784

8785+
bool errorOccurred = false;
8786+
87858787
// Register all `@differentiable` attributes and `differentiable_function`
87868788
// instructions in the module that trigger differentiation.
87878789
for (SILFunction &f : module) {
@@ -8792,10 +8794,18 @@ void Differentiation::run() {
87928794
context.getInvokers().insert({diffAttr, invoker});
87938795
continue;
87948796
}
8795-
for (SILBasicBlock &bb : f)
8796-
for (SILInstruction &i : bb)
8797+
for (SILBasicBlock &bb : f) {
8798+
for (SILInstruction &i : bb) {
87978799
if (auto *dfi = dyn_cast<DifferentiableFunctionInst>(&i))
87988800
context.getDifferentiableFunctionInsts().push_back(dfi);
8801+
else if (auto *lfi = dyn_cast<LinearFunctionInst>(&i)) {
8802+
astCtx.Diags.diagnose(
8803+
lfi->getLoc().getSourceLoc(),
8804+
diag::autodiff_conversion_to_linear_function_not_supported);
8805+
errorOccurred = true;
8806+
}
8807+
}
8808+
}
87998809
}
88008810

88018811
// If nothing has triggered differentiation, there's nothing to do.
@@ -8811,8 +8821,6 @@ void Differentiation::run() {
88118821
return;
88128822
}
88138823

8814-
bool errorOccurred = false;
8815-
88168824
// Process all `[differentiable]` attributes.
88178825
for (auto invokerPair : context.getInvokers()) {
88188826
auto *attr = invokerPair.first;

lib/Sema/CSApply.cpp

Lines changed: 57 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5883,9 +5883,21 @@ maybeDiagnoseUnsupportedDifferentiableConversion(ConstraintSystem &cs,
58835883
auto &tc = cs.getTypeChecker();
58845884
Type fromType = cs.getType(expr);
58855885
auto fromFnType = fromType->getAs<AnyFunctionType>();
5886+
auto isToTypeLinear =
5887+
toType->getDifferentiabilityKind() == DifferentiabilityKind::Linear;
5888+
// Conversion from a `@differentiable` function to a `@differentiable(linear)`
5889+
// function is not allowed, because the from-expression will never be a
5890+
// closure expression or a declaration/member reference.
5891+
if (fromFnType->getDifferentiabilityKind() == DifferentiabilityKind::Normal &&
5892+
toType->getDifferentiabilityKind() == DifferentiabilityKind::Linear) {
5893+
tc.diagnose(expr->getLoc(),
5894+
diag::invalid_differentiable_function_conversion_expr,
5895+
isToTypeLinear);
5896+
return;
5897+
}
58865898
// Conversion from a non-`@differentiable` function to a `@differentiable` is
58875899
// only allowed from a closure expression or a declaration/member reference.
5888-
if (toType->isDifferentiable() && !fromFnType->isDifferentiable()) {
5900+
if (!fromFnType->isDifferentiable() && toType->isDifferentiable()) {
58895901
auto maybeDiagnoseFunctionRef = [&](Expr *semanticExpr) {
58905902
if (auto *capture = dyn_cast<CaptureListExpr>(semanticExpr))
58915903
semanticExpr = capture->getClosureBody();
@@ -5897,20 +5909,15 @@ maybeDiagnoseUnsupportedDifferentiableConversion(ConstraintSystem &cs,
58975909
// note with a fix-it.
58985910
if (auto *paramDecl = dyn_cast<ParamDecl>(declRef->getDecl())) {
58995911
tc.diagnose(expr->getLoc(),
5900-
diag::invalid_differentiable_function_conversion_expr);
5912+
diag::invalid_differentiable_function_conversion_expr,
5913+
isToTypeLinear);
59015914
if (paramDecl->getType()->is<AnyFunctionType>()) {
59025915
auto *typeRepr = paramDecl->getTypeLoc().getTypeRepr();
59035916
while (auto *attributed = dyn_cast<AttributedTypeRepr>(typeRepr))
59045917
typeRepr = attributed->getTypeRepr();
59055918
std::string attributeString = "@differentiable";
5906-
switch (toType->getDifferentiabilityKind()) {
5907-
case DifferentiabilityKind::Linear:
5919+
if (isToTypeLinear)
59085920
attributeString += "(linear)";
5909-
break;
5910-
case DifferentiabilityKind::Normal:
5911-
case DifferentiabilityKind::NonDifferentiable:
5912-
break;
5913-
}
59145921
auto *funcTypeRepr = cast<FunctionTypeRepr>(typeRepr);
59155922
auto paramListLoc = funcTypeRepr->getArgsTypeRepr()->getStartLoc();
59165923
tc.diagnose(paramDecl->getLoc(),
@@ -5930,7 +5937,8 @@ maybeDiagnoseUnsupportedDifferentiableConversion(ConstraintSystem &cs,
59305937
return;
59315938
}
59325939
tc.diagnose(expr->getLoc(),
5933-
diag::invalid_differentiable_function_conversion_expr);
5940+
diag::invalid_differentiable_function_conversion_expr,
5941+
isToTypeLinear);
59345942
};
59355943
maybeDiagnoseFunctionRef(getSemanticExprForDeclOrMemberRef(expr));
59365944
}
@@ -6583,23 +6591,55 @@ Expr *ExprRewriter::coerceToType(Expr *expr, Type toType,
65836591

65846592
// SWIFT_ENABLE_TENSORFLOW
65856593
auto fromEI = fromFunc->getExtInfo();
6594+
auto isFromDifferentiable = fromEI.isDifferentiable();
6595+
auto isToDifferentiable = toEI.isDifferentiable();
65866596
// Handle implicit conversion from @differentiable.
6587-
if (fromEI.isDifferentiable() && !toEI.isDifferentiable()) {
6597+
if (isFromDifferentiable && !isToDifferentiable) {
65886598
fromFunc = fromFunc->getWithoutDifferentiability()
65896599
->castTo<FunctionType>();
6590-
expr = cs.cacheType(new (tc.Context)
6591-
DifferentiableFunctionExtractOriginalExpr(expr, fromFunc));
6600+
switch (fromEI.getDifferentiabilityKind()) {
6601+
case DifferentiabilityKind::Normal:
6602+
expr = cs.cacheType(new (tc.Context)
6603+
DifferentiableFunctionExtractOriginalExpr(expr, fromFunc));
6604+
break;
6605+
case DifferentiabilityKind::Linear:
6606+
expr = cs.cacheType(new (tc.Context)
6607+
LinearFunctionExtractOriginalExpr(expr, fromFunc));
6608+
break;
6609+
case DifferentiabilityKind::NonDifferentiable:
6610+
llvm_unreachable("Cannot be NonDifferentiable");
6611+
}
65926612
}
6593-
// Handle implicit conversion to @differentiable.
6613+
// Handle implicit conversion from @differentiable(linear) to
6614+
// @differentiable.
6615+
else if (fromEI.getDifferentiabilityKind() ==
6616+
DifferentiabilityKind::Linear &&
6617+
toEI.getDifferentiabilityKind() == DifferentiabilityKind::Normal) {
6618+
// TODO(TF-908): Create a `LinearToDifferentiableFunctionExpr` and SILGen
6619+
// it as thunk application. Remove the diagnostic.
6620+
tc.diagnose(expr->getLoc(),
6621+
diag::unsupported_linear_to_differentiable_conversion);
6622+
}
6623+
// Handle implicit conversion from non-@differentiable to @differentiable.
65946624
maybeDiagnoseUnsupportedDifferentiableConversion(cs, expr, toFunc);
6595-
if (!fromEI.isDifferentiable() && toEI.isDifferentiable()) {
6625+
if (!isFromDifferentiable && isToDifferentiable) {
65966626
auto newEI =
65976627
fromEI.withDifferentiabilityKind(toEI.getDifferentiabilityKind());
65986628
fromFunc = FunctionType::get(toFunc->getParams(), fromFunc->getResult())
65996629
->withExtInfo(newEI)
66006630
->castTo<FunctionType>();
6601-
expr = cs.cacheType(new (tc.Context)
6602-
DifferentiableFunctionExpr(expr, fromFunc));
6631+
switch (toEI.getDifferentiabilityKind()) {
6632+
case DifferentiabilityKind::Normal:
6633+
expr = cs.cacheType(new (tc.Context)
6634+
DifferentiableFunctionExpr(expr, fromFunc));
6635+
break;
6636+
case DifferentiabilityKind::Linear:
6637+
expr = cs.cacheType(new (tc.Context)
6638+
LinearFunctionExpr(expr, fromFunc));
6639+
break;
6640+
case DifferentiabilityKind::NonDifferentiable:
6641+
llvm_unreachable("Cannot be NonDifferentiable");
6642+
}
66036643
}
66046644

66056645
// If we have a ClosureExpr, then we can safely propagate the 'no escape'

test/AutoDiff/differentiable_func_type_type_checking.swift

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,17 @@ extension Float {
5050
_ = gradient(of: Float.addOne) // okay
5151
_ = gradient(of: Float(1.0).addOne) // okay
5252

53+
// TODO(TF-908): Remove this test once linear-to-differentiable conversion is supported.
54+
func linearToDifferentiable(_ f: @escaping @differentiable(linear) (Float) -> Float) {
55+
// expected-error @+1 {{conversion from '@differentiable(linear)' to '@differentiable' is not yet supported}}
56+
_ = f as @differentiable (Float) -> Float
57+
}
58+
59+
func differentiableToLinear(_ f: @escaping @differentiable (Float) -> Float) {
60+
// expected-error @+1 {{a '@differentiable(linear)' function can only be formed from a reference to a 'func' or a literal closure}}
61+
_ = f as @differentiable(linear) (Float) -> Float
62+
}
63+
5364
//===----------------------------------------------------------------------===//
5465
// Parameter selection (@nondiff)
5566
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)