Skip to content

Commit 6413f43

Browse files
author
Marc Rasi
committed
[AutoDiff upstream] AST bits for @differentiable fn ty
1 parent 631305f commit 6413f43

File tree

17 files changed

+267
-63
lines changed

17 files changed

+267
-63
lines changed

include/swift/AST/AutoDiff.h

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,26 @@
22
//
33
// This source file is part of the Swift.org open source project
44
//
5-
// Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors
5+
// Copyright (c) 2019 Apple Inc. and the Swift project authors
66
// Licensed under Apache License v2.0 with Runtime Library Exception
77
//
88
// See https://swift.org/LICENSE.txt for license information
99
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
1010
//
1111
//===----------------------------------------------------------------------===//
1212
//
13-
// SWIFT_ENABLE_TENSORFLOW
1413
// This file defines AST support for automatic differentiation.
1514
//
1615
//===----------------------------------------------------------------------===//
1716

1817
#ifndef SWIFT_AST_AUTODIFF_H
1918
#define SWIFT_AST_AUTODIFF_H
2019

20+
#include <cstdint>
21+
22+
#include "swift/AST/Identifier.h"
2123
#include "swift/AST/IndexSubset.h"
24+
#include "swift/Basic/SourceLoc.h"
2225
#include "swift/Basic/Range.h"
2326

2427
namespace swift {
@@ -86,6 +89,12 @@ class ParsedAutoDiffParameter {
8689
}
8790
};
8891

92+
enum class DifferentiabilityKind : uint8_t {
93+
NonDifferentiable = 0,
94+
Normal = 1,
95+
Linear = 2
96+
};
97+
8998
} // end namespace swift
9099

91100
#endif // SWIFT_AST_AUTODIFF_H

include/swift/AST/Types.h

Lines changed: 53 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#ifndef SWIFT_TYPES_H
1818
#define SWIFT_TYPES_H
1919

20+
#include "swift/AST/AutoDiff.h"
2021
#include "swift/AST/DeclContext.h"
2122
#include "swift/AST/GenericParamKey.h"
2223
#include "swift/AST/Identifier.h"
@@ -300,8 +301,8 @@ class alignas(1 << TypeAlignInBits) TypeBase {
300301
}
301302

302303
protected:
303-
enum { NumAFTExtInfoBits = 6 };
304-
enum { NumSILExtInfoBits = 6 };
304+
enum { NumAFTExtInfoBits = 8 };
305+
enum { NumSILExtInfoBits = 8 };
305306
union { uint64_t OpaqueBits;
306307

307308
SWIFT_INLINE_BITFIELD_BASE(TypeBase, bitmax(NumTypeKindBits,8) +
@@ -2875,14 +2876,16 @@ class AnyFunctionType : public TypeBase {
28752876
// If bits are added or removed, then TypeBase::AnyFunctionTypeBits
28762877
// and NumMaskBits must be updated, and they must match.
28772878
//
2878-
// |representation|noEscape|throws|
2879-
// | 0 .. 3 | 4 | 5 |
2879+
// |representation|noEscape|throws|differentiability|
2880+
// | 0 .. 3 | 4 | 5 | 6 .. 7 |
28802881
//
28812882
enum : unsigned {
2882-
RepresentationMask = 0xF << 0,
2883-
NoEscapeMask = 1 << 4,
2884-
ThrowsMask = 1 << 5,
2885-
NumMaskBits = 6
2883+
RepresentationMask = 0xF << 0,
2884+
NoEscapeMask = 1 << 4,
2885+
ThrowsMask = 1 << 5,
2886+
DifferentiabilityMaskOffset = 6,
2887+
DifferentiabilityMask = 0x3 << DifferentiabilityMaskOffset,
2888+
NumMaskBits = 8
28862889
};
28872890

28882891
unsigned Bits; // Naturally sized for speed.
@@ -2905,13 +2908,24 @@ class AnyFunctionType : public TypeBase {
29052908
// Constructor with no defaults.
29062909
ExtInfo(Representation Rep,
29072910
bool IsNoEscape,
2908-
bool Throws)
2911+
bool Throws,
2912+
DifferentiabilityKind DiffKind)
29092913
: ExtInfo(Rep, Throws) {
29102914
Bits |= (IsNoEscape ? NoEscapeMask : 0);
2915+
Bits |= ((unsigned)DiffKind << DifferentiabilityMaskOffset) &
2916+
DifferentiabilityMask;
29112917
}
29122918

29132919
bool isNoEscape() const { return Bits & NoEscapeMask; }
29142920
bool throws() const { return Bits & ThrowsMask; }
2921+
bool isDifferentiable() const {
2922+
return getDifferentiabilityKind() >
2923+
DifferentiabilityKind::NonDifferentiable;
2924+
}
2925+
DifferentiabilityKind getDifferentiabilityKind() const {
2926+
return DifferentiabilityKind((Bits & DifferentiabilityMask) >>
2927+
DifferentiabilityMaskOffset);
2928+
}
29152929
Representation getRepresentation() const {
29162930
unsigned rawRep = Bits & RepresentationMask;
29172931
assert(rawRep <= unsigned(Representation::Last)
@@ -3069,6 +3083,11 @@ class AnyFunctionType : public TypeBase {
30693083
return getExtInfo().throws();
30703084
}
30713085

3086+
bool isDifferentiable() const { return getExtInfo().isDifferentiable(); }
3087+
DifferentiabilityKind getDifferentiabilityKind() const {
3088+
return getExtInfo().getDifferentiabilityKind();
3089+
}
3090+
30723091
/// Returns a new function type exactly like this one but with the ExtInfo
30733092
/// replaced.
30743093
AnyFunctionType *withExtInfo(ExtInfo info) const;
@@ -3716,14 +3735,16 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
37163735
// If bits are added or removed, then TypeBase::SILFunctionTypeBits
37173736
// and NumMaskBits must be updated, and they must match.
37183737

3719-
// |representation|pseudogeneric| noescape |
3720-
// | 0 .. 3 | 4 | 5 |
3738+
// |representation|pseudogeneric| noescape |differentiability|
3739+
// | 0 .. 3 | 4 | 5 | 6 .. 7 |
37213740
//
37223741
enum : unsigned {
37233742
RepresentationMask = 0xF << 0,
37243743
PseudogenericMask = 1 << 4,
37253744
NoEscapeMask = 1 << 5,
3726-
NumMaskBits = 6
3745+
DifferentiabilityMaskOffset = 6,
3746+
DifferentiabilityMask = 0x3 << DifferentiabilityMaskOffset,
3747+
NumMaskBits = 8
37273748
};
37283749

37293750
unsigned Bits; // Naturally sized for speed.
@@ -3737,10 +3758,13 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
37373758
ExtInfo() : Bits(0) { }
37383759

37393760
// Constructor for polymorphic type.
3740-
ExtInfo(Representation rep, bool isPseudogeneric, bool isNoEscape) {
3761+
ExtInfo(Representation rep, bool isPseudogeneric, bool isNoEscape,
3762+
DifferentiabilityKind diffKind) {
37413763
Bits = ((unsigned) rep) |
37423764
(isPseudogeneric ? PseudogenericMask : 0) |
3743-
(isNoEscape ? NoEscapeMask : 0);
3765+
(isNoEscape ? NoEscapeMask : 0) |
3766+
(((unsigned)diffKind << DifferentiabilityMaskOffset) &
3767+
DifferentiabilityMask);
37443768
}
37453769

37463770
/// Is this function pseudo-generic? A pseudo-generic function
@@ -3750,6 +3774,16 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
37503774
// Is this function guaranteed to be no-escape by the type system?
37513775
bool isNoEscape() const { return Bits & NoEscapeMask; }
37523776

3777+
bool isDifferentiable() const {
3778+
return getDifferentiabilityKind() !=
3779+
DifferentiabilityKind::NonDifferentiable;
3780+
}
3781+
3782+
DifferentiabilityKind getDifferentiabilityKind() const {
3783+
return DifferentiabilityKind((Bits & DifferentiabilityMask) >>
3784+
DifferentiabilityMaskOffset);
3785+
}
3786+
37533787
/// What is the abstract representation of this function value?
37543788
Representation getRepresentation() const {
37553789
return Representation(Bits & RepresentationMask);
@@ -4154,6 +4188,11 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
41544188
getRepresentation() == SILFunctionTypeRepresentation::Thick;
41554189
}
41564190

4191+
bool isDifferentiable() const { return getExtInfo().isDifferentiable(); }
4192+
DifferentiabilityKind getDifferentiabilityKind() const {
4193+
return getExtInfo().getDifferentiabilityKind();
4194+
}
4195+
41574196
bool isNoReturnFunction(SILModule &M) const; // Defined in SILType.cpp
41584197

41594198
/// Create a SILFunctionType with the same parameters, results, and attributes as this one, but with

lib/AST/ASTDemangler.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -479,9 +479,9 @@ Type ASTBuilder::createImplFunctionType(
479479
break;
480480
}
481481

482-
auto einfo = SILFunctionType::ExtInfo(representation,
483-
flags.isPseudogeneric(),
484-
!flags.isEscaping());
482+
auto einfo = SILFunctionType::ExtInfo(
483+
representation, flags.isPseudogeneric(), !flags.isEscaping(),
484+
DifferentiabilityKind::NonDifferentiable);
485485

486486
llvm::SmallVector<SILParameterInfo, 8> funcParams;
487487
llvm::SmallVector<SILYieldInfo, 8> funcYields;

lib/AST/ASTPrinter.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3789,6 +3789,14 @@ class TypePrinter : public TypeVisitor<TypePrinter> {
37893789
if (Options.SkipAttributes)
37903790
return;
37913791

3792+
if (!Options.excludeAttrKind(TAK_differentiable) &&
3793+
info.isDifferentiable()) {
3794+
if (info.getDifferentiabilityKind() == DifferentiabilityKind::Linear) {
3795+
Printer << "@differentiable(linear) ";
3796+
} else {
3797+
Printer << "@differentiable ";
3798+
}
3799+
}
37923800

37933801
if (Options.PrintFunctionRepresentationAttrs &&
37943802
!Options.excludeAttrKind(TAK_convention) &&
@@ -3833,6 +3841,15 @@ class TypePrinter : public TypeVisitor<TypePrinter> {
38333841
if (Options.SkipAttributes)
38343842
return;
38353843

3844+
if (!Options.excludeAttrKind(TAK_differentiable) &&
3845+
info.isDifferentiable()) {
3846+
if (info.getDifferentiabilityKind() == DifferentiabilityKind::Linear) {
3847+
Printer << "@differentiable(linear) ";
3848+
} else {
3849+
Printer << "@differentiable ";
3850+
}
3851+
}
3852+
38363853
if (Options.PrintFunctionRepresentationAttrs &&
38373854
!Options.excludeAttrKind(TAK_convention) &&
38383855
info.getRepresentation() != SILFunctionType::Representation::Thick) {

lib/SILGen/SILGen.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -415,10 +415,10 @@ SILGenModule::getKeyPathProjectionCoroutine(bool isReadAccess,
415415
: ParameterConvention::Indirect_In_Guaranteed },
416416
};
417417

418-
auto extInfo =
419-
SILFunctionType::ExtInfo(SILFunctionTypeRepresentation::Thin,
420-
/*pseudogeneric*/false,
421-
/*non-escaping*/false);
418+
auto extInfo = SILFunctionType::ExtInfo(
419+
SILFunctionTypeRepresentation::Thin,
420+
/*pseudogeneric*/ false,
421+
/*non-escaping*/ false, DifferentiabilityKind::NonDifferentiable);
422422

423423
auto functionTy = SILFunctionType::get(sig, extInfo,
424424
SILCoroutineKind::YieldOnce,

lib/SILGen/SILGenExpr.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2669,7 +2669,8 @@ static SILFunction *getOrCreateKeyPathGetter(SILGenModule &SGM,
26692669
auto signature = SILFunctionType::get(genericSig,
26702670
SILFunctionType::ExtInfo(SILFunctionType::Representation::Thin,
26712671
/*pseudogeneric*/ false,
2672-
/*noescape*/ false),
2672+
/*noescape*/ false,
2673+
DifferentiabilityKind::NonDifferentiable),
26732674
SILCoroutineKind::None,
26742675
ParameterConvention::Direct_Unowned,
26752676
params, {}, result, None,
@@ -2811,7 +2812,8 @@ static SILFunction *getOrCreateKeyPathSetter(SILGenModule &SGM,
28112812
auto signature = SILFunctionType::get(genericSig,
28122813
SILFunctionType::ExtInfo(SILFunctionType::Representation::Thin,
28132814
/*pseudogeneric*/ false,
2814-
/*noescape*/ false),
2815+
/*noescape*/ false,
2816+
DifferentiabilityKind::NonDifferentiable),
28152817
SILCoroutineKind::None,
28162818
ParameterConvention::Direct_Unowned,
28172819
params, {}, {}, None,
@@ -2987,7 +2989,8 @@ getOrCreateKeyPathEqualsAndHash(SILGenModule &SGM,
29872989
auto signature = SILFunctionType::get(genericSig,
29882990
SILFunctionType::ExtInfo(SILFunctionType::Representation::Thin,
29892991
/*pseudogeneric*/ false,
2990-
/*noescape*/ false),
2992+
/*noescape*/ false,
2993+
DifferentiabilityKind::NonDifferentiable),
29912994
SILCoroutineKind::None,
29922995
ParameterConvention::Direct_Unowned,
29932996
params, /*yields*/ {}, results, None,
@@ -3162,7 +3165,8 @@ getOrCreateKeyPathEqualsAndHash(SILGenModule &SGM,
31623165
auto signature = SILFunctionType::get(genericSig,
31633166
SILFunctionType::ExtInfo(SILFunctionType::Representation::Thin,
31643167
/*pseudogeneric*/ false,
3165-
/*noescape*/ false),
3168+
/*noescape*/ false,
3169+
DifferentiabilityKind::NonDifferentiable),
31663170
SILCoroutineKind::None,
31673171
ParameterConvention::Direct_Unowned,
31683172
params, /*yields*/ {}, results, None,

lib/SILOptimizer/Transforms/Outliner.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,8 @@ CanSILFunctionType BridgedProperty::getOutlinedFunctionType(SILModule &M) {
288288
ResultConvention::Owned));
289289
auto ExtInfo =
290290
SILFunctionType::ExtInfo(SILFunctionType::Representation::Thin,
291-
/*pseudogeneric*/ false, /*noescape*/ false);
291+
/*pseudogeneric*/ false, /*noescape*/ false,
292+
DifferentiabilityKind::NonDifferentiable);
292293
auto FunctionType = SILFunctionType::get(
293294
nullptr, ExtInfo, SILCoroutineKind::None,
294295
ParameterConvention::Direct_Unowned, Parameters, /*yields*/ {},
@@ -1108,10 +1109,10 @@ CanSILFunctionType ObjCMethodCall::getOutlinedFunctionType(SILModule &M) {
11081109
OrigSigIdx++;
11091110
}
11101111

1111-
auto ExtInfo =
1112-
SILFunctionType::ExtInfo(SILFunctionType::Representation::Thin,
1113-
/*pseudogeneric*/ false,
1114-
/*noescape*/ false);
1112+
auto ExtInfo = SILFunctionType::ExtInfo(
1113+
SILFunctionType::Representation::Thin,
1114+
/*pseudogeneric*/ false,
1115+
/*noescape*/ false, DifferentiabilityKind::NonDifferentiable);
11151116

11161117
SmallVector<SILResultInfo, 4> Results;
11171118
// If we don't have a bridged return we changed from @autoreleased to @owned

lib/SILOptimizer/UtilityPasses/BugReducerTester.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,13 +83,13 @@ class BugReducerTester : public SILFunctionTransform {
8383
ResultInfoArray.push_back(
8484
SILResultInfo(EmptyTupleCanType, ResultConvention::Unowned));
8585
auto FuncType = SILFunctionType::get(
86-
nullptr, SILFunctionType::ExtInfo(SILFunctionType::Representation::Thin,
87-
false /*isPseudoGeneric*/,
88-
false /*noescape*/),
86+
nullptr,
87+
SILFunctionType::ExtInfo(SILFunctionType::Representation::Thin,
88+
false /*isPseudoGeneric*/, false /*noescape*/,
89+
DifferentiabilityKind::NonDifferentiable),
8990
SILCoroutineKind::None, ParameterConvention::Direct_Unowned,
90-
ArrayRef<SILParameterInfo>(), ArrayRef<SILYieldInfo>(),
91-
ResultInfoArray, None,
92-
SubstitutionMap(), false,
91+
ArrayRef<SILParameterInfo>(), ArrayRef<SILYieldInfo>(), ResultInfoArray,
92+
None, SubstitutionMap(), false,
9393
getFunction()->getModule().getASTContext());
9494

9595
SILOptFunctionBuilder FunctionBuilder(*this);

lib/Sema/ConstraintSystem.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1642,7 +1642,8 @@ resolveOverloadForDeclWithSpecialTypeCheckingSemantics(ConstraintSystem &CS,
16421642
auto bodyClosure = FunctionType::get(arg, result,
16431643
FunctionType::ExtInfo(FunctionType::Representation::Swift,
16441644
/*noescape*/ true,
1645-
/*throws*/ true));
1645+
/*throws*/ true,
1646+
DifferentiabilityKind::NonDifferentiable));
16461647
FunctionType::Param args[] = {
16471648
FunctionType::Param(noescapeClosure),
16481649
FunctionType::Param(bodyClosure, CS.getASTContext().getIdentifier("do")),
@@ -1651,7 +1652,8 @@ resolveOverloadForDeclWithSpecialTypeCheckingSemantics(ConstraintSystem &CS,
16511652
refType = FunctionType::get(args, result,
16521653
FunctionType::ExtInfo(FunctionType::Representation::Swift,
16531654
/*noescape*/ false,
1654-
/*throws*/ true));
1655+
/*throws*/ true,
1656+
DifferentiabilityKind::NonDifferentiable));
16551657
openedFullType = refType;
16561658
return true;
16571659
}
@@ -1674,15 +1676,17 @@ resolveOverloadForDeclWithSpecialTypeCheckingSemantics(ConstraintSystem &CS,
16741676
auto bodyClosure = FunctionType::get(bodyArgs, result,
16751677
FunctionType::ExtInfo(FunctionType::Representation::Swift,
16761678
/*noescape*/ true,
1677-
/*throws*/ true));
1679+
/*throws*/ true,
1680+
DifferentiabilityKind::NonDifferentiable));
16781681
FunctionType::Param args[] = {
16791682
FunctionType::Param(existentialTy),
16801683
FunctionType::Param(bodyClosure, CS.getASTContext().getIdentifier("do")),
16811684
};
16821685
refType = FunctionType::get(args, result,
16831686
FunctionType::ExtInfo(FunctionType::Representation::Swift,
16841687
/*noescape*/ false,
1685-
/*throws*/ true));
1688+
/*throws*/ true,
1689+
DifferentiabilityKind::NonDifferentiable));
16861690
openedFullType = refType;
16871691
return true;
16881692
}

0 commit comments

Comments
 (0)