@@ -4091,54 +4091,52 @@ inline bool isPackParameter(ParameterConvention conv) {
4091
4091
llvm_unreachable (" bad convention kind" );
4092
4092
}
4093
4093
4094
- // / The differentiability of a SIL function type parameter.
4095
- enum class SILParameterDifferentiability : bool {
4096
- // / Either differentiable or not applicable.
4097
- // /
4098
- // / - If the function type is not `@differentiable`, parameter
4099
- // / differentiability is not applicable. This case is the default value.
4100
- // / - If the function type is `@differentiable`, the function is
4101
- // / differentiable with respect to this parameter.
4102
- DifferentiableOrNotApplicable,
4103
-
4104
- // / Not differentiable: a `@noDerivative` parameter.
4105
- // /
4106
- // / May be applied only to parameters of `@differentiable` function types.
4107
- // / The function type is not differentiable with respect to this parameter.
4108
- NotDifferentiable,
4109
- };
4110
-
4111
4094
// / A parameter type and the rules for passing it.
4112
4095
class SILParameterInfo {
4113
- CanType Type;
4114
- ParameterConvention Convention;
4115
- SILParameterDifferentiability Differentiability;
4096
+ public:
4097
+ enum Flag : uint8_t {
4098
+ None = 0x0 ,
4099
+
4100
+ // / Not differentiable: a `@noDerivative` parameter.
4101
+ // /
4102
+ // / May be applied only to parameters of `@differentiable` function types.
4103
+ // / The function type is not differentiable with respect to this parameter.
4104
+ // /
4105
+ // / If this is not set then the parameter is either differentiable or
4106
+ // / differentiation is not applicable to the parameter:
4107
+ // /
4108
+ // / - If the function type is not `@differentiable`, parameter
4109
+ // / differentiability is not applicable. This case is the default value.
4110
+ // / - If the function type is `@differentiable`, the function is
4111
+ // / differentiable with respect to this parameter.
4112
+ NotDifferentiable = 0x1 ,
4113
+ };
4114
+
4115
+ using Options = OptionSet<Flag>;
4116
+
4117
+ private:
4118
+ CanType type;
4119
+ ParameterConvention convention;
4120
+ Options options;
4116
4121
4117
4122
public:
4118
4123
SILParameterInfo () = default ;// : Ty(), Convention((ParameterConvention)0) {}
4119
- SILParameterInfo (
4120
- CanType type, ParameterConvention conv,
4121
- SILParameterDifferentiability differentiability =
4122
- SILParameterDifferentiability::DifferentiableOrNotApplicable)
4123
- : Type(type), Convention(conv), Differentiability(differentiability) {
4124
+ SILParameterInfo (CanType type, ParameterConvention conv, Options options = {})
4125
+ : type(type), convention(conv), options(options) {
4124
4126
assert (type->isLegalSILType () && " SILParameterInfo has illegal SIL type" );
4125
4127
}
4126
4128
4127
4129
// / Return the unsubstituted parameter type that describes the abstract
4128
4130
// / calling convention of the parameter.
4129
4131
// /
4130
4132
// / For most purposes, you probably want \c getArgumentType .
4131
- CanType getInterfaceType () const {
4132
- return Type;
4133
- }
4134
-
4133
+ CanType getInterfaceType () const { return type; }
4134
+
4135
4135
// / Return the type of a call argument matching this parameter.
4136
4136
// /
4137
4137
// / \c t must refer back to the function type this is a parameter for.
4138
4138
CanType getArgumentType (SILModule &M, const SILFunctionType *t, TypeExpansionContext context) const ;
4139
- ParameterConvention getConvention () const {
4140
- return Convention;
4141
- }
4139
+ ParameterConvention getConvention () const { return convention; }
4142
4140
// Does this parameter convention require indirect storage? This reflects a
4143
4141
// SILFunctionType's formal (immutable) conventions, as opposed to the
4144
4142
// transient SIL conventions that dictate SILValue types.
@@ -4185,14 +4183,65 @@ class SILParameterInfo {
4185
4183
return isGuaranteedParameter (getConvention ());
4186
4184
}
4187
4185
4188
- SILParameterDifferentiability getDifferentiability () const {
4189
- return Differentiability;
4186
+ bool hasOption (Flag flag) const { return options.contains (flag); }
4187
+
4188
+ Options getOptions () const { return options; }
4189
+
4190
+ SILParameterInfo addingOption (Flag flag) const {
4191
+ auto options = getOptions ();
4192
+ options |= flag;
4193
+ return SILParameterInfo (getInterfaceType (), getConvention (), options);
4190
4194
}
4191
4195
4192
- SILParameterInfo getWithDifferentiability (
4193
- SILParameterDifferentiability differentiability) const {
4196
+ SILParameterInfo removingOption (Flag flag) const {
4197
+ auto options = getOptions ();
4198
+ options &= flag;
4199
+ return SILParameterInfo (getInterfaceType (), getConvention (), options);
4200
+ }
4201
+
4202
+ // / Add all flags in \p arg into a copy of this parameter info and return the
4203
+ // / parameter info.
4204
+ // /
4205
+ // / NOTE: You can pass in SILParameterInfo::Flag to this function since said
4206
+ // / type auto converts to Options.
4207
+ SILParameterInfo operator |(Options arg) const {
4208
+ return SILParameterInfo (getInterfaceType (), getConvention (),
4209
+ getOptions () | arg);
4210
+ }
4211
+
4212
+ SILParameterInfo &operator |=(Options arg) {
4213
+ options |= arg;
4214
+ return *this ;
4215
+ }
4216
+
4217
+ // / Copy this parameter and intersect \p arg with the parameters former
4218
+ // / options.
4219
+ // /
4220
+ // / NOTE: You can pass in SILParameterInfo::Flag to this function since said
4221
+ // / type auto converts to Options.
4222
+ SILParameterInfo operator &(Options arg) const {
4194
4223
return SILParameterInfo (getInterfaceType (), getConvention (),
4195
- differentiability);
4224
+ getOptions () & arg);
4225
+ }
4226
+
4227
+ SILParameterInfo &operator &=(Options arg) {
4228
+ options &= arg;
4229
+ return *this ;
4230
+ }
4231
+
4232
+ // / Copy this parameter such that its options contains the set subtraction of
4233
+ // / \p arg from the parameters former options.
4234
+ // /
4235
+ // / NOTE: You can pass in SILParameterInfo::Flag to this function since said
4236
+ // / type auto converts to Options.
4237
+ SILParameterInfo operator -(Options arg) const {
4238
+ return SILParameterInfo (getInterfaceType (), getConvention (),
4239
+ getOptions () - arg);
4240
+ }
4241
+
4242
+ SILParameterInfo &operator -=(Options arg) {
4243
+ options -= arg;
4244
+ return *this ;
4196
4245
}
4197
4246
4198
4247
// / The SIL storage type determines the ABI for arguments based purely on the
@@ -4208,12 +4257,12 @@ class SILParameterInfo {
4208
4257
4209
4258
// / Return a version of this parameter info with the type replaced.
4210
4259
SILParameterInfo getWithInterfaceType (CanType type) const {
4211
- return SILParameterInfo (type, getConvention (), getDifferentiability ());
4260
+ return SILParameterInfo (type, getConvention (), getOptions ());
4212
4261
}
4213
4262
4214
4263
// / Return a version of this parameter info with the convention replaced.
4215
4264
SILParameterInfo getWithConvention (ParameterConvention c) const {
4216
- return SILParameterInfo (getInterfaceType (), c, getDifferentiability ());
4265
+ return SILParameterInfo (getInterfaceType (), c, getOptions ());
4217
4266
}
4218
4267
4219
4268
// / Transform this SILParameterInfo by applying the user-provided
@@ -4243,7 +4292,7 @@ class SILParameterInfo {
4243
4292
void profile (llvm::FoldingSetNodeID &id) {
4244
4293
id.AddPointer (getInterfaceType ().getPointer ());
4245
4294
id.AddInteger ((unsigned )getConvention ());
4246
- id.AddInteger ((unsigned )getDifferentiability ());
4295
+ id.AddInteger ((unsigned )getOptions (). toRaw ());
4247
4296
}
4248
4297
4249
4298
SWIFT_DEBUG_DUMP;
@@ -4259,7 +4308,7 @@ class SILParameterInfo {
4259
4308
bool operator ==(SILParameterInfo rhs) const {
4260
4309
return getInterfaceType () == rhs.getInterfaceType () &&
4261
4310
getConvention () == rhs.getConvention () &&
4262
- getDifferentiability () == rhs.getDifferentiability ();
4311
+ getOptions (). toRaw () == rhs.getOptions (). toRaw ();
4263
4312
}
4264
4313
bool operator !=(SILParameterInfo rhs) const {
4265
4314
return !(*this == rhs);
@@ -4308,64 +4357,114 @@ inline bool isIndirectFormalResult(ResultConvention convention) {
4308
4357
convention == ResultConvention::Pack;
4309
4358
}
4310
4359
4311
- // / The differentiability of a SIL function type result.
4312
- enum class SILResultDifferentiability : bool {
4313
- // / Either differentiable or not applicable.
4314
- // /
4315
- // / - If the function type is not `@differentiable`, result
4316
- // / differentiability is not applicable. This case is the default value.
4317
- // / - If the function type is `@differentiable`, the function is
4318
- // / differentiable with respect to this result.
4319
- DifferentiableOrNotApplicable,
4320
-
4321
- // / Not differentiable: a `@noDerivative` result.
4322
- // /
4323
- // / May be applied only to result of `@differentiable` function types.
4324
- // / The function type is not differentiable with respect to this result.
4325
- NotDifferentiable,
4326
- };
4327
-
4328
4360
// / A result type and the rules for returning it.
4329
4361
class SILResultInfo {
4330
- CanType Type;
4331
- ResultConvention Convention;
4332
- SILResultDifferentiability Differentiability;
4362
+ public:
4363
+ enum Flag : uint8_t {
4364
+ None = 0x0 ,
4365
+
4366
+ // / Not differentiable: a `@noDerivative` result.
4367
+ // /
4368
+ // / May be applied only to result of `@differentiable` function types.
4369
+ // / The function type is not differentiable with respect to this result.
4370
+ // /
4371
+ // / If this is not set then the function is either differentiable or
4372
+ // / differentiability is not applicable. This can occur if:
4373
+ // /
4374
+ // / - The function type is not `@differentiable`, result
4375
+ // / differentiability is not applicable. This case is the default value.
4376
+ // / - The function type is `@differentiable`, the function is
4377
+ // / differentiable with respect to this result.
4378
+ NotDifferentiable = 0x1 ,
4379
+ };
4380
+
4381
+ using Options = OptionSet<Flag>;
4382
+
4383
+ private:
4384
+ CanType type;
4385
+ ResultConvention convention;
4386
+ Options options;
4333
4387
4334
4388
public:
4335
4389
SILResultInfo () = default ;
4336
- SILResultInfo (CanType type, ResultConvention conv,
4337
- SILResultDifferentiability differentiability =
4338
- SILResultDifferentiability::DifferentiableOrNotApplicable)
4339
- : Type(type), Convention(conv), Differentiability(differentiability) {
4390
+ SILResultInfo (CanType type, ResultConvention conv, Options options = {})
4391
+ : type(type), convention(conv), options(options) {
4340
4392
assert (type->isLegalSILType () && " SILResultInfo has illegal SIL type" );
4341
4393
}
4342
4394
4343
4395
// / Return the unsubstituted parameter type that describes the abstract
4344
4396
// / calling convention of the parameter.
4345
4397
// /
4346
4398
// / For most purposes, you probably want \c getReturnValueType .
4347
- CanType getInterfaceType () const {
4348
- return Type;
4349
- }
4350
-
4399
+ CanType getInterfaceType () const { return type; }
4400
+
4351
4401
// / The type of a return value corresponding to this result.
4352
4402
// /
4353
4403
// / \c t must refer back to the function type this is a parameter for.
4354
4404
CanType getReturnValueType (SILModule &M, const SILFunctionType *t,
4355
4405
TypeExpansionContext context) const ;
4356
4406
4357
- ResultConvention getConvention () const {
4358
- return Convention;
4407
+ ResultConvention getConvention () const { return convention; }
4408
+
4409
+ Options getOptions () const { return options; }
4410
+
4411
+ bool hasOption (Flag flag) const { return options.contains (flag); }
4412
+
4413
+ SILResultInfo addingOption (Flag flag) const {
4414
+ auto options = getOptions ();
4415
+ options |= flag;
4416
+ return SILResultInfo (getInterfaceType (), getConvention (), options);
4417
+ }
4418
+
4419
+ SILResultInfo removingOption (Flag flag) const {
4420
+ auto options = getOptions ();
4421
+ options &= flag;
4422
+ return SILResultInfo (getInterfaceType (), getConvention (), options);
4423
+ }
4424
+
4425
+ // / Add all flags in \p arg into a copy of this parameter info and return the
4426
+ // / parameter info.
4427
+ // /
4428
+ // / NOTE: You can pass in SILResultInfo::Flag to this function since said
4429
+ // / type auto converts to Options.
4430
+ SILResultInfo operator |(Options arg) const {
4431
+ return SILResultInfo (getInterfaceType (), getConvention (),
4432
+ getOptions () | arg);
4359
4433
}
4360
4434
4361
- SILResultDifferentiability getDifferentiability () const {
4362
- return Differentiability;
4435
+ SILResultInfo &operator |=(Options arg) {
4436
+ options |= arg;
4437
+ return *this ;
4363
4438
}
4364
4439
4365
- SILResultInfo
4366
- getWithDifferentiability (SILResultDifferentiability differentiability) const {
4440
+ // / Copy this parameter and intersect \p arg with the parameters former
4441
+ // / options.
4442
+ // /
4443
+ // / NOTE: You can pass in SILResultInfo::Flag to this function since said
4444
+ // / type auto converts to Options.
4445
+ SILResultInfo operator &(Options arg) const {
4367
4446
return SILResultInfo (getInterfaceType (), getConvention (),
4368
- differentiability);
4447
+ getOptions () & arg);
4448
+ }
4449
+
4450
+ SILResultInfo &operator &=(Options arg) {
4451
+ options &= arg;
4452
+ return *this ;
4453
+ }
4454
+
4455
+ // / Copy this parameter such that its options contains the set subtraction of
4456
+ // / \p arg from the parameters former options.
4457
+ // /
4458
+ // / NOTE: You can pass in SILResultInfo::Flag to this function since said
4459
+ // / type auto converts to Options.
4460
+ SILResultInfo operator -(Options arg) const {
4461
+ return SILResultInfo (getInterfaceType (), getConvention (),
4462
+ getOptions () - arg);
4463
+ }
4464
+
4465
+ SILResultInfo &operator -=(Options arg) {
4466
+ options -= arg;
4467
+ return *this ;
4369
4468
}
4370
4469
4371
4470
// / The SIL storage type determines the ABI for arguments based purely on the
@@ -4428,9 +4527,9 @@ class SILResultInfo {
4428
4527
}
4429
4528
4430
4529
void profile (llvm::FoldingSetNodeID &id) {
4431
- id.AddPointer (Type .getPointer ());
4530
+ id.AddPointer (type .getPointer ());
4432
4531
id.AddInteger (unsigned (getConvention ()));
4433
- id.AddInteger (unsigned (getDifferentiability ()));
4532
+ id.AddInteger (unsigned (getOptions (). toRaw ()));
4434
4533
}
4435
4534
4436
4535
SWIFT_DEBUG_DUMP;
@@ -4447,8 +4546,8 @@ class SILResultInfo {
4447
4546
getOwnershipKind (SILFunction &, CanSILFunctionType fTy ) const ; // in SILType.cpp
4448
4547
4449
4548
bool operator ==(SILResultInfo rhs) const {
4450
- return Type == rhs.Type && Convention == rhs.Convention
4451
- && Differentiability == rhs.Differentiability ;
4549
+ return type == rhs.type && convention == rhs.convention &&
4550
+ getOptions (). toRaw () == rhs.getOptions (). toRaw () ;
4452
4551
}
4453
4552
bool operator !=(SILResultInfo rhs) const {
4454
4553
return !(*this == rhs);
0 commit comments