Skip to content

Commit 6c5b7c9

Browse files
committed
[sil] Change SILParameterInfo/SILResultInfo's differentiability parameter to use an OptionSet so we can add other options.
I am doing this in preparation for adding options to SILParameterInfo/ SILResultInfo that state that a parameter/result is transferring. Even though I could have just introduced a new bit here, I instead streamlined the interface of SILParameterInfo/SILResultInfo to use an OptionSet instead of individual bits to make it easier to add new flags here. The reason why it is easier is that along API (e.x.: function argument) boundaries one does not have to marshal each field or pass each field. Instead one can just pass the whole OptionSet as an opaque thing. Using this I was able to change serialization/deserialization of SILParameterInfo/SILResultInfo so that one does not need to update them if one adds new fields! The reason why I am doing this for both SILParameterInfo/SILResultInfo in the same commit is because they share code in the demangler that I did not want to have to duplicate in an intervening commit. By changing them both at the same type, I didn't have to change anything without an actual need to. I am doing this in a separate commit from adding transferring support so I can validate correctness using the tests for the options already supported (currently only differentiability).
1 parent 4ea643f commit 6c5b7c9

16 files changed

+462
-299
lines changed

include/swift/AST/TypeDifferenceVisitor.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ class CanTypeDifferenceVisitor : public CanTypePairVisitor<Impl, bool> {
291291
bool visitComponent(CanType type1, CanType type2,
292292
SILParameterInfo param1, SILParameterInfo param2) {
293293
if (param1.getConvention() != param2.getConvention() ||
294-
param1.getDifferentiability() != param2.getDifferentiability())
294+
param1.getOptions().toRaw() != param2.getOptions().toRaw())
295295
return asImpl().visitDifferentTypeStructure(type1, type2);
296296

297297
return asImpl().visit(param1.getInterfaceType(),

include/swift/AST/Types.h

Lines changed: 179 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -4091,54 +4091,52 @@ inline bool isPackParameter(ParameterConvention conv) {
40914091
llvm_unreachable("bad convention kind");
40924092
}
40934093

4094-
/// The differentiability of a SIL function type parameter.
4095-
enum class SILParameterDifferentiability : bool {
4096-
/// Either differentiable or not applicable.
4097-
///
4098-
/// - If the function type is not `@differentiable`, parameter
4099-
/// differentiability is not applicable. This case is the default value.
4100-
/// - If the function type is `@differentiable`, the function is
4101-
/// differentiable with respect to this parameter.
4102-
DifferentiableOrNotApplicable,
4103-
4104-
/// Not differentiable: a `@noDerivative` parameter.
4105-
///
4106-
/// May be applied only to parameters of `@differentiable` function types.
4107-
/// The function type is not differentiable with respect to this parameter.
4108-
NotDifferentiable,
4109-
};
4110-
41114094
/// A parameter type and the rules for passing it.
41124095
class SILParameterInfo {
4113-
CanType Type;
4114-
ParameterConvention Convention;
4115-
SILParameterDifferentiability Differentiability;
4096+
public:
4097+
enum Flag : uint8_t {
4098+
None = 0x0,
4099+
4100+
/// Not differentiable: a `@noDerivative` parameter.
4101+
///
4102+
/// May be applied only to parameters of `@differentiable` function types.
4103+
/// The function type is not differentiable with respect to this parameter.
4104+
///
4105+
/// If this is not set then the parameter is either differentiable or
4106+
/// differentiation is not applicable to the parameter:
4107+
///
4108+
/// - If the function type is not `@differentiable`, parameter
4109+
/// differentiability is not applicable. This case is the default value.
4110+
/// - If the function type is `@differentiable`, the function is
4111+
/// differentiable with respect to this parameter.
4112+
NotDifferentiable = 0x1,
4113+
};
4114+
4115+
using Options = OptionSet<Flag>;
4116+
4117+
private:
4118+
CanType type;
4119+
ParameterConvention convention;
4120+
Options options;
41164121

41174122
public:
41184123
SILParameterInfo() = default;//: Ty(), Convention((ParameterConvention)0) {}
4119-
SILParameterInfo(
4120-
CanType type, ParameterConvention conv,
4121-
SILParameterDifferentiability differentiability =
4122-
SILParameterDifferentiability::DifferentiableOrNotApplicable)
4123-
: Type(type), Convention(conv), Differentiability(differentiability) {
4124+
SILParameterInfo(CanType type, ParameterConvention conv, Options options = {})
4125+
: type(type), convention(conv), options(options) {
41244126
assert(type->isLegalSILType() && "SILParameterInfo has illegal SIL type");
41254127
}
41264128

41274129
/// Return the unsubstituted parameter type that describes the abstract
41284130
/// calling convention of the parameter.
41294131
///
41304132
/// For most purposes, you probably want \c getArgumentType .
4131-
CanType getInterfaceType() const {
4132-
return Type;
4133-
}
4134-
4133+
CanType getInterfaceType() const { return type; }
4134+
41354135
/// Return the type of a call argument matching this parameter.
41364136
///
41374137
/// \c t must refer back to the function type this is a parameter for.
41384138
CanType getArgumentType(SILModule &M, const SILFunctionType *t, TypeExpansionContext context) const;
4139-
ParameterConvention getConvention() const {
4140-
return Convention;
4141-
}
4139+
ParameterConvention getConvention() const { return convention; }
41424140
// Does this parameter convention require indirect storage? This reflects a
41434141
// SILFunctionType's formal (immutable) conventions, as opposed to the
41444142
// transient SIL conventions that dictate SILValue types.
@@ -4185,14 +4183,65 @@ class SILParameterInfo {
41854183
return isGuaranteedParameter(getConvention());
41864184
}
41874185

4188-
SILParameterDifferentiability getDifferentiability() const {
4189-
return Differentiability;
4186+
bool hasOption(Flag flag) const { return options.contains(flag); }
4187+
4188+
Options getOptions() const { return options; }
4189+
4190+
SILParameterInfo addingOption(Flag flag) const {
4191+
auto options = getOptions();
4192+
options |= flag;
4193+
return SILParameterInfo(getInterfaceType(), getConvention(), options);
41904194
}
41914195

4192-
SILParameterInfo getWithDifferentiability(
4193-
SILParameterDifferentiability differentiability) const {
4196+
SILParameterInfo removingOption(Flag flag) const {
4197+
auto options = getOptions();
4198+
options &= flag;
4199+
return SILParameterInfo(getInterfaceType(), getConvention(), options);
4200+
}
4201+
4202+
/// Add all flags in \p arg into a copy of this parameter info and return the
4203+
/// parameter info.
4204+
///
4205+
/// NOTE: You can pass in SILParameterInfo::Flag to this function since said
4206+
/// type auto converts to Options.
4207+
SILParameterInfo operator|(Options arg) const {
4208+
return SILParameterInfo(getInterfaceType(), getConvention(),
4209+
getOptions() | arg);
4210+
}
4211+
4212+
SILParameterInfo &operator|=(Options arg) {
4213+
options |= arg;
4214+
return *this;
4215+
}
4216+
4217+
/// Copy this parameter and intersect \p arg with the parameters former
4218+
/// options.
4219+
///
4220+
/// NOTE: You can pass in SILParameterInfo::Flag to this function since said
4221+
/// type auto converts to Options.
4222+
SILParameterInfo operator&(Options arg) const {
41944223
return SILParameterInfo(getInterfaceType(), getConvention(),
4195-
differentiability);
4224+
getOptions() & arg);
4225+
}
4226+
4227+
SILParameterInfo &operator&=(Options arg) {
4228+
options &= arg;
4229+
return *this;
4230+
}
4231+
4232+
/// Copy this parameter such that its options contains the set subtraction of
4233+
/// \p arg from the parameters former options.
4234+
///
4235+
/// NOTE: You can pass in SILParameterInfo::Flag to this function since said
4236+
/// type auto converts to Options.
4237+
SILParameterInfo operator-(Options arg) const {
4238+
return SILParameterInfo(getInterfaceType(), getConvention(),
4239+
getOptions() - arg);
4240+
}
4241+
4242+
SILParameterInfo &operator-=(Options arg) {
4243+
options -= arg;
4244+
return *this;
41964245
}
41974246

41984247
/// The SIL storage type determines the ABI for arguments based purely on the
@@ -4208,12 +4257,12 @@ class SILParameterInfo {
42084257

42094258
/// Return a version of this parameter info with the type replaced.
42104259
SILParameterInfo getWithInterfaceType(CanType type) const {
4211-
return SILParameterInfo(type, getConvention(), getDifferentiability());
4260+
return SILParameterInfo(type, getConvention(), getOptions());
42124261
}
42134262

42144263
/// Return a version of this parameter info with the convention replaced.
42154264
SILParameterInfo getWithConvention(ParameterConvention c) const {
4216-
return SILParameterInfo(getInterfaceType(), c, getDifferentiability());
4265+
return SILParameterInfo(getInterfaceType(), c, getOptions());
42174266
}
42184267

42194268
/// Transform this SILParameterInfo by applying the user-provided
@@ -4243,7 +4292,7 @@ class SILParameterInfo {
42434292
void profile(llvm::FoldingSetNodeID &id) {
42444293
id.AddPointer(getInterfaceType().getPointer());
42454294
id.AddInteger((unsigned)getConvention());
4246-
id.AddInteger((unsigned)getDifferentiability());
4295+
id.AddInteger((unsigned)getOptions().toRaw());
42474296
}
42484297

42494298
SWIFT_DEBUG_DUMP;
@@ -4259,7 +4308,7 @@ class SILParameterInfo {
42594308
bool operator==(SILParameterInfo rhs) const {
42604309
return getInterfaceType() == rhs.getInterfaceType() &&
42614310
getConvention() == rhs.getConvention() &&
4262-
getDifferentiability() == rhs.getDifferentiability();
4311+
getOptions().toRaw() == rhs.getOptions().toRaw();
42634312
}
42644313
bool operator!=(SILParameterInfo rhs) const {
42654314
return !(*this == rhs);
@@ -4308,64 +4357,114 @@ inline bool isIndirectFormalResult(ResultConvention convention) {
43084357
convention == ResultConvention::Pack;
43094358
}
43104359

4311-
/// The differentiability of a SIL function type result.
4312-
enum class SILResultDifferentiability : bool {
4313-
/// Either differentiable or not applicable.
4314-
///
4315-
/// - If the function type is not `@differentiable`, result
4316-
/// differentiability is not applicable. This case is the default value.
4317-
/// - If the function type is `@differentiable`, the function is
4318-
/// differentiable with respect to this result.
4319-
DifferentiableOrNotApplicable,
4320-
4321-
/// Not differentiable: a `@noDerivative` result.
4322-
///
4323-
/// May be applied only to result of `@differentiable` function types.
4324-
/// The function type is not differentiable with respect to this result.
4325-
NotDifferentiable,
4326-
};
4327-
43284360
/// A result type and the rules for returning it.
43294361
class SILResultInfo {
4330-
CanType Type;
4331-
ResultConvention Convention;
4332-
SILResultDifferentiability Differentiability;
4362+
public:
4363+
enum Flag : uint8_t {
4364+
None = 0x0,
4365+
4366+
/// Not differentiable: a `@noDerivative` result.
4367+
///
4368+
/// May be applied only to result of `@differentiable` function types.
4369+
/// The function type is not differentiable with respect to this result.
4370+
///
4371+
/// If this is not set then the function is either differentiable or
4372+
/// differentiability is not applicable. This can occur if:
4373+
///
4374+
/// - The function type is not `@differentiable`, result
4375+
/// differentiability is not applicable. This case is the default value.
4376+
/// - The function type is `@differentiable`, the function is
4377+
/// differentiable with respect to this result.
4378+
NotDifferentiable = 0x1,
4379+
};
4380+
4381+
using Options = OptionSet<Flag>;
4382+
4383+
private:
4384+
CanType type;
4385+
ResultConvention convention;
4386+
Options options;
43334387

43344388
public:
43354389
SILResultInfo() = default;
4336-
SILResultInfo(CanType type, ResultConvention conv,
4337-
SILResultDifferentiability differentiability =
4338-
SILResultDifferentiability::DifferentiableOrNotApplicable)
4339-
: Type(type), Convention(conv), Differentiability(differentiability) {
4390+
SILResultInfo(CanType type, ResultConvention conv, Options options = {})
4391+
: type(type), convention(conv), options(options) {
43404392
assert(type->isLegalSILType() && "SILResultInfo has illegal SIL type");
43414393
}
43424394

43434395
/// Return the unsubstituted parameter type that describes the abstract
43444396
/// calling convention of the parameter.
43454397
///
43464398
/// For most purposes, you probably want \c getReturnValueType .
4347-
CanType getInterfaceType() const {
4348-
return Type;
4349-
}
4350-
4399+
CanType getInterfaceType() const { return type; }
4400+
43514401
/// The type of a return value corresponding to this result.
43524402
///
43534403
/// \c t must refer back to the function type this is a parameter for.
43544404
CanType getReturnValueType(SILModule &M, const SILFunctionType *t,
43554405
TypeExpansionContext context) const;
43564406

4357-
ResultConvention getConvention() const {
4358-
return Convention;
4407+
ResultConvention getConvention() const { return convention; }
4408+
4409+
Options getOptions() const { return options; }
4410+
4411+
bool hasOption(Flag flag) const { return options.contains(flag); }
4412+
4413+
SILResultInfo addingOption(Flag flag) const {
4414+
auto options = getOptions();
4415+
options |= flag;
4416+
return SILResultInfo(getInterfaceType(), getConvention(), options);
4417+
}
4418+
4419+
SILResultInfo removingOption(Flag flag) const {
4420+
auto options = getOptions();
4421+
options &= flag;
4422+
return SILResultInfo(getInterfaceType(), getConvention(), options);
4423+
}
4424+
4425+
/// Add all flags in \p arg into a copy of this parameter info and return the
4426+
/// parameter info.
4427+
///
4428+
/// NOTE: You can pass in SILResultInfo::Flag to this function since said
4429+
/// type auto converts to Options.
4430+
SILResultInfo operator|(Options arg) const {
4431+
return SILResultInfo(getInterfaceType(), getConvention(),
4432+
getOptions() | arg);
43594433
}
43604434

4361-
SILResultDifferentiability getDifferentiability() const {
4362-
return Differentiability;
4435+
SILResultInfo &operator|=(Options arg) {
4436+
options |= arg;
4437+
return *this;
43634438
}
43644439

4365-
SILResultInfo
4366-
getWithDifferentiability(SILResultDifferentiability differentiability) const {
4440+
/// Copy this parameter and intersect \p arg with the parameters former
4441+
/// options.
4442+
///
4443+
/// NOTE: You can pass in SILResultInfo::Flag to this function since said
4444+
/// type auto converts to Options.
4445+
SILResultInfo operator&(Options arg) const {
43674446
return SILResultInfo(getInterfaceType(), getConvention(),
4368-
differentiability);
4447+
getOptions() & arg);
4448+
}
4449+
4450+
SILResultInfo &operator&=(Options arg) {
4451+
options &= arg;
4452+
return *this;
4453+
}
4454+
4455+
/// Copy this parameter such that its options contains the set subtraction of
4456+
/// \p arg from the parameters former options.
4457+
///
4458+
/// NOTE: You can pass in SILResultInfo::Flag to this function since said
4459+
/// type auto converts to Options.
4460+
SILResultInfo operator-(Options arg) const {
4461+
return SILResultInfo(getInterfaceType(), getConvention(),
4462+
getOptions() - arg);
4463+
}
4464+
4465+
SILResultInfo &operator-=(Options arg) {
4466+
options -= arg;
4467+
return *this;
43694468
}
43704469

43714470
/// The SIL storage type determines the ABI for arguments based purely on the
@@ -4428,9 +4527,9 @@ class SILResultInfo {
44284527
}
44294528

44304529
void profile(llvm::FoldingSetNodeID &id) {
4431-
id.AddPointer(Type.getPointer());
4530+
id.AddPointer(type.getPointer());
44324531
id.AddInteger(unsigned(getConvention()));
4433-
id.AddInteger(unsigned(getDifferentiability()));
4532+
id.AddInteger(unsigned(getOptions().toRaw()));
44344533
}
44354534

44364535
SWIFT_DEBUG_DUMP;
@@ -4447,8 +4546,8 @@ class SILResultInfo {
44474546
getOwnershipKind(SILFunction &, CanSILFunctionType fTy) const; // in SILType.cpp
44484547

44494548
bool operator==(SILResultInfo rhs) const {
4450-
return Type == rhs.Type && Convention == rhs.Convention
4451-
&& Differentiability == rhs.Differentiability;
4549+
return type == rhs.type && convention == rhs.convention &&
4550+
getOptions().toRaw() == rhs.getOptions().toRaw();
44524551
}
44534552
bool operator!=(SILResultInfo rhs) const {
44544553
return !(*this == rhs);

0 commit comments

Comments
 (0)