Skip to content

Commit ad6b7aa

Browse files
committed
Add DeclAttribute * to SILDifferentiabilityWitness.
Unserialized, to be used for diagnostics. Will revisit later when revamping the differentiation transform.
1 parent f240ed2 commit ad6b7aa

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

include/swift/SIL/SILDifferentiabilityWitness.h

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#ifndef SWIFT_SIL_SILDIFFERENTIABILITYWITNESS_H
2727
#define SWIFT_SIL_SILDIFFERENTIABILITYWITNESS_H
2828

29+
#include "swift/AST/Attr.h"
2930
#include "swift/AST/AutoDiff.h"
3031
#include "swift/AST/GenericSignature.h"
3132
#include "swift/SIL/SILAllocated.h"
@@ -60,6 +61,11 @@ class SILDifferentiabilityWitness
6061
/// Whether or not this differentiability witness is serialized, which allows
6162
/// devirtualization from another module.
6263
bool serialized;
64+
/// The AST `@differentiable` or `@differentiating` attribute from which the
65+
/// differentiability witness is generated. Used for diagnostics.
66+
/// Null if the differentiability witness is parsed from SIL or if it is
67+
/// deserialized.
68+
DeclAttribute *attribute = nullptr;
6369

6470
static AutoDiffConfig *
6571
getAutoDiffConfig(SILModule &module, IndexSubset *parameterIndices,
@@ -72,18 +78,18 @@ class SILDifferentiabilityWitness
7278
IndexSubset *resultIndices,
7379
GenericSignature *derivativeGenSig,
7480
SILFunction *jvp, SILFunction *vjp,
75-
bool isSerialized)
81+
bool isSerialized, DeclAttribute *attribute)
7682
: module(module), linkage(linkage), originalFunction(originalFunction),
7783
parameterIndices(parameterIndices), resultIndices(resultIndices),
7884
derivativeGenericSignature(derivativeGenSig), jvp(jvp), vjp(vjp),
79-
serialized(isSerialized) {}
85+
serialized(isSerialized), attribute(attribute) {}
8086

8187
public:
8288
static SILDifferentiabilityWitness *create(
8389
SILModule &module, SILLinkage linkage, SILFunction *originalFunction,
8490
IndexSubset *parameterIndices, IndexSubset *resultIndices,
8591
GenericSignature *derivativeGenSig, SILFunction *jvp, SILFunction *vjp,
86-
bool isSerialized);
92+
bool isSerialized, DeclAttribute *attribute = nullptr);
8793

8894
SILDifferentiabilityWitnessKey getKey() const;
8995
SILModule &getModule() const { return module; }
@@ -114,6 +120,7 @@ class SILDifferentiabilityWitness
114120
}
115121
}
116122
bool isSerialized() const { return serialized; }
123+
DeclAttribute *getAttribute() const { return attribute; }
117124

118125
/// Verify that the differentiability witness is well-formed.
119126
void verify(const SILModule &M) const;

lib/SIL/SILDifferentiabilityWitness.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,12 @@ SILDifferentiabilityWitness *SILDifferentiabilityWitness::create(
2121
SILModule &module, SILLinkage linkage, SILFunction *originalFunction,
2222
IndexSubset *parameterIndices, IndexSubset *resultIndices,
2323
GenericSignature *derivativeGenSig, SILFunction *jvp, SILFunction *vjp,
24-
bool isSerialized) {
24+
bool isSerialized, DeclAttribute *attribute) {
2525
void *buf = module.allocate(sizeof(SILDifferentiabilityWitness),
2626
alignof(SILDifferentiabilityWitness));
2727
auto *diffWitness = ::new (buf) SILDifferentiabilityWitness(
2828
module, linkage, originalFunction, parameterIndices, resultIndices,
29-
derivativeGenSig, jvp, vjp, isSerialized);
29+
derivativeGenSig, jvp, vjp, isSerialized, attribute);
3030
// Register the differentiability witness in the module.
3131
assert(!module.DifferentiabilityWitnessMap.count(diffWitness->getKey()) &&
3232
"Cannot create duplicate differentiability witness in a module");

0 commit comments

Comments
 (0)