Skip to content

Commit ac48feb

Browse files
committed
Add @noDerivative flag to SILParameterInfo.
The `@noDerivative` attribute marks the non-differentiability parameters of a `@differentiable` function type. All parameters except those marked with `@noDerivative` are differentiability parameters. For example, `@differentiable (Float, @noDerivative Float) -> Float` is only differentiable with respect to its first parameter. The `@noDerivative` attribute is represented as a `SILParameterDifferentiability` bit on `SILParameterInfo`. Add round-trip serialization tests. Resolves TF-872.
1 parent 0d7820c commit ac48feb

File tree

8 files changed

+137
-10
lines changed

8 files changed

+137
-10
lines changed

include/swift/AST/Types.h

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3637,13 +3637,35 @@ inline bool isGuaranteedParameter(ParameterConvention conv) {
36373637
llvm_unreachable("bad convention kind");
36383638
}
36393639

3640+
/// The differentiability of a SIL function type parameter.
3641+
enum class SILParameterDifferentiability : unsigned {
3642+
/// Either differentiable or not applicable.
3643+
///
3644+
/// - If the function type is not `@differentiable`, parameter
3645+
/// differentiability is not applicable. This case is the default value.
3646+
/// - If the function type is `@differentiable`, the function is
3647+
/// differentiable with respect to this parameter.
3648+
DifferentiableOrNotApplicable,
3649+
3650+
/// Not differentiable: a `@noDerivative` parameter.
3651+
///
3652+
/// May be applied only to parameters of `@differentiable` function types.
3653+
/// The function type is not differentiable with respect to this parameter.
3654+
NotDifferentiable,
3655+
};
3656+
36403657
/// A parameter type and the rules for passing it.
36413658
class SILParameterInfo {
36423659
llvm::PointerIntPair<CanType, 3, ParameterConvention> TypeAndConvention;
3660+
SILParameterDifferentiability Differentiability : 1;
3661+
36433662
public:
36443663
SILParameterInfo() = default;//: Ty(), Convention((ParameterConvention)0) {}
3645-
SILParameterInfo(CanType type, ParameterConvention conv)
3646-
: TypeAndConvention(type, conv) {
3664+
SILParameterInfo(
3665+
CanType type, ParameterConvention conv,
3666+
SILParameterDifferentiability differentiability =
3667+
SILParameterDifferentiability::DifferentiableOrNotApplicable)
3668+
: TypeAndConvention(type, conv), Differentiability(differentiability) {
36473669
assert(type->isLegalSILType() && "SILParameterInfo has illegal SIL type");
36483670
}
36493671

@@ -3698,6 +3720,16 @@ class SILParameterInfo {
36983720
return isGuaranteedParameter(getConvention());
36993721
}
37003722

3723+
SILParameterDifferentiability getDifferentiability() const {
3724+
return Differentiability;
3725+
}
3726+
3727+
SILParameterInfo getWithDifferentiability(
3728+
SILParameterDifferentiability differentiability) const {
3729+
return SILParameterInfo(getInterfaceType(), getConvention(),
3730+
differentiability);
3731+
}
3732+
37013733
/// The SIL storage type determines the ABI for arguments based purely on the
37023734
/// formal parameter conventions. The actual SIL type for the argument values
37033735
/// may differ in canonical SIL. In particular, opaque values require indirect
@@ -3726,6 +3758,7 @@ class SILParameterInfo {
37263758
void profile(llvm::FoldingSetNodeID &id) {
37273759
id.AddPointer(getInterfaceType().getPointer());
37283760
id.AddInteger((unsigned)getConvention());
3761+
id.AddInteger((unsigned)getDifferentiability());
37293762
}
37303763

37313764
SWIFT_DEBUG_DUMP;
@@ -3739,8 +3772,9 @@ class SILParameterInfo {
37393772
}
37403773

37413774
bool operator==(SILParameterInfo rhs) const {
3742-
return getInterfaceType() == rhs.getInterfaceType()
3743-
&& getConvention() == rhs.getConvention();
3775+
return getInterfaceType() == rhs.getInterfaceType() &&
3776+
getConvention() == rhs.getConvention() &&
3777+
getDifferentiability() == rhs.getDifferentiability();
37443778
}
37453779
bool operator!=(SILParameterInfo rhs) const {
37463780
return !(*this == rhs);

lib/AST/ASTContext.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3327,6 +3327,15 @@ SILFunctionType::SILFunctionType(
33273327
"Cannot return an @noescape function type");
33283328
}
33293329
}
3330+
3331+
// Check that `@noDerivative` parameters only exist on `@differentiable`
3332+
// functions.
3333+
if (!ext.isDifferentiable())
3334+
for (auto param : getParameters())
3335+
assert(param.getDifferentiability() ==
3336+
SILParameterDifferentiability::DifferentiableOrNotApplicable &&
3337+
"non-`@differentiable` function should not have NotDifferentiable "
3338+
"parameter");
33303339
#endif
33313340
}
33323341

lib/AST/ASTPrinter.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4649,6 +4649,13 @@ void SILParameterInfo::print(raw_ostream &OS, const PrintOptions &Opts) const {
46494649
}
46504650
void SILParameterInfo::print(ASTPrinter &Printer,
46514651
const PrintOptions &Opts) const {
4652+
switch (getDifferentiability()) {
4653+
case SILParameterDifferentiability::NotDifferentiable:
4654+
Printer << "@noDerivative ";
4655+
break;
4656+
default:
4657+
break;
4658+
}
46524659
Printer << getStringForParameterConvention(getConvention());
46534660
getInterfaceType().print(Printer, Opts);
46544661
}

lib/Sema/TypeCheckType.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2948,6 +2948,8 @@ SILParameterInfo TypeResolver::resolveSILParameter(
29482948
auto convention = DefaultParameterConvention;
29492949
Type type;
29502950
bool hadError = false;
2951+
auto differentiability =
2952+
SILParameterDifferentiability::DifferentiableOrNotApplicable;
29512953

29522954
if (auto attrRepr = dyn_cast<AttributedTypeRepr>(repr)) {
29532955
auto attrs = attrRepr->getAttrs();
@@ -2973,6 +2975,10 @@ SILParameterInfo TypeResolver::resolveSILParameter(
29732975
checkFor(TypeAttrKind::TAK_owned, ParameterConvention::Direct_Owned);
29742976
checkFor(TypeAttrKind::TAK_guaranteed,
29752977
ParameterConvention::Direct_Guaranteed);
2978+
if (attrs.has(TAK_noDerivative)) {
2979+
attrs.clearAttribute(TAK_noDerivative);
2980+
differentiability = SILParameterDifferentiability::NotDifferentiable;
2981+
}
29762982

29772983
type = resolveAttributedType(attrs, attrRepr->getTypeRepr(), options);
29782984
} else {
@@ -2989,7 +2995,8 @@ SILParameterInfo TypeResolver::resolveSILParameter(
29892995
}
29902996

29912997
if (hadError) type = ErrorType::get(Context);
2992-
return SILParameterInfo(type->getCanonicalType(), convention);
2998+
return SILParameterInfo(type->getCanonicalType(), convention,
2999+
differentiability);
29933000
}
29943001

29953002
bool TypeResolver::resolveSingleSILResult(TypeRepr *repr,

lib/Serialization/Deserialization.cpp

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4503,6 +4503,21 @@ Optional<swift::ParameterConvention> getActualParameterConvention(uint8_t raw) {
45034503
return None;
45044504
}
45054505

4506+
/// Translate from the serialization SILParameterDifferentiability enumerators,
4507+
/// which are guaranteed to be stable, to the AST ones.
4508+
static Optional<swift::SILParameterDifferentiability>
4509+
getActualSILParameterDifferentiability(uint8_t raw) {
4510+
switch (serialization::SILParameterDifferentiability(raw)) {
4511+
#define CASE(ID) \
4512+
case serialization::SILParameterDifferentiability::ID: \
4513+
return swift::SILParameterDifferentiability::ID;
4514+
CASE(DifferentiableOrNotApplicable)
4515+
CASE(NotDifferentiable)
4516+
#undef CASE
4517+
}
4518+
return None;
4519+
}
4520+
45064521
/// Translate from the serialization ResultConvention enumerators,
45074522
/// which are guaranteed to be stable, to the AST ones.
45084523
static
@@ -5144,15 +5159,26 @@ class TypeDeserializer {
51445159
if (!calleeConvention.hasValue())
51455160
MF.fatal();
51465161

5147-
auto processParameter = [&](TypeID typeID, uint64_t rawConvention)
5148-
-> llvm::Expected<SILParameterInfo> {
5162+
auto processParameter =
5163+
[&](TypeID typeID, uint64_t rawConvention,
5164+
uint64_t ramDifferentiability) -> llvm::Expected<SILParameterInfo> {
51495165
auto convention = getActualParameterConvention(rawConvention);
51505166
if (!convention)
51515167
MF.fatal();
51525168
auto type = MF.getTypeChecked(typeID);
51535169
if (!type)
51545170
return type.takeError();
5155-
return SILParameterInfo(type.get()->getCanonicalType(), *convention);
5171+
auto differentiability =
5172+
swift::SILParameterDifferentiability::DifferentiableOrNotApplicable;
5173+
if (diffKind != DifferentiabilityKind::NonDifferentiable) {
5174+
auto differentiabilityOpt =
5175+
getActualSILParameterDifferentiability(ramDifferentiability);
5176+
if (!differentiabilityOpt)
5177+
MF.fatal();
5178+
differentiability = *differentiabilityOpt;
5179+
}
5180+
return SILParameterInfo(type.get()->getCanonicalType(), *convention,
5181+
differentiability);
51565182
};
51575183

51585184
auto processYield = [&](TypeID typeID, uint64_t rawConvention)
@@ -5191,7 +5217,10 @@ class TypeDeserializer {
51915217
for (unsigned i = 0; i != numParams; ++i) {
51925218
auto typeID = variableData[nextVariableDataIndex++];
51935219
auto rawConvention = variableData[nextVariableDataIndex++];
5194-
auto param = processParameter(typeID, rawConvention);
5220+
uint64_t differentiability = 0;
5221+
if (diffKind != DifferentiabilityKind::NonDifferentiable)
5222+
differentiability = variableData[nextVariableDataIndex++];
5223+
auto param = processParameter(typeID, rawConvention, differentiability);
51955224
if (!param)
51965225
return param.takeError();
51975226
allParams.push_back(param.get());

lib/Serialization/ModuleFormat.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ const uint16_t SWIFTMODULE_VERSION_MAJOR = 0;
5555
/// describe what change you made. The content of this comment isn't important;
5656
/// it just ensures a conflict if two people change the module format.
5757
/// Don't worry about adhering to the 80-column limit for this line.
58-
const uint16_t SWIFTMODULE_VERSION_MINOR = 533; // removed @_implicitly_synthesizes_nested_requirement
58+
const uint16_t SWIFTMODULE_VERSION_MINOR = 534; // add SIL parameter differentiability
5959

6060
/// A standard hash seed used for all string hashes in a serialized module.
6161
///
@@ -347,6 +347,13 @@ enum class ParameterConvention : uint8_t {
347347
};
348348
using ParameterConventionField = BCFixed<4>;
349349

350+
// These IDs must \em not be renumbered or reordered without incrementing
351+
// the module version.
352+
enum class SILParameterDifferentiability : uint8_t {
353+
DifferentiableOrNotApplicable,
354+
NotDifferentiable,
355+
};
356+
350357
// These IDs must \em not be renumbered or reordered without incrementing
351358
// the module version.
352359
enum class ResultConvention : uint8_t {

lib/Serialization/Serialization.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3766,6 +3766,17 @@ static uint8_t getRawStableParameterConvention(swift::ParameterConvention pc) {
37663766
llvm_unreachable("bad parameter convention kind");
37673767
}
37683768

3769+
/// Translate from AST SILParameterDifferentiability enum to the Serialization
3770+
/// enum values, which are guaranteed to be stable.
3771+
static uint8_t
3772+
getRawSILParameterDifferentiability(swift::SILParameterDifferentiability pd) {
3773+
switch (pd) {
3774+
SIMPLE_CASE(SILParameterDifferentiability, DifferentiableOrNotApplicable)
3775+
SIMPLE_CASE(SILParameterDifferentiability, NotDifferentiable)
3776+
}
3777+
llvm_unreachable("bad parameter differentiability kind");
3778+
}
3779+
37693780
/// Translate from the AST ResultConvention enum to the
37703781
/// Serialization enum values, which are guaranteed to be stable.
37713782
static uint8_t getRawStableResultConvention(swift::ResultConvention rc) {
@@ -4075,6 +4086,9 @@ class Serializer::TypeSerializer : public TypeVisitor<TypeSerializer> {
40754086
variableData.push_back(S.addTypeRef(param.getInterfaceType()));
40764087
unsigned conv = getRawStableParameterConvention(param.getConvention());
40774088
variableData.push_back(TypeID(conv));
4089+
if (fnTy->isDifferentiable())
4090+
variableData.push_back(TypeID(
4091+
getRawSILParameterDifferentiability(param.getDifferentiability())));
40784092
}
40794093
for (auto yield : fnTy->getYields()) {
40804094
variableData.push_back(S.addTypeRef(yield.getInterfaceType()));

test/AutoDiff/SIL/Serialization/differentiation.swift

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,23 @@ bb0(%0 : $@differentiable(linear) (Float) -> Float):
2626
// CHECK: bb0([[ARG:%.*]] : $@differentiable(linear) (Float) -> Float):
2727
// CHECK: return [[ARG]] : $@differentiable(linear) (Float) -> Float
2828
// CHECK: }
29+
30+
sil @c : $@convention(thin) (@differentiable (Float, @noDerivative Float) -> Float) -> @differentiable (Float, @noDerivative Float) -> Float {
31+
bb0(%0 : $@differentiable (Float, @noDerivative Float) -> Float):
32+
return %0 : $@differentiable (Float, @noDerivative Float) -> Float
33+
}
34+
35+
// CHECK-LABEL: sil @c : $@convention(thin) (@differentiable (Float, @noDerivative Float) -> Float) -> @differentiable (Float, @noDerivative Float) -> Float {
36+
// CHECK: bb0(%0 : $@differentiable (Float, @noDerivative Float) -> Float):
37+
// CHECK: return %0 : $@differentiable (Float, @noDerivative Float) -> Float
38+
// CHECK: }
39+
40+
sil @d : $@convention(thin) (@differentiable(linear) (Float, @noDerivative Float) -> Float) -> @differentiable(linear) (Float, @noDerivative Float) -> Float {
41+
bb0(%0 : $@differentiable(linear) (Float, @noDerivative Float) -> Float):
42+
return %0 : $@differentiable(linear) (Float, @noDerivative Float) -> Float
43+
}
44+
45+
// CHECK-LABEL: sil @d : $@convention(thin) (@differentiable(linear) (Float, @noDerivative Float) -> Float) -> @differentiable(linear) (Float, @noDerivative Float) -> Float {
46+
// CHECK: bb0(%0 : $@differentiable(linear) (Float, @noDerivative Float) -> Float):
47+
// CHECK: return %0 : $@differentiable(linear) (Float, @noDerivative Float) -> Float
48+
// CHECK: }

0 commit comments

Comments
 (0)