@@ -4092,54 +4092,50 @@ inline bool isPackParameter(ParameterConvention conv) {
4092
4092
llvm_unreachable (" bad convention kind" );
4093
4093
}
4094
4094
4095
- // / The differentiability of a SIL function type parameter.
4096
- enum class SILParameterDifferentiability : bool {
4097
- // / Either differentiable or not applicable.
4098
- // /
4099
- // / - If the function type is not `@differentiable`, parameter
4100
- // / differentiability is not applicable. This case is the default value.
4101
- // / - If the function type is `@differentiable`, the function is
4102
- // / differentiable with respect to this parameter.
4103
- DifferentiableOrNotApplicable,
4104
-
4105
- // / Not differentiable: a `@noDerivative` parameter.
4106
- // /
4107
- // / May be applied only to parameters of `@differentiable` function types.
4108
- // / The function type is not differentiable with respect to this parameter.
4109
- NotDifferentiable,
4110
- };
4111
-
4112
4095
// / A parameter type and the rules for passing it.
4113
4096
class SILParameterInfo {
4114
- CanType Type;
4115
- ParameterConvention Convention;
4116
- SILParameterDifferentiability Differentiability;
4097
+ public:
4098
+ enum Flag : uint8_t {
4099
+ // / Not differentiable: a `@noDerivative` parameter.
4100
+ // /
4101
+ // / May be applied only to parameters of `@differentiable` function types.
4102
+ // / The function type is not differentiable with respect to this parameter.
4103
+ // /
4104
+ // / If this is not set then the parameter is either differentiable or
4105
+ // / differentiation is not applicable to the parameter:
4106
+ // /
4107
+ // / - If the function type is not `@differentiable`, parameter
4108
+ // / differentiability is not applicable. This case is the default value.
4109
+ // / - If the function type is `@differentiable`, the function is
4110
+ // / differentiable with respect to this parameter.
4111
+ NotDifferentiable = 0x1 ,
4112
+ };
4113
+
4114
+ using Options = OptionSet<Flag>;
4115
+
4116
+ private:
4117
+ CanType type;
4118
+ ParameterConvention convention;
4119
+ Options options;
4117
4120
4118
4121
public:
4119
4122
SILParameterInfo () = default ;// : Ty(), Convention((ParameterConvention)0) {}
4120
- SILParameterInfo (
4121
- CanType type, ParameterConvention conv,
4122
- SILParameterDifferentiability differentiability =
4123
- SILParameterDifferentiability::DifferentiableOrNotApplicable)
4124
- : Type(type), Convention(conv), Differentiability(differentiability) {
4123
+ SILParameterInfo (CanType type, ParameterConvention conv, Options options = {})
4124
+ : type(type), convention(conv), options(options) {
4125
4125
assert (type->isLegalSILType () && " SILParameterInfo has illegal SIL type" );
4126
4126
}
4127
4127
4128
4128
// / Return the unsubstituted parameter type that describes the abstract
4129
4129
// / calling convention of the parameter.
4130
4130
// /
4131
4131
// / For most purposes, you probably want \c getArgumentType .
4132
- CanType getInterfaceType () const {
4133
- return Type;
4134
- }
4135
-
4132
+ CanType getInterfaceType () const { return type; }
4133
+
4136
4134
// / Return the type of a call argument matching this parameter.
4137
4135
// /
4138
4136
// / \c t must refer back to the function type this is a parameter for.
4139
4137
CanType getArgumentType (SILModule &M, const SILFunctionType *t, TypeExpansionContext context) const ;
4140
- ParameterConvention getConvention () const {
4141
- return Convention;
4142
- }
4138
+ ParameterConvention getConvention () const { return convention; }
4143
4139
// Does this parameter convention require indirect storage? This reflects a
4144
4140
// SILFunctionType's formal (immutable) conventions, as opposed to the
4145
4141
// transient SIL conventions that dictate SILValue types.
@@ -4186,14 +4182,65 @@ class SILParameterInfo {
4186
4182
return isGuaranteedParameter (getConvention ());
4187
4183
}
4188
4184
4189
- SILParameterDifferentiability getDifferentiability () const {
4190
- return Differentiability;
4185
+ bool hasOption (Flag flag) const { return options.contains (flag); }
4186
+
4187
+ Options getOptions () const { return options; }
4188
+
4189
+ SILParameterInfo addingOption (Flag flag) const {
4190
+ auto options = getOptions ();
4191
+ options |= flag;
4192
+ return SILParameterInfo (getInterfaceType (), getConvention (), options);
4193
+ }
4194
+
4195
+ SILParameterInfo removingOption (Flag flag) const {
4196
+ auto options = getOptions ();
4197
+ options &= flag;
4198
+ return SILParameterInfo (getInterfaceType (), getConvention (), options);
4199
+ }
4200
+
4201
+ // / Add all flags in \p arg into a copy of this parameter info and return the
4202
+ // / parameter info.
4203
+ // /
4204
+ // / NOTE: You can pass in SILParameterInfo::Flag to this function since said
4205
+ // / type auto converts to Options.
4206
+ SILParameterInfo operator |(Options arg) const {
4207
+ return SILParameterInfo (getInterfaceType (), getConvention (),
4208
+ getOptions () | arg);
4209
+ }
4210
+
4211
+ SILParameterInfo &operator |=(Options arg) {
4212
+ options |= arg;
4213
+ return *this ;
4191
4214
}
4192
4215
4193
- SILParameterInfo getWithDifferentiability (
4194
- SILParameterDifferentiability differentiability) const {
4216
+ // / Copy this parameter and intersect \p arg with the parameters former
4217
+ // / options.
4218
+ // /
4219
+ // / NOTE: You can pass in SILParameterInfo::Flag to this function since said
4220
+ // / type auto converts to Options.
4221
+ SILParameterInfo operator &(Options arg) const {
4195
4222
return SILParameterInfo (getInterfaceType (), getConvention (),
4196
- differentiability);
4223
+ getOptions () & arg);
4224
+ }
4225
+
4226
+ SILParameterInfo &operator &=(Options arg) {
4227
+ options &= arg;
4228
+ return *this ;
4229
+ }
4230
+
4231
+ // / Copy this parameter such that its options contains the set subtraction of
4232
+ // / \p arg from the parameters former options.
4233
+ // /
4234
+ // / NOTE: You can pass in SILParameterInfo::Flag to this function since said
4235
+ // / type auto converts to Options.
4236
+ SILParameterInfo operator -(Options arg) const {
4237
+ return SILParameterInfo (getInterfaceType (), getConvention (),
4238
+ getOptions () - arg);
4239
+ }
4240
+
4241
+ SILParameterInfo &operator -=(Options arg) {
4242
+ options -= arg;
4243
+ return *this ;
4197
4244
}
4198
4245
4199
4246
// / The SIL storage type determines the ABI for arguments based purely on the
@@ -4209,12 +4256,12 @@ class SILParameterInfo {
4209
4256
4210
4257
// / Return a version of this parameter info with the type replaced.
4211
4258
SILParameterInfo getWithInterfaceType (CanType type) const {
4212
- return SILParameterInfo (type, getConvention (), getDifferentiability ());
4259
+ return SILParameterInfo (type, getConvention (), getOptions ());
4213
4260
}
4214
4261
4215
4262
// / Return a version of this parameter info with the convention replaced.
4216
4263
SILParameterInfo getWithConvention (ParameterConvention c) const {
4217
- return SILParameterInfo (getInterfaceType (), c, getDifferentiability ());
4264
+ return SILParameterInfo (getInterfaceType (), c, getOptions ());
4218
4265
}
4219
4266
4220
4267
// / Transform this SILParameterInfo by applying the user-provided
@@ -4244,7 +4291,7 @@ class SILParameterInfo {
4244
4291
void profile (llvm::FoldingSetNodeID &id) {
4245
4292
id.AddPointer (getInterfaceType ().getPointer ());
4246
4293
id.AddInteger ((unsigned )getConvention ());
4247
- id.AddInteger ((unsigned )getDifferentiability ());
4294
+ id.AddInteger ((unsigned )getOptions (). toRaw ());
4248
4295
}
4249
4296
4250
4297
SWIFT_DEBUG_DUMP;
@@ -4260,7 +4307,7 @@ class SILParameterInfo {
4260
4307
bool operator ==(SILParameterInfo rhs) const {
4261
4308
return getInterfaceType () == rhs.getInterfaceType () &&
4262
4309
getConvention () == rhs.getConvention () &&
4263
- getDifferentiability () == rhs.getDifferentiability ( );
4310
+ getOptions (). containsOnly ( rhs.getOptions () );
4264
4311
}
4265
4312
bool operator !=(SILParameterInfo rhs) const {
4266
4313
return !(*this == rhs);
@@ -4309,64 +4356,112 @@ inline bool isIndirectFormalResult(ResultConvention convention) {
4309
4356
convention == ResultConvention::Pack;
4310
4357
}
4311
4358
4312
- // / The differentiability of a SIL function type result.
4313
- enum class SILResultDifferentiability : bool {
4314
- // / Either differentiable or not applicable.
4315
- // /
4316
- // / - If the function type is not `@differentiable`, result
4317
- // / differentiability is not applicable. This case is the default value.
4318
- // / - If the function type is `@differentiable`, the function is
4319
- // / differentiable with respect to this result.
4320
- DifferentiableOrNotApplicable,
4321
-
4322
- // / Not differentiable: a `@noDerivative` result.
4323
- // /
4324
- // / May be applied only to result of `@differentiable` function types.
4325
- // / The function type is not differentiable with respect to this result.
4326
- NotDifferentiable,
4327
- };
4328
-
4329
4359
// / A result type and the rules for returning it.
4330
4360
class SILResultInfo {
4331
- CanType Type;
4332
- ResultConvention Convention;
4333
- SILResultDifferentiability Differentiability;
4361
+ public:
4362
+ enum Flag : uint8_t {
4363
+ // / Not differentiable: a `@noDerivative` result.
4364
+ // /
4365
+ // / May be applied only to result of `@differentiable` function types.
4366
+ // / The function type is not differentiable with respect to this result.
4367
+ // /
4368
+ // / If this is not set then the function is either differentiable or
4369
+ // / differentiability is not applicable. This can occur if:
4370
+ // /
4371
+ // / - The function type is not `@differentiable`, result
4372
+ // / differentiability is not applicable. This case is the default value.
4373
+ // / - The function type is `@differentiable`, the function is
4374
+ // / differentiable with respect to this result.
4375
+ NotDifferentiable = 0x1 ,
4376
+ };
4377
+
4378
+ using Options = OptionSet<Flag>;
4379
+
4380
+ private:
4381
+ CanType type;
4382
+ ResultConvention convention;
4383
+ Options options;
4334
4384
4335
4385
public:
4336
4386
SILResultInfo () = default ;
4337
- SILResultInfo (CanType type, ResultConvention conv,
4338
- SILResultDifferentiability differentiability =
4339
- SILResultDifferentiability::DifferentiableOrNotApplicable)
4340
- : Type(type), Convention(conv), Differentiability(differentiability) {
4387
+ SILResultInfo (CanType type, ResultConvention conv, Options options = {})
4388
+ : type(type), convention(conv), options(options) {
4341
4389
assert (type->isLegalSILType () && " SILResultInfo has illegal SIL type" );
4342
4390
}
4343
4391
4344
4392
// / Return the unsubstituted parameter type that describes the abstract
4345
4393
// / calling convention of the parameter.
4346
4394
// /
4347
4395
// / For most purposes, you probably want \c getReturnValueType .
4348
- CanType getInterfaceType () const {
4349
- return Type;
4350
- }
4351
-
4396
+ CanType getInterfaceType () const { return type; }
4397
+
4352
4398
// / The type of a return value corresponding to this result.
4353
4399
// /
4354
4400
// / \c t must refer back to the function type this is a parameter for.
4355
4401
CanType getReturnValueType (SILModule &M, const SILFunctionType *t,
4356
4402
TypeExpansionContext context) const ;
4357
4403
4358
- ResultConvention getConvention () const {
4359
- return Convention;
4404
+ ResultConvention getConvention () const { return convention; }
4405
+
4406
+ Options getOptions () const { return options; }
4407
+
4408
+ bool hasOption (Flag flag) const { return options.contains (flag); }
4409
+
4410
+ SILResultInfo addingOption (Flag flag) const {
4411
+ auto options = getOptions ();
4412
+ options |= flag;
4413
+ return SILResultInfo (getInterfaceType (), getConvention (), options);
4414
+ }
4415
+
4416
+ SILResultInfo removingOption (Flag flag) const {
4417
+ auto options = getOptions ();
4418
+ options &= flag;
4419
+ return SILResultInfo (getInterfaceType (), getConvention (), options);
4420
+ }
4421
+
4422
+ // / Add all flags in \p arg into a copy of this parameter info and return the
4423
+ // / parameter info.
4424
+ // /
4425
+ // / NOTE: You can pass in SILResultInfo::Flag to this function since said
4426
+ // / type auto converts to Options.
4427
+ SILResultInfo operator |(Options arg) const {
4428
+ return SILResultInfo (getInterfaceType (), getConvention (),
4429
+ getOptions () | arg);
4360
4430
}
4361
4431
4362
- SILResultDifferentiability getDifferentiability () const {
4363
- return Differentiability;
4432
+ SILResultInfo &operator |=(Options arg) {
4433
+ options |= arg;
4434
+ return *this ;
4364
4435
}
4365
4436
4366
- SILResultInfo
4367
- getWithDifferentiability (SILResultDifferentiability differentiability) const {
4437
+ // / Copy this parameter and intersect \p arg with the parameters former
4438
+ // / options.
4439
+ // /
4440
+ // / NOTE: You can pass in SILResultInfo::Flag to this function since said
4441
+ // / type auto converts to Options.
4442
+ SILResultInfo operator &(Options arg) const {
4368
4443
return SILResultInfo (getInterfaceType (), getConvention (),
4369
- differentiability);
4444
+ getOptions () & arg);
4445
+ }
4446
+
4447
+ SILResultInfo &operator &=(Options arg) {
4448
+ options &= arg;
4449
+ return *this ;
4450
+ }
4451
+
4452
+ // / Copy this parameter such that its options contains the set subtraction of
4453
+ // / \p arg from the parameters former options.
4454
+ // /
4455
+ // / NOTE: You can pass in SILResultInfo::Flag to this function since said
4456
+ // / type auto converts to Options.
4457
+ SILResultInfo operator -(Options arg) const {
4458
+ return SILResultInfo (getInterfaceType (), getConvention (),
4459
+ getOptions () - arg);
4460
+ }
4461
+
4462
+ SILResultInfo &operator -=(Options arg) {
4463
+ options -= arg;
4464
+ return *this ;
4370
4465
}
4371
4466
4372
4467
// / The SIL storage type determines the ABI for arguments based purely on the
@@ -4429,9 +4524,9 @@ class SILResultInfo {
4429
4524
}
4430
4525
4431
4526
void profile (llvm::FoldingSetNodeID &id) {
4432
- id.AddPointer (Type .getPointer ());
4527
+ id.AddPointer (type .getPointer ());
4433
4528
id.AddInteger (unsigned (getConvention ()));
4434
- id.AddInteger (unsigned (getDifferentiability ()));
4529
+ id.AddInteger (unsigned (getOptions (). toRaw ()));
4435
4530
}
4436
4531
4437
4532
SWIFT_DEBUG_DUMP;
@@ -4448,8 +4543,8 @@ class SILResultInfo {
4448
4543
getOwnershipKind (SILFunction &, CanSILFunctionType fTy ) const ; // in SILType.cpp
4449
4544
4450
4545
bool operator ==(SILResultInfo rhs) const {
4451
- return Type == rhs.Type && Convention == rhs.Convention
4452
- && Differentiability == rhs.Differentiability ;
4546
+ return type == rhs.type && convention == rhs.convention &&
4547
+ getOptions (). containsOnly ( rhs.getOptions ()) ;
4453
4548
}
4454
4549
bool operator !=(SILResultInfo rhs) const {
4455
4550
return !(*this == rhs);
0 commit comments