Skip to content

Commit 39fb55b

Browse files
authored
Merge pull request #28156 from marcrasi/diff-fn-ty-ast-bits
2 parents cd3ada5 + 6413f43 commit 39fb55b

File tree

17 files changed

+267
-63
lines changed

17 files changed

+267
-63
lines changed

include/swift/AST/AutoDiff.h

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,26 @@
22
//
33
// This source file is part of the Swift.org open source project
44
//
5-
// Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors
5+
// Copyright (c) 2019 Apple Inc. and the Swift project authors
66
// Licensed under Apache License v2.0 with Runtime Library Exception
77
//
88
// See https://swift.org/LICENSE.txt for license information
99
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
1010
//
1111
//===----------------------------------------------------------------------===//
1212
//
13-
// SWIFT_ENABLE_TENSORFLOW
1413
// This file defines AST support for automatic differentiation.
1514
//
1615
//===----------------------------------------------------------------------===//
1716

1817
#ifndef SWIFT_AST_AUTODIFF_H
1918
#define SWIFT_AST_AUTODIFF_H
2019

20+
#include <cstdint>
21+
22+
#include "swift/AST/Identifier.h"
2123
#include "swift/AST/IndexSubset.h"
24+
#include "swift/Basic/SourceLoc.h"
2225
#include "swift/Basic/Range.h"
2326

2427
namespace swift {
@@ -86,6 +89,12 @@ class ParsedAutoDiffParameter {
8689
}
8790
};
8891

92+
enum class DifferentiabilityKind : uint8_t {
93+
NonDifferentiable = 0,
94+
Normal = 1,
95+
Linear = 2
96+
};
97+
8998
} // end namespace swift
9099

91100
#endif // SWIFT_AST_AUTODIFF_H

include/swift/AST/Types.h

Lines changed: 53 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#ifndef SWIFT_TYPES_H
1818
#define SWIFT_TYPES_H
1919

20+
#include "swift/AST/AutoDiff.h"
2021
#include "swift/AST/DeclContext.h"
2122
#include "swift/AST/GenericParamKey.h"
2223
#include "swift/AST/Identifier.h"
@@ -301,8 +302,8 @@ class alignas(1 << TypeAlignInBits) TypeBase {
301302
}
302303

303304
protected:
304-
enum { NumAFTExtInfoBits = 6 };
305-
enum { NumSILExtInfoBits = 6 };
305+
enum { NumAFTExtInfoBits = 8 };
306+
enum { NumSILExtInfoBits = 8 };
306307
union { uint64_t OpaqueBits;
307308

308309
SWIFT_INLINE_BITFIELD_BASE(TypeBase, bitmax(NumTypeKindBits,8) +
@@ -2879,14 +2880,16 @@ class AnyFunctionType : public TypeBase {
28792880
// If bits are added or removed, then TypeBase::AnyFunctionTypeBits
28802881
// and NumMaskBits must be updated, and they must match.
28812882
//
2882-
// |representation|noEscape|throws|
2883-
// | 0 .. 3 | 4 | 5 |
2883+
// |representation|noEscape|throws|differentiability|
2884+
// | 0 .. 3 | 4 | 5 | 6 .. 7 |
28842885
//
28852886
enum : unsigned {
2886-
RepresentationMask = 0xF << 0,
2887-
NoEscapeMask = 1 << 4,
2888-
ThrowsMask = 1 << 5,
2889-
NumMaskBits = 6
2887+
RepresentationMask = 0xF << 0,
2888+
NoEscapeMask = 1 << 4,
2889+
ThrowsMask = 1 << 5,
2890+
DifferentiabilityMaskOffset = 6,
2891+
DifferentiabilityMask = 0x3 << DifferentiabilityMaskOffset,
2892+
NumMaskBits = 8
28902893
};
28912894

28922895
unsigned Bits; // Naturally sized for speed.
@@ -2909,13 +2912,24 @@ class AnyFunctionType : public TypeBase {
29092912
// Constructor with no defaults.
29102913
ExtInfo(Representation Rep,
29112914
bool IsNoEscape,
2912-
bool Throws)
2915+
bool Throws,
2916+
DifferentiabilityKind DiffKind)
29132917
: ExtInfo(Rep, Throws) {
29142918
Bits |= (IsNoEscape ? NoEscapeMask : 0);
2919+
Bits |= ((unsigned)DiffKind << DifferentiabilityMaskOffset) &
2920+
DifferentiabilityMask;
29152921
}
29162922

29172923
bool isNoEscape() const { return Bits & NoEscapeMask; }
29182924
bool throws() const { return Bits & ThrowsMask; }
2925+
bool isDifferentiable() const {
2926+
return getDifferentiabilityKind() >
2927+
DifferentiabilityKind::NonDifferentiable;
2928+
}
2929+
DifferentiabilityKind getDifferentiabilityKind() const {
2930+
return DifferentiabilityKind((Bits & DifferentiabilityMask) >>
2931+
DifferentiabilityMaskOffset);
2932+
}
29192933
Representation getRepresentation() const {
29202934
unsigned rawRep = Bits & RepresentationMask;
29212935
assert(rawRep <= unsigned(Representation::Last)
@@ -3073,6 +3087,11 @@ class AnyFunctionType : public TypeBase {
30733087
return getExtInfo().throws();
30743088
}
30753089

3090+
bool isDifferentiable() const { return getExtInfo().isDifferentiable(); }
3091+
DifferentiabilityKind getDifferentiabilityKind() const {
3092+
return getExtInfo().getDifferentiabilityKind();
3093+
}
3094+
30763095
/// Returns a new function type exactly like this one but with the ExtInfo
30773096
/// replaced.
30783097
AnyFunctionType *withExtInfo(ExtInfo info) const;
@@ -3731,14 +3750,16 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
37313750
// If bits are added or removed, then TypeBase::SILFunctionTypeBits
37323751
// and NumMaskBits must be updated, and they must match.
37333752

3734-
// |representation|pseudogeneric| noescape |
3735-
// | 0 .. 3 | 4 | 5 |
3753+
// |representation|pseudogeneric| noescape |differentiability|
3754+
// | 0 .. 3 | 4 | 5 | 6 .. 7 |
37363755
//
37373756
enum : unsigned {
37383757
RepresentationMask = 0xF << 0,
37393758
PseudogenericMask = 1 << 4,
37403759
NoEscapeMask = 1 << 5,
3741-
NumMaskBits = 6
3760+
DifferentiabilityMaskOffset = 6,
3761+
DifferentiabilityMask = 0x3 << DifferentiabilityMaskOffset,
3762+
NumMaskBits = 8
37423763
};
37433764

37443765
unsigned Bits; // Naturally sized for speed.
@@ -3752,10 +3773,13 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
37523773
ExtInfo() : Bits(0) { }
37533774

37543775
// Constructor for polymorphic type.
3755-
ExtInfo(Representation rep, bool isPseudogeneric, bool isNoEscape) {
3776+
ExtInfo(Representation rep, bool isPseudogeneric, bool isNoEscape,
3777+
DifferentiabilityKind diffKind) {
37563778
Bits = ((unsigned) rep) |
37573779
(isPseudogeneric ? PseudogenericMask : 0) |
3758-
(isNoEscape ? NoEscapeMask : 0);
3780+
(isNoEscape ? NoEscapeMask : 0) |
3781+
(((unsigned)diffKind << DifferentiabilityMaskOffset) &
3782+
DifferentiabilityMask);
37593783
}
37603784

37613785
/// Is this function pseudo-generic? A pseudo-generic function
@@ -3765,6 +3789,16 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
37653789
// Is this function guaranteed to be no-escape by the type system?
37663790
bool isNoEscape() const { return Bits & NoEscapeMask; }
37673791

3792+
bool isDifferentiable() const {
3793+
return getDifferentiabilityKind() !=
3794+
DifferentiabilityKind::NonDifferentiable;
3795+
}
3796+
3797+
DifferentiabilityKind getDifferentiabilityKind() const {
3798+
return DifferentiabilityKind((Bits & DifferentiabilityMask) >>
3799+
DifferentiabilityMaskOffset);
3800+
}
3801+
37683802
/// What is the abstract representation of this function value?
37693803
Representation getRepresentation() const {
37703804
return Representation(Bits & RepresentationMask);
@@ -4169,6 +4203,11 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
41694203
getRepresentation() == SILFunctionTypeRepresentation::Thick;
41704204
}
41714205

4206+
bool isDifferentiable() const { return getExtInfo().isDifferentiable(); }
4207+
DifferentiabilityKind getDifferentiabilityKind() const {
4208+
return getExtInfo().getDifferentiabilityKind();
4209+
}
4210+
41724211
bool isNoReturnFunction(SILModule &M) const; // Defined in SILType.cpp
41734212

41744213
/// Create a SILFunctionType with the same parameters, results, and attributes as this one, but with

lib/AST/ASTDemangler.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -479,9 +479,9 @@ Type ASTBuilder::createImplFunctionType(
479479
break;
480480
}
481481

482-
auto einfo = SILFunctionType::ExtInfo(representation,
483-
flags.isPseudogeneric(),
484-
!flags.isEscaping());
482+
auto einfo = SILFunctionType::ExtInfo(
483+
representation, flags.isPseudogeneric(), !flags.isEscaping(),
484+
DifferentiabilityKind::NonDifferentiable);
485485

486486
llvm::SmallVector<SILParameterInfo, 8> funcParams;
487487
llvm::SmallVector<SILYieldInfo, 8> funcYields;

lib/AST/ASTPrinter.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3782,6 +3782,14 @@ class TypePrinter : public TypeVisitor<TypePrinter> {
37823782
if (Options.SkipAttributes)
37833783
return;
37843784

3785+
if (!Options.excludeAttrKind(TAK_differentiable) &&
3786+
info.isDifferentiable()) {
3787+
if (info.getDifferentiabilityKind() == DifferentiabilityKind::Linear) {
3788+
Printer << "@differentiable(linear) ";
3789+
} else {
3790+
Printer << "@differentiable ";
3791+
}
3792+
}
37853793

37863794
if (Options.PrintFunctionRepresentationAttrs &&
37873795
!Options.excludeAttrKind(TAK_convention) &&
@@ -3826,6 +3834,15 @@ class TypePrinter : public TypeVisitor<TypePrinter> {
38263834
if (Options.SkipAttributes)
38273835
return;
38283836

3837+
if (!Options.excludeAttrKind(TAK_differentiable) &&
3838+
info.isDifferentiable()) {
3839+
if (info.getDifferentiabilityKind() == DifferentiabilityKind::Linear) {
3840+
Printer << "@differentiable(linear) ";
3841+
} else {
3842+
Printer << "@differentiable ";
3843+
}
3844+
}
3845+
38293846
if (Options.PrintFunctionRepresentationAttrs &&
38303847
!Options.excludeAttrKind(TAK_convention) &&
38313848
info.getRepresentation() != SILFunctionType::Representation::Thick) {

lib/SILGen/SILGen.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -416,10 +416,10 @@ SILGenModule::getKeyPathProjectionCoroutine(bool isReadAccess,
416416
: ParameterConvention::Indirect_In_Guaranteed },
417417
};
418418

419-
auto extInfo =
420-
SILFunctionType::ExtInfo(SILFunctionTypeRepresentation::Thin,
421-
/*pseudogeneric*/false,
422-
/*non-escaping*/false);
419+
auto extInfo = SILFunctionType::ExtInfo(
420+
SILFunctionTypeRepresentation::Thin,
421+
/*pseudogeneric*/ false,
422+
/*non-escaping*/ false, DifferentiabilityKind::NonDifferentiable);
423423

424424
auto functionTy = SILFunctionType::get(sig, extInfo,
425425
SILCoroutineKind::YieldOnce,

lib/SILGen/SILGenExpr.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2684,7 +2684,8 @@ static SILFunction *getOrCreateKeyPathGetter(SILGenModule &SGM,
26842684
auto signature = SILFunctionType::get(genericSig,
26852685
SILFunctionType::ExtInfo(SILFunctionType::Representation::Thin,
26862686
/*pseudogeneric*/ false,
2687-
/*noescape*/ false),
2687+
/*noescape*/ false,
2688+
DifferentiabilityKind::NonDifferentiable),
26882689
SILCoroutineKind::None,
26892690
ParameterConvention::Direct_Unowned,
26902691
params, {}, result, None,
@@ -2828,7 +2829,8 @@ static SILFunction *getOrCreateKeyPathSetter(SILGenModule &SGM,
28282829
auto signature = SILFunctionType::get(genericSig,
28292830
SILFunctionType::ExtInfo(SILFunctionType::Representation::Thin,
28302831
/*pseudogeneric*/ false,
2831-
/*noescape*/ false),
2832+
/*noescape*/ false,
2833+
DifferentiabilityKind::NonDifferentiable),
28322834
SILCoroutineKind::None,
28332835
ParameterConvention::Direct_Unowned,
28342836
params, {}, {}, None,
@@ -3004,7 +3006,8 @@ getOrCreateKeyPathEqualsAndHash(SILGenModule &SGM,
30043006
auto signature = SILFunctionType::get(genericSig,
30053007
SILFunctionType::ExtInfo(SILFunctionType::Representation::Thin,
30063008
/*pseudogeneric*/ false,
3007-
/*noescape*/ false),
3009+
/*noescape*/ false,
3010+
DifferentiabilityKind::NonDifferentiable),
30083011
SILCoroutineKind::None,
30093012
ParameterConvention::Direct_Unowned,
30103013
params, /*yields*/ {}, results, None,
@@ -3180,7 +3183,8 @@ getOrCreateKeyPathEqualsAndHash(SILGenModule &SGM,
31803183
auto signature = SILFunctionType::get(genericSig,
31813184
SILFunctionType::ExtInfo(SILFunctionType::Representation::Thin,
31823185
/*pseudogeneric*/ false,
3183-
/*noescape*/ false),
3186+
/*noescape*/ false,
3187+
DifferentiabilityKind::NonDifferentiable),
31843188
SILCoroutineKind::None,
31853189
ParameterConvention::Direct_Unowned,
31863190
params, /*yields*/ {}, results, None,

lib/SILOptimizer/Transforms/Outliner.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,8 @@ CanSILFunctionType BridgedProperty::getOutlinedFunctionType(SILModule &M) {
288288
ResultConvention::Owned));
289289
auto ExtInfo =
290290
SILFunctionType::ExtInfo(SILFunctionType::Representation::Thin,
291-
/*pseudogeneric*/ false, /*noescape*/ false);
291+
/*pseudogeneric*/ false, /*noescape*/ false,
292+
DifferentiabilityKind::NonDifferentiable);
292293
auto FunctionType = SILFunctionType::get(
293294
nullptr, ExtInfo, SILCoroutineKind::None,
294295
ParameterConvention::Direct_Unowned, Parameters, /*yields*/ {},
@@ -1108,10 +1109,10 @@ CanSILFunctionType ObjCMethodCall::getOutlinedFunctionType(SILModule &M) {
11081109
OrigSigIdx++;
11091110
}
11101111

1111-
auto ExtInfo =
1112-
SILFunctionType::ExtInfo(SILFunctionType::Representation::Thin,
1113-
/*pseudogeneric*/ false,
1114-
/*noescape*/ false);
1112+
auto ExtInfo = SILFunctionType::ExtInfo(
1113+
SILFunctionType::Representation::Thin,
1114+
/*pseudogeneric*/ false,
1115+
/*noescape*/ false, DifferentiabilityKind::NonDifferentiable);
11151116

11161117
SmallVector<SILResultInfo, 4> Results;
11171118
// If we don't have a bridged return we changed from @autoreleased to @owned

lib/SILOptimizer/UtilityPasses/BugReducerTester.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,13 +83,13 @@ class BugReducerTester : public SILFunctionTransform {
8383
ResultInfoArray.push_back(
8484
SILResultInfo(EmptyTupleCanType, ResultConvention::Unowned));
8585
auto FuncType = SILFunctionType::get(
86-
nullptr, SILFunctionType::ExtInfo(SILFunctionType::Representation::Thin,
87-
false /*isPseudoGeneric*/,
88-
false /*noescape*/),
86+
nullptr,
87+
SILFunctionType::ExtInfo(SILFunctionType::Representation::Thin,
88+
false /*isPseudoGeneric*/, false /*noescape*/,
89+
DifferentiabilityKind::NonDifferentiable),
8990
SILCoroutineKind::None, ParameterConvention::Direct_Unowned,
90-
ArrayRef<SILParameterInfo>(), ArrayRef<SILYieldInfo>(),
91-
ResultInfoArray, None,
92-
SubstitutionMap(), false,
91+
ArrayRef<SILParameterInfo>(), ArrayRef<SILYieldInfo>(), ResultInfoArray,
92+
None, SubstitutionMap(), false,
9393
getFunction()->getModule().getASTContext());
9494

9595
SILOptFunctionBuilder FunctionBuilder(*this);

lib/Sema/ConstraintSystem.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1643,7 +1643,8 @@ resolveOverloadForDeclWithSpecialTypeCheckingSemantics(ConstraintSystem &CS,
16431643
auto bodyClosure = FunctionType::get(arg, result,
16441644
FunctionType::ExtInfo(FunctionType::Representation::Swift,
16451645
/*noescape*/ true,
1646-
/*throws*/ true));
1646+
/*throws*/ true,
1647+
DifferentiabilityKind::NonDifferentiable));
16471648
FunctionType::Param args[] = {
16481649
FunctionType::Param(noescapeClosure),
16491650
FunctionType::Param(bodyClosure, CS.getASTContext().getIdentifier("do")),
@@ -1652,7 +1653,8 @@ resolveOverloadForDeclWithSpecialTypeCheckingSemantics(ConstraintSystem &CS,
16521653
refType = FunctionType::get(args, result,
16531654
FunctionType::ExtInfo(FunctionType::Representation::Swift,
16541655
/*noescape*/ false,
1655-
/*throws*/ true));
1656+
/*throws*/ true,
1657+
DifferentiabilityKind::NonDifferentiable));
16561658
openedFullType = refType;
16571659
return true;
16581660
}
@@ -1675,15 +1677,17 @@ resolveOverloadForDeclWithSpecialTypeCheckingSemantics(ConstraintSystem &CS,
16751677
auto bodyClosure = FunctionType::get(bodyArgs, result,
16761678
FunctionType::ExtInfo(FunctionType::Representation::Swift,
16771679
/*noescape*/ true,
1678-
/*throws*/ true));
1680+
/*throws*/ true,
1681+
DifferentiabilityKind::NonDifferentiable));
16791682
FunctionType::Param args[] = {
16801683
FunctionType::Param(existentialTy),
16811684
FunctionType::Param(bodyClosure, CS.getASTContext().getIdentifier("do")),
16821685
};
16831686
refType = FunctionType::get(args, result,
16841687
FunctionType::ExtInfo(FunctionType::Representation::Swift,
16851688
/*noescape*/ false,
1686-
/*throws*/ true));
1689+
/*throws*/ true,
1690+
DifferentiabilityKind::NonDifferentiable));
16871691
openedFullType = refType;
16881692
return true;
16891693
}

0 commit comments

Comments
 (0)