Skip to content

Commit 4ccdfa2

Browse files
committed
Merge remote-tracking branch 'origin/master' into master-next
2 parents b1399db + 62f6686 commit 4ccdfa2

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+1814
-112
lines changed

docs/SIL.rst

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5774,6 +5774,67 @@ The rules on generic substitutions are identical to those of ``apply``.
57745774
Differentiable Programming
57755775
~~~~~~~~~~~~~~~~~~~~~~~~~~
57765776

5777+
differentiable_function
5778+
```````````````````````
5779+
::
5780+
5781+
sil-instruction ::= 'differentiable_function'
5782+
sil-differentiable-function-parameter-indices
5783+
sil-value ':' sil-type
5784+
sil-differentiable-function-derivative-functions-clause?
5785+
5786+
sil-differentiable-function-parameter-indices ::=
5787+
'[' 'parameters' [0-9]+ (' ' [0-9]+)* ']'
5788+
sil-differentiable-derivative-functions-clause ::=
5789+
'with_derivative'
5790+
'{' sil-value ':' sil-type ',' sil-value ':' sil-type '}'
5791+
5792+
differentiable_function [parameters 0] %0 : $(T) -> T \
5793+
with_derivative {%1 : $(T) -> (T, (T) -> T), %2 : $(T) -> (T, (T) -> T)}
5794+
5795+
Creates a ``@differentiable`` function from an original function operand and
5796+
derivative function operands (optional). There are two derivative function
5797+
kinds: a Jacobian-vector products (JVP) function and a vector-Jacobian products
5798+
(VJP) function.
5799+
5800+
``[parameters ...]`` specifies parameter indices that the original function is
5801+
differentiable with respect to.
5802+
5803+
The ``with_derivative`` clause specifies the derivative function operands
5804+
associated with the original function.
5805+
5806+
The differentiation transformation canonicalizes all `differentiable_function`
5807+
instructions, generating derivative functions if necessary to fill in derivative
5808+
function operands.
5809+
5810+
In raw SIL, the ``with_derivative`` clause is optional. In canonical SIL, the
5811+
``with_derivative`` clause is mandatory.
5812+
5813+
5814+
differentiable_function_extract
5815+
```````````````````````````````
5816+
::
5817+
5818+
sil-instruction ::= 'differentiable_function_extract'
5819+
'[' sil-differentiable-function-extractee ']'
5820+
sil-value ':' sil-type
5821+
('as' sil-type)?
5822+
5823+
sil-differentiable-function-extractee ::= 'original' | 'jvp' | 'vjp'
5824+
5825+
differentiable_function_extract [original] %0 : $@differentiable (T) -> T
5826+
differentiable_function_extract [jvp] %0 : $@differentiable (T) -> T
5827+
differentiable_function_extract [vjp] %0 : $@differentiable (T) -> T
5828+
differentiable_function_extract [jvp] %0 : $@differentiable (T) -> T \
5829+
as $(@in_constant T) -> (T, (T.TangentVector) -> T.TangentVector)
5830+
5831+
Extracts the original function or a derivative function from the given
5832+
``@differentiable`` function. The extractee is one of the following:
5833+
``[original]``, ``[jvp]``, or ``[vjp]``.
5834+
5835+
In lowered SIL, an explicit extractee type may be provided. This is currently
5836+
used by the LoadableByAddress transformation, which rewrites function types.
5837+
57775838
differentiability_witness_function
57785839
``````````````````````````````````
57795840
::

include/swift/AST/AutoDiff.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,41 @@ struct AutoDiffDerivativeFunctionKind {
7575
}
7676
};
7777

78+
/// A component of a SIL `@differentiable` function-typed value.
79+
struct NormalDifferentiableFunctionTypeComponent {
80+
enum innerty : unsigned { Original = 0, JVP = 1, VJP = 2 } rawValue;
81+
82+
NormalDifferentiableFunctionTypeComponent() = default;
83+
NormalDifferentiableFunctionTypeComponent(innerty rawValue)
84+
: rawValue(rawValue) {}
85+
NormalDifferentiableFunctionTypeComponent(
86+
AutoDiffDerivativeFunctionKind kind);
87+
explicit NormalDifferentiableFunctionTypeComponent(unsigned rawValue)
88+
: NormalDifferentiableFunctionTypeComponent((innerty)rawValue) {}
89+
explicit NormalDifferentiableFunctionTypeComponent(StringRef name);
90+
operator innerty() const { return rawValue; }
91+
92+
/// Returns the derivative function kind, if the component is a derivative
93+
/// function.
94+
Optional<AutoDiffDerivativeFunctionKind> getAsDerivativeFunctionKind() const;
95+
};
96+
97+
/// A component of a SIL `@differentiable(linear)` function-typed value.
98+
struct LinearDifferentiableFunctionTypeComponent {
99+
enum innerty : unsigned {
100+
Original = 0,
101+
Transpose = 1,
102+
} rawValue;
103+
104+
LinearDifferentiableFunctionTypeComponent() = default;
105+
LinearDifferentiableFunctionTypeComponent(innerty rawValue)
106+
: rawValue(rawValue) {}
107+
explicit LinearDifferentiableFunctionTypeComponent(unsigned rawValue)
108+
: LinearDifferentiableFunctionTypeComponent((innerty)rawValue) {}
109+
explicit LinearDifferentiableFunctionTypeComponent(StringRef name);
110+
operator innerty() const { return rawValue; }
111+
};
112+
78113
/// A derivative function configuration, uniqued in `ASTContext`.
79114
/// Identifies a specific derivative function given an original function.
80115
class AutoDiffDerivativeFunctionIdentifier : public llvm::FoldingSetNode {

include/swift/AST/Decl.h

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7063,6 +7063,28 @@ class PrecedenceGroupDecl : public Decl {
70637063
}
70647064
};
70657065

7066+
/// The fixity of an OperatorDecl.
7067+
enum class OperatorFixity : uint8_t {
7068+
Infix,
7069+
Prefix,
7070+
Postfix
7071+
};
7072+
7073+
inline void simple_display(llvm::raw_ostream &out, OperatorFixity fixity) {
7074+
switch (fixity) {
7075+
case OperatorFixity::Infix:
7076+
out << "infix";
7077+
return;
7078+
case OperatorFixity::Prefix:
7079+
out << "prefix";
7080+
return;
7081+
case OperatorFixity::Postfix:
7082+
out << "postfix";
7083+
return;
7084+
}
7085+
llvm_unreachable("Unhandled case in switch");
7086+
}
7087+
70667088
/// Abstract base class of operator declarations.
70677089
class OperatorDecl : public Decl {
70687090
SourceLoc OperatorLoc, NameLoc;
@@ -7088,6 +7110,21 @@ class OperatorDecl : public Decl {
70887110
: Decl(kind, DC), OperatorLoc(OperatorLoc), NameLoc(NameLoc), name(Name),
70897111
DesignatedNominalTypes(DesignatedNominalTypes) {}
70907112

7113+
/// Retrieve the operator's fixity, corresponding to the concrete subclass
7114+
/// of the OperatorDecl.
7115+
OperatorFixity getFixity() const {
7116+
switch (getKind()) {
7117+
#define DECL(Id, Name) case DeclKind::Id: llvm_unreachable("Not an operator!");
7118+
#define OPERATOR_DECL(Id, Name)
7119+
#include "swift/AST/DeclNodes.def"
7120+
case DeclKind::InfixOperator:
7121+
return OperatorFixity::Infix;
7122+
case DeclKind::PrefixOperator:
7123+
return OperatorFixity::Prefix;
7124+
case DeclKind::PostfixOperator:
7125+
return OperatorFixity::Postfix;
7126+
}
7127+
}
70917128

70927129
SourceLoc getOperatorLoc() const { return OperatorLoc; }
70937130
SourceLoc getNameLoc() const { return NameLoc; }

include/swift/AST/DiagnosticsParse.def

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1628,6 +1628,17 @@ ERROR(sil_autodiff_expected_parameter_index,PointsToFirstBadToken,
16281628
"expected the index of a parameter to differentiate with respect to", ())
16291629
ERROR(sil_autodiff_expected_result_index,PointsToFirstBadToken,
16301630
"expected the index of a result to differentiate from", ())
1631+
ERROR(sil_inst_autodiff_operand_list_expected_lbrace,PointsToFirstBadToken,
1632+
"expected '{' to start a derivative function list", ())
1633+
ERROR(sil_inst_autodiff_operand_list_expected_comma,PointsToFirstBadToken,
1634+
"expected ',' between operands in a derivative function list", ())
1635+
ERROR(sil_inst_autodiff_operand_list_expected_rbrace,PointsToFirstBadToken,
1636+
"expected '}' to start a derivative function list", ())
1637+
ERROR(sil_inst_autodiff_expected_differentiable_extractee_kind,PointsToFirstBadToken,
1638+
"expected an extractee kind attribute, which can be one of '[original]', "
1639+
"'[jvp]', and '[vjp]'", ())
1640+
ERROR(sil_inst_autodiff_expected_function_type_operand,PointsToFirstBadToken,
1641+
"expected an operand of a function type", ())
16311642
ERROR(sil_inst_autodiff_expected_differentiability_witness_kind,PointsToFirstBadToken,
16321643
"expected a differentiability witness kind, which can be one of '[jvp]', "
16331644
"'[vjp]', or '[transpose]'", ())

include/swift/AST/FileUnit.h

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ class FileUnit : public DeclContext {
3030
#pragma clang diagnostic pop
3131
virtual void anchor();
3232

33+
friend class DirectOperatorLookupRequest;
34+
friend class DirectPrecedenceGroupLookupRequest;
35+
3336
// FIXME: Stick this in a PointerIntPair.
3437
const FileUnitKind Kind;
3538

@@ -107,6 +110,25 @@ class FileUnit : public DeclContext {
107110
const ModuleDecl *importedModule,
108111
SmallVectorImpl<Identifier> &spiGroups) const {};
109112

113+
protected:
114+
/// Look up an operator declaration. Do not call directly, use
115+
/// \c DirectOperatorLookupRequest instead.
116+
///
117+
/// \param name The operator name ("+", ">>", etc.)
118+
///
119+
/// \param fixity One of Prefix, Infix, or Postfix.
120+
virtual void
121+
lookupOperatorDirect(Identifier name, OperatorFixity fixity,
122+
TinyPtrVector<OperatorDecl *> &results) const {}
123+
124+
/// Look up a precedence group. Do not call directly, use
125+
/// \c DirectPrecedenceGroupLookupRequest instead.
126+
///
127+
/// \param name The precedence group name.
128+
virtual void lookupPrecedenceGroupDirect(
129+
Identifier name, TinyPtrVector<PrecedenceGroupDecl *> &results) const {}
130+
131+
public:
110132
/// Returns the comment attached to the given declaration.
111133
///
112134
/// This function is an implementation detail for comment serialization.
@@ -342,22 +364,6 @@ class LoadedFile : public FileUnit {
342364
return StringRef();
343365
}
344366

345-
/// Look up an operator declaration.
346-
///
347-
/// \param name The operator name ("+", ">>", etc.)
348-
///
349-
/// \param fixity One of PrefixOperator, InfixOperator, or PostfixOperator.
350-
virtual OperatorDecl *lookupOperator(Identifier name, DeclKind fixity) const {
351-
return nullptr;
352-
}
353-
354-
/// Look up a precedence group.
355-
///
356-
/// \param name The precedence group name.
357-
virtual PrecedenceGroupDecl *lookupPrecedenceGroup(Identifier name) const {
358-
return nullptr;
359-
}
360-
361367
/// Returns the Swift module that overlays a Clang module.
362368
virtual ModuleDecl *getOverlayModule() const { return nullptr; }
363369

include/swift/AST/NameLookupRequests.h

Lines changed: 66 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
#include "swift/AST/SimpleRequest.h"
2020
#include "swift/AST/ASTTypeIDs.h"
21+
#include "swift/AST/FileUnit.h"
2122
#include "swift/AST/Identifier.h"
2223
#include "swift/Basic/Statistic.h"
2324
#include "llvm/ADT/Hashing.h"
@@ -518,29 +519,53 @@ class DirectLookupRequest
518519

519520
class OperatorLookupDescriptor final {
520521
public:
521-
SourceFile *SF;
522+
using Storage = llvm::PointerUnion<FileUnit *, ModuleDecl *>;
523+
Storage fileOrModule;
522524
Identifier name;
523525
bool isCascading;
524526
SourceLoc diagLoc;
525527

526-
OperatorLookupDescriptor(SourceFile *SF, Identifier name, bool isCascading,
527-
SourceLoc diagLoc)
528-
: SF(SF), name(name), isCascading(isCascading), diagLoc(diagLoc) {}
528+
private:
529+
OperatorLookupDescriptor(Storage fileOrModule, Identifier name,
530+
bool isCascading, SourceLoc diagLoc)
531+
: fileOrModule(fileOrModule), name(name), isCascading(isCascading),
532+
diagLoc(diagLoc) {}
533+
534+
public:
535+
/// Retrieves the files to perform lookup in.
536+
ArrayRef<FileUnit *> getFiles() const;
537+
538+
/// If this is for a module lookup, returns the module. Otherwise returns
539+
/// \c nullptr.
540+
ModuleDecl *getModule() const {
541+
return fileOrModule.dyn_cast<ModuleDecl *>();
542+
}
529543

530544
friend llvm::hash_code hash_value(const OperatorLookupDescriptor &desc) {
531-
return llvm::hash_combine(desc.SF, desc.name, desc.isCascading);
545+
return llvm::hash_combine(desc.fileOrModule, desc.name, desc.isCascading);
532546
}
533547

534548
friend bool operator==(const OperatorLookupDescriptor &lhs,
535549
const OperatorLookupDescriptor &rhs) {
536-
return lhs.SF == rhs.SF && lhs.name == rhs.name &&
550+
return lhs.fileOrModule == rhs.fileOrModule && lhs.name == rhs.name &&
537551
lhs.isCascading == rhs.isCascading;
538552
}
539553

540554
friend bool operator!=(const OperatorLookupDescriptor &lhs,
541555
const OperatorLookupDescriptor &rhs) {
542556
return !(lhs == rhs);
543557
}
558+
559+
static OperatorLookupDescriptor forFile(FileUnit *file, Identifier name,
560+
bool isCascading, SourceLoc diagLoc) {
561+
return OperatorLookupDescriptor(file, name, isCascading, diagLoc);
562+
}
563+
564+
static OperatorLookupDescriptor forModule(ModuleDecl *mod, Identifier name,
565+
bool isCascading,
566+
SourceLoc diagLoc) {
567+
return OperatorLookupDescriptor(mod, name, isCascading, diagLoc);
568+
}
544569
};
545570

546571
void simple_display(llvm::raw_ostream &out,
@@ -572,6 +597,41 @@ using LookupInfixOperatorRequest = LookupOperatorRequest<InfixOperatorDecl>;
572597
using LookupPostfixOperatorRequest = LookupOperatorRequest<PostfixOperatorDecl>;
573598
using LookupPrecedenceGroupRequest = LookupOperatorRequest<PrecedenceGroupDecl>;
574599

600+
/// Looks up an operator in a given file or module without looking through
601+
/// imports.
602+
class DirectOperatorLookupRequest
603+
: public SimpleRequest<DirectOperatorLookupRequest,
604+
TinyPtrVector<OperatorDecl *>(
605+
OperatorLookupDescriptor, OperatorFixity),
606+
CacheKind::Uncached> {
607+
public:
608+
using SimpleRequest::SimpleRequest;
609+
610+
private:
611+
friend SimpleRequest;
612+
613+
llvm::Expected<TinyPtrVector<OperatorDecl *>>
614+
evaluate(Evaluator &evaluator, OperatorLookupDescriptor descriptor,
615+
OperatorFixity fixity) const;
616+
};
617+
618+
/// Looks up an precedencegroup in a given file or module without looking
619+
/// through imports.
620+
class DirectPrecedenceGroupLookupRequest
621+
: public SimpleRequest<DirectPrecedenceGroupLookupRequest,
622+
TinyPtrVector<PrecedenceGroupDecl *>(
623+
OperatorLookupDescriptor),
624+
CacheKind::Uncached> {
625+
public:
626+
using SimpleRequest::SimpleRequest;
627+
628+
private:
629+
friend SimpleRequest;
630+
631+
llvm::Expected<TinyPtrVector<PrecedenceGroupDecl *>>
632+
evaluate(Evaluator &evaluator, OperatorLookupDescriptor descriptor) const;
633+
};
634+
575635
#define SWIFT_TYPEID_ZONE NameLookup
576636
#define SWIFT_TYPEID_HEADER "swift/AST/NameLookupTypeIDZone.def"
577637
#include "swift/Basic/DefineTypeIDZone.h"

include/swift/AST/NameLookupTypeIDZone.def

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,13 @@ SWIFT_REQUEST(NameLookup, CustomAttrNominalRequest,
2424
SWIFT_REQUEST(NameLookup, DirectLookupRequest,
2525
TinyPtrVector<ValueDecl *>(DirectLookupDescriptor), Uncached,
2626
NoLocationInfo)
27+
SWIFT_REQUEST(NameLookup, DirectOperatorLookupRequest,
28+
TinyPtrVector<OperatorDecl *>(OperatorLookupDescriptor,
29+
OperatorFixity),
30+
Uncached, NoLocationInfo)
31+
SWIFT_REQUEST(NameLookup, DirectPrecedenceGroupLookupRequest,
32+
TinyPtrVector<PrecedenceGroupDecl *>(OperatorLookupDescriptor),
33+
Uncached, NoLocationInfo)
2734
SWIFT_REQUEST(NameLookup, ExpandASTScopeRequest,
2835
ast_scope::ASTScopeImpl* (ast_scope::ASTScopeImpl*, ast_scope::ScopeCreator*),
2936
SeparatelyCached,

include/swift/AST/SourceFile.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,16 @@ class SourceFile final : public FileUnit {
436436
ObjCSelector selector,
437437
SmallVectorImpl<AbstractFunctionDecl *> &results) const override;
438438

439+
protected:
440+
virtual void
441+
lookupOperatorDirect(Identifier name, OperatorFixity fixity,
442+
TinyPtrVector<OperatorDecl *> &results) const override;
443+
444+
virtual void lookupPrecedenceGroupDirect(
445+
Identifier name,
446+
TinyPtrVector<PrecedenceGroupDecl *> &results) const override;
447+
448+
public:
439449
virtual void getTopLevelDecls(SmallVectorImpl<Decl*> &results) const override;
440450

441451
virtual void

0 commit comments

Comments
 (0)