@@ -357,32 +357,29 @@ class StructDecl;
357
357
class TupleType ;
358
358
class EnumDecl ;
359
359
360
- // / A type that represents the tangent space of a differentiable type .
361
- class TangentSpace {
360
+ // / A type that represents a vector space .
361
+ class VectorSpace {
362
362
public:
363
363
// / A tangent space kind.
364
364
enum class Kind {
365
365
// / `Builtin.FP<...>`.
366
- BuiltinRealScalar,
367
- // / A type that conforms to `FloatingPoint`.
368
- RealScalar,
369
- // / A type that conforms to `VectorNumeric` where the associated
370
- // / `ScalarElement` conforms to `FloatingPoint`.
371
- RealVector,
372
- // / A product of tangent spaces as a struct.
373
- ProductStruct,
374
- // / A product of tangent spaces as a tuple.
375
- ProductTuple,
376
- // / A sum of tangent spaces.
377
- Sum
366
+ BuiltinFloat,
367
+ // / A type that conforms to `VectorNumeric`.
368
+ Vector,
369
+ // / A product of vector spaces as a struct.
370
+ Struct,
371
+ // / A product of vector spaces as a tuple.
372
+ Tuple,
373
+ // / A union of vector spaces.
374
+ Enum
378
375
};
379
376
380
377
private:
381
378
Kind kind;
382
379
union Value {
383
380
// BuiltinRealScalar
384
381
BuiltinFloatType *builtinFPType;
385
- // RealScalar or RealVector
382
+ // RealVector
386
383
NominalTypeDecl *realNominalType;
387
384
// ProductStruct
388
385
StructDecl *structDecl;
@@ -398,67 +395,55 @@ class TangentSpace {
398
395
Value (EnumDecl *enumDecl) : enumDecl (enumDecl) {}
399
396
} value;
400
397
401
- TangentSpace (Kind kind, Value value)
398
+ VectorSpace (Kind kind, Value value)
402
399
: kind(kind), value(value) {}
403
400
404
401
public:
405
- TangentSpace () = delete ;
402
+ VectorSpace () = delete ;
406
403
407
- static TangentSpace
404
+ static VectorSpace
408
405
getBuiltinRealScalarSpace (BuiltinFloatType *builtinFP) {
409
- return {Kind::BuiltinRealScalar , builtinFP};
406
+ return {Kind::BuiltinFloat , builtinFP};
410
407
}
411
- static TangentSpace getRealScalarSpace (NominalTypeDecl *typeDecl) {
412
- return {Kind::RealScalar , typeDecl};
408
+ static VectorSpace getRealVectorSpace (NominalTypeDecl *typeDecl) {
409
+ return {Kind::Vector , typeDecl};
413
410
}
414
- static TangentSpace getRealVectorSpace (NominalTypeDecl *typeDecl ) {
415
- return {Kind::RealVector, typeDecl };
411
+ static VectorSpace getStruct (StructDecl *structDecl ) {
412
+ return {Kind::Struct, structDecl };
416
413
}
417
- static TangentSpace getProductStruct (StructDecl *structDecl ) {
418
- return {Kind::ProductStruct, structDecl };
414
+ static VectorSpace getTuple (TupleType *tupleTy ) {
415
+ return {Kind::Tuple, tupleTy };
419
416
}
420
- static TangentSpace getProductTuple (TupleType *tupleTy) {
421
- return {Kind::ProductTuple, tupleTy};
422
- }
423
- static TangentSpace getSum (EnumDecl *enumDecl) {
424
- return {Kind::Sum, enumDecl};
417
+ static VectorSpace getEnum (EnumDecl *enumDecl) {
418
+ return {Kind::Enum, enumDecl};
425
419
}
426
420
427
- bool isBuiltinRealScalarSpace () const {
428
- return kind == Kind::BuiltinRealScalar ;
421
+ bool isBuiltinFloat () const {
422
+ return kind == Kind::BuiltinFloat ;
429
423
}
430
- bool isRealScalarSpace () const { return kind == Kind::RealScalar; }
431
- bool isRealVectorSpace () const { return kind == Kind::RealVector; }
432
- bool isProductStruct () const { return kind == Kind::ProductStruct; }
433
- bool isProductTuple () const { return kind == Kind::ProductTuple; }
424
+ bool isVector () const { return kind == Kind::Vector; }
425
+ bool isStruct () const { return kind == Kind::Struct; }
426
+ bool isTuple () const { return kind == Kind::Tuple; }
434
427
435
428
Kind getKind () const { return kind; }
436
- BuiltinFloatType *getBuiltinRealScalarSpace () const {
437
- assert (kind == Kind::BuiltinRealScalar );
429
+ BuiltinFloatType *getBuiltinFloat () const {
430
+ assert (kind == Kind::BuiltinFloat );
438
431
return value.builtinFPType ;
439
432
}
440
- NominalTypeDecl *getRealScalarSpace () const {
441
- assert (kind == Kind::RealScalar);
442
- return value.realNominalType ;
443
- }
444
- NominalTypeDecl *getRealVectorSpace () const {
445
- assert (kind == Kind::RealVector);
446
- return value.realNominalType ;
447
- }
448
- NominalTypeDecl *getRealScalarOrVectorSpace () const {
449
- assert (kind == Kind::RealScalar || kind == Kind::RealVector);
433
+ NominalTypeDecl *getVector () const {
434
+ assert (kind == Kind::Vector);
450
435
return value.realNominalType ;
451
436
}
452
- StructDecl *getProductStruct () const {
453
- assert (kind == Kind::ProductStruct );
437
+ StructDecl *getStruct () const {
438
+ assert (kind == Kind::Struct );
454
439
return value.structDecl ;
455
440
}
456
- TupleType *getProductTuple () const {
457
- assert (kind == Kind::ProductTuple );
441
+ TupleType *getTuple () const {
442
+ assert (kind == Kind::Tuple );
458
443
return value.tupleType ;
459
444
}
460
- EnumDecl *getSum () const {
461
- assert (kind == Kind::Sum );
445
+ EnumDecl *getEnum () const {
446
+ assert (kind == Kind::Enum );
462
447
return value.enumDecl ;
463
448
}
464
449
};
0 commit comments