Skip to content

Commit b4ac692

Browse files
committed
[AutoDiff] NFC: Change DifferentiableFunctionExtractee to a top-level type.
Many places in the compiler that are completely unrelated to {{DifferentiableFunctionExtractInst}} are using {{DifferentiableFunctionExtractInst::Extractee}}, including `@differentiable` type lowering (IRGen/GenDiffFunc.cpp). This patch refactors it and renames it to `NormalDifferentiableFunctionTypeComponent` so that it is no longer part of `DifferentiableFunctionInst`. Resolves [TF-904](https://bugs.swift.org/browse/TF-904).
1 parent bb67311 commit b4ac692

File tree

13 files changed

+144
-148
lines changed

13 files changed

+144
-148
lines changed

include/swift/AST/AutoDiff.h

Lines changed: 52 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,58 @@ enum class DifferentiabilityKind: uint8_t {
4545
Linear = 2
4646
};
4747

48-
// TODO(TF-904): Replace `DifferentiableFunctionExtractInst::Extractee`.
49-
enum class NormalDifferentiableFunctionTypeComponent : uint8_t {
50-
Original = 0,
51-
JVP = 1,
52-
VJP = 2
48+
/// The kind of an linear map.
49+
struct AutoDiffLinearMapKind {
50+
enum innerty : uint8_t {
51+
// The differential function.
52+
Differential = 0,
53+
// The pullback function.
54+
Pullback = 1
55+
} rawValue;
56+
57+
AutoDiffLinearMapKind() = default;
58+
AutoDiffLinearMapKind(innerty rawValue) : rawValue(rawValue) {}
59+
operator innerty() const { return rawValue; }
60+
};
61+
62+
/// The kind of a derivative function.
63+
struct AutoDiffDerivativeFunctionKind {
64+
enum innerty : uint8_t {
65+
// The Jacobian-vector products function.
66+
JVP = 0,
67+
// The vector-Jacobian products function.
68+
VJP = 1
69+
} rawValue;
70+
71+
AutoDiffDerivativeFunctionKind() = default;
72+
AutoDiffDerivativeFunctionKind(innerty rawValue) : rawValue(rawValue) {}
73+
AutoDiffDerivativeFunctionKind(AutoDiffLinearMapKind linMapKind)
74+
: rawValue(static_cast<innerty>(linMapKind.rawValue)) {}
75+
explicit AutoDiffDerivativeFunctionKind(StringRef string);
76+
operator innerty() const { return rawValue; }
77+
AutoDiffLinearMapKind getLinearMapKind() {
78+
return (AutoDiffLinearMapKind::innerty)rawValue;
79+
}
80+
};
81+
82+
struct NormalDifferentiableFunctionTypeComponent {
83+
enum innerty : unsigned {
84+
Original = 0,
85+
JVP = 1,
86+
VJP = 2
87+
} rawValue;
88+
89+
NormalDifferentiableFunctionTypeComponent() = default;
90+
NormalDifferentiableFunctionTypeComponent(innerty rawValue)
91+
: rawValue(rawValue) {}
92+
NormalDifferentiableFunctionTypeComponent(
93+
AutoDiffDerivativeFunctionKind kind);
94+
explicit NormalDifferentiableFunctionTypeComponent(unsigned rawValue) :
95+
NormalDifferentiableFunctionTypeComponent((innerty)rawValue) {}
96+
explicit NormalDifferentiableFunctionTypeComponent(StringRef name);
97+
operator innerty() const { return rawValue; }
98+
99+
Optional<AutoDiffDerivativeFunctionKind> getAsDerivativeFunctionKind() const;
53100
};
54101

55102
struct LinearDifferentiableFunctionTypeComponent {
@@ -196,40 +243,6 @@ inline llvm::raw_ostream &operator<<(llvm::raw_ostream &s,
196243
return s;
197244
}
198245

199-
/// The kind of an linear map.
200-
struct AutoDiffLinearMapKind {
201-
enum innerty : uint8_t {
202-
// The differential function.
203-
Differential = 0,
204-
// The pullback function.
205-
Pullback = 1
206-
} rawValue;
207-
208-
AutoDiffLinearMapKind() = default;
209-
AutoDiffLinearMapKind(innerty rawValue) : rawValue(rawValue) {}
210-
operator innerty() const { return rawValue; }
211-
};
212-
213-
/// The kind of a derivative function.
214-
struct AutoDiffDerivativeFunctionKind {
215-
enum innerty : uint8_t {
216-
// The Jacobian-vector products function.
217-
JVP = 0,
218-
// The vector-Jacobian products function.
219-
VJP = 1
220-
} rawValue;
221-
222-
AutoDiffDerivativeFunctionKind() = default;
223-
AutoDiffDerivativeFunctionKind(innerty rawValue) : rawValue(rawValue) {}
224-
AutoDiffDerivativeFunctionKind(AutoDiffLinearMapKind linMapKind)
225-
: rawValue(static_cast<innerty>(linMapKind.rawValue)) {}
226-
explicit AutoDiffDerivativeFunctionKind(StringRef string);
227-
operator innerty() const { return rawValue; }
228-
AutoDiffLinearMapKind getLinearMapKind() {
229-
return (AutoDiffLinearMapKind::innerty)rawValue;
230-
}
231-
};
232-
233246
/// Identifies an autodiff derivative function configuration:
234247
/// - Parameter indices.
235248
/// - Result indices.

include/swift/SIL/SILBuilder.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -528,7 +528,7 @@ class SILBuilder {
528528
}
529529

530530
DifferentiableFunctionExtractInst *createDifferentiableFunctionExtract(
531-
SILLocation Loc, DifferentiableFunctionExtractee Extractee,
531+
SILLocation Loc, NormalDifferentiableFunctionTypeComponent Extractee,
532532
SILValue TheFunction) {
533533
return insert(new (getModule()) DifferentiableFunctionExtractInst(
534534
getModule(), getSILDebugLocation(Loc), Extractee, TheFunction));
@@ -546,7 +546,7 @@ class SILBuilder {
546546
SILValue TheFunction) {
547547
return insert(new (getModule()) DifferentiableFunctionExtractInst(
548548
getModule(), getSILDebugLocation(Loc),
549-
DifferentiableFunctionExtractee::Original, TheFunction));
549+
NormalDifferentiableFunctionTypeComponent::Original, TheFunction));
550550
}
551551

552552
BuiltinInst *createBuiltin(SILLocation Loc, Identifier Name, SILType ResultTy,

include/swift/SIL/SILInstruction.h

Lines changed: 10 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -7967,42 +7967,29 @@ class DifferentiableFunctionExtractInst
79677967
: public InstructionBase<
79687968
SILInstructionKind::DifferentiableFunctionExtractInst,
79697969
SingleValueInstruction> {
7970-
public:
7971-
struct Extractee {
7972-
enum innerty : unsigned {
7973-
Original = 0,
7974-
JVP = 1,
7975-
VJP = 2
7976-
} rawValue;
7977-
Extractee() = default;
7978-
Extractee(innerty rawValue) : rawValue(rawValue) {}
7979-
explicit Extractee(unsigned rawValue) : Extractee((innerty)rawValue) {}
7980-
Extractee(AutoDiffDerivativeFunctionKind kind);
7981-
explicit Extractee(StringRef name);
7982-
operator innerty() const { return rawValue; }
7983-
7984-
Optional<AutoDiffDerivativeFunctionKind>
7985-
getExtracteeAsDerivativeFunction() const;
7986-
};
7987-
79887970
private:
79897971
/// The extractee.
7990-
Extractee extractee;
7972+
NormalDifferentiableFunctionTypeComponent extractee;
79917973
/// The list containing the `@differentiable` function operand.
79927974
FixedOperandList<1> operands;
79937975

79947976
static SILType
7995-
getExtracteeType(SILValue function, Extractee extractee, SILModule &module);
7977+
getExtracteeType(
7978+
SILValue function, NormalDifferentiableFunctionTypeComponent extractee,
7979+
SILModule &module);
79967980

79977981
public:
79987982
explicit DifferentiableFunctionExtractInst(
7999-
SILModule &module, SILDebugLocation debugLoc, Extractee extractee,
7983+
SILModule &module, SILDebugLocation debugLoc,
7984+
NormalDifferentiableFunctionTypeComponent extractee,
80007985
SILValue theFunction);
80017986

8002-
Extractee getExtractee() const { return extractee; }
7987+
NormalDifferentiableFunctionTypeComponent getExtractee() const {
7988+
return extractee;
7989+
}
80037990

80047991
AutoDiffDerivativeFunctionKind getDerivativeFunctionKind() const {
8005-
auto kind = extractee.getExtracteeAsDerivativeFunction();
7992+
auto kind = extractee.getAsDerivativeFunctionKind();
80067993
assert(kind);
80077994
return *kind;
80087995
}
@@ -8012,9 +7999,6 @@ class DifferentiableFunctionExtractInst
80127999
MutableArrayRef<Operand> getAllOperands() { return operands.asArray(); }
80138000
};
80148001

8015-
typedef DifferentiableFunctionExtractInst::Extractee
8016-
DifferentiableFunctionExtractee;
8017-
80188002
/// `linear_function_extract` - given an `@differentiable(linear)` function
80198003
/// representing a bundle of the original function and the transpose function,
80208004
/// extract the specified function.
@@ -8047,8 +8031,6 @@ class LinearFunctionExtractInst
80478031
ArrayRef<Operand> getAllOperands() const { return operands.asArray(); }
80488032
MutableArrayRef<Operand> getAllOperands() { return operands.asArray(); }
80498033
};
8050-
8051-
typedef LinearDifferentiableFunctionTypeComponent LinearFunctionExtractee;
80528034
// SWIFT_ENABLE_TENSORFLOW END
80538035

80548036
// This is defined out of line to work around the fact that this depends on

lib/AST/AutoDiff.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,33 @@ AutoDiffDerivativeFunctionKind(StringRef string) {
3232
rawValue = *result;
3333
}
3434

35+
NormalDifferentiableFunctionTypeComponent::
36+
NormalDifferentiableFunctionTypeComponent(AutoDiffDerivativeFunctionKind kind) {
37+
switch (kind) {
38+
case AutoDiffDerivativeFunctionKind::JVP: rawValue = JVP; return;
39+
case AutoDiffDerivativeFunctionKind::VJP: rawValue = VJP; return;
40+
}
41+
}
42+
43+
NormalDifferentiableFunctionTypeComponent::
44+
NormalDifferentiableFunctionTypeComponent(StringRef string) {
45+
Optional<innerty> result = llvm::StringSwitch<Optional<innerty>>(string)
46+
.Case("original", Original)
47+
.Case("jvp", JVP)
48+
.Case("vjp", VJP);
49+
assert(result && "Invalid string");
50+
rawValue = *result;
51+
}
52+
53+
Optional<AutoDiffDerivativeFunctionKind>
54+
NormalDifferentiableFunctionTypeComponent::getAsDerivativeFunctionKind() const {
55+
switch (rawValue) {
56+
case Original: return None;
57+
case JVP: return {AutoDiffDerivativeFunctionKind::JVP};
58+
case VJP: return {AutoDiffDerivativeFunctionKind::VJP};
59+
}
60+
}
61+
3562
LinearDifferentiableFunctionTypeComponent::
3663
LinearDifferentiableFunctionTypeComponent(StringRef string) {
3764
Optional<innerty> result =

lib/IRGen/GenDiffFunc.cpp

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -41,34 +41,34 @@ class DifferentiableFuncFieldInfo final
4141
: public RecordField<DifferentiableFuncFieldInfo> {
4242
public:
4343
DifferentiableFuncFieldInfo(
44-
DifferentiableFunctionExtractee component, const TypeInfo &type,
44+
NormalDifferentiableFunctionTypeComponent component, const TypeInfo &type,
4545
IndexSubset *parameterIndices)
4646
: RecordField(type), component(component),
4747
parameterIndices(parameterIndices) {}
4848

4949
/// The field index.
50-
const DifferentiableFunctionExtractee component;
50+
const NormalDifferentiableFunctionTypeComponent component;
5151

5252
/// The parameter indices.
5353
IndexSubset *parameterIndices;
5454

5555
std::string getFieldName() const {
5656
switch (component) {
57-
case DifferentiableFunctionExtractee::Original:
57+
case NormalDifferentiableFunctionTypeComponent::Original:
5858
return "original";
59-
case DifferentiableFunctionExtractee::JVP:
59+
case NormalDifferentiableFunctionTypeComponent::JVP:
6060
return "jvp";
61-
case DifferentiableFunctionExtractee::VJP:
61+
case NormalDifferentiableFunctionTypeComponent::VJP:
6262
return "vjp";
6363
}
6464
}
6565

6666
SILType getType(IRGenModule &IGM, SILType t) const {
6767
auto fnTy = t.castTo<SILFunctionType>();
6868
auto origFnTy = fnTy->getWithoutDifferentiability();
69-
if (component == DifferentiableFunctionExtractee::Original)
69+
if (component == NormalDifferentiableFunctionTypeComponent::Original)
7070
return SILType::getPrimitiveObjectType(origFnTy);
71-
auto kind = *component.getExtracteeAsDerivativeFunction();
71+
auto kind = *component.getAsDerivativeFunctionKind();
7272
auto assocTy = origFnTy->getAutoDiffDerivativeFunctionType(
7373
parameterIndices, /*resultIndex*/ 0, kind,
7474
IGM.getSILTypes(), LookUpConformanceInModule(IGM.getSwiftModule()));
@@ -79,8 +79,8 @@ class DifferentiableFuncFieldInfo final
7979
class DifferentiableFuncTypeInfo final
8080
: public RecordTypeInfo<DifferentiableFuncTypeInfo, LoadableTypeInfo,
8181
DifferentiableFuncFieldInfo> {
82-
using super =
83-
RecordTypeInfo<DifferentiableFuncTypeInfo, LoadableTypeInfo, DifferentiableFuncFieldInfo>;
82+
using super = RecordTypeInfo<DifferentiableFuncTypeInfo, LoadableTypeInfo,
83+
DifferentiableFuncFieldInfo>;
8484

8585
public:
8686
DifferentiableFuncTypeInfo(
@@ -117,7 +117,7 @@ class DifferentiableFuncTypeInfo final
117117

118118
class DifferentiableFuncTypeBuilder
119119
: public RecordTypeBuilder<DifferentiableFuncTypeBuilder, DifferentiableFuncFieldInfo,
120-
DifferentiableFunctionExtractee> {
120+
NormalDifferentiableFunctionTypeComponent> {
121121

122122
SILFunctionType *originalType;
123123
IndexSubset *parameterIndices;
@@ -151,15 +151,15 @@ class DifferentiableFuncTypeBuilder
151151
}
152152

153153
DifferentiableFuncFieldInfo getFieldInfo(
154-
unsigned index, DifferentiableFunctionExtractee component,
154+
unsigned index, NormalDifferentiableFunctionTypeComponent component,
155155
const TypeInfo &fieldTI) {
156156
return DifferentiableFuncFieldInfo(component, fieldTI, parameterIndices);
157157
}
158158

159-
SILType getType(DifferentiableFunctionExtractee component) {
160-
if (component == DifferentiableFunctionExtractee::Original)
159+
SILType getType(NormalDifferentiableFunctionTypeComponent component) {
160+
if (component == NormalDifferentiableFunctionTypeComponent::Original)
161161
return SILType::getPrimitiveObjectType(originalType->getCanonicalType());
162-
auto kind = *component.getExtracteeAsDerivativeFunction();
162+
auto kind = *component.getAsDerivativeFunctionKind();
163163
auto assocTy = originalType->getAutoDiffDerivativeFunctionType(
164164
parameterIndices, /*resultIndex*/ 0, kind, IGM.getSILTypes(),
165165
LookUpConformanceInModule(IGM.getSwiftModule()));
@@ -320,9 +320,9 @@ class LinearFuncTypeBuilder
320320
const TypeInfo *
321321
TypeConverter::convertNormalDifferentiableFunctionType(SILFunctionType *type) {
322322
DifferentiableFuncTypeBuilder builder(IGM, type);
323-
return builder.layout({DifferentiableFunctionExtractee::Original,
324-
DifferentiableFunctionExtractee::JVP,
325-
DifferentiableFunctionExtractee::VJP});
323+
return builder.layout({NormalDifferentiableFunctionTypeComponent::Original,
324+
NormalDifferentiableFunctionTypeComponent::JVP,
325+
NormalDifferentiableFunctionTypeComponent::VJP});
326326
}
327327

328328
const TypeInfo *

lib/ParseSIL/ParseSIL.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3041,7 +3041,7 @@ bool SILParser::parseSILInstruction(SILBuilder &B) {
30413041
case SILInstructionKind::DifferentiableFunctionExtractInst: {
30423042
// Parse the rest of the instruction: an extractee, a differentiable
30433043
// function operand, and a debug location.
3044-
DifferentiableFunctionExtractee extractee;
3044+
NormalDifferentiableFunctionTypeComponent extractee;
30453045
StringRef extracteeNames[3] = {"original", "jvp", "vjp"};
30463046
SILValue functionOperand;
30473047
SourceLoc lastLoc;

lib/SIL/SILInstructions.cpp

Lines changed: 7 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -664,45 +664,16 @@ LinearFunctionInst *LinearFunctionInst::create(
664664
HasOwnership);
665665
}
666666

667-
DifferentiableFunctionExtractInst::Extractee::Extractee(
668-
AutoDiffDerivativeFunctionKind kind) {
669-
switch (kind) {
670-
case AutoDiffDerivativeFunctionKind::JVP:
671-
rawValue = JVP;
672-
return;
673-
case AutoDiffDerivativeFunctionKind::VJP:
674-
rawValue = VJP;
675-
return;
676-
}
677-
}
678-
679-
DifferentiableFunctionExtractInst::Extractee::Extractee(StringRef string) {
680-
Optional<innerty> result = llvm::StringSwitch<Optional<innerty>>(string)
681-
.Case("original", Original)
682-
.Case("jvp", JVP)
683-
.Case("vjp", VJP);
684-
assert(result && "Invalid string");
685-
rawValue = *result;
686-
}
687-
688-
Optional<AutoDiffDerivativeFunctionKind>
689-
DifferentiableFunctionExtractInst::Extractee::
690-
getExtracteeAsDerivativeFunction() const {
691-
switch (rawValue) {
692-
case Original: return None;
693-
case JVP: return {AutoDiffDerivativeFunctionKind::JVP};
694-
case VJP: return {AutoDiffDerivativeFunctionKind::VJP};
695-
}
696-
}
697-
698667
SILType DifferentiableFunctionExtractInst::
699-
getExtracteeType(SILValue function, Extractee extractee, SILModule &module) {
668+
getExtracteeType(
669+
SILValue function, NormalDifferentiableFunctionTypeComponent extractee,
670+
SILModule &module) {
700671
auto fnTy = function->getType().castTo<SILFunctionType>();
701672
assert(fnTy->getDifferentiabilityKind() == DifferentiabilityKind::Normal);
702673
auto originalFnTy = fnTy->getWithoutDifferentiability();
703-
auto kindOpt = extractee.getExtracteeAsDerivativeFunction();
674+
auto kindOpt = extractee.getAsDerivativeFunctionKind();
704675
if (!kindOpt) {
705-
assert(extractee == Extractee::Original);
676+
assert(extractee == NormalDifferentiableFunctionTypeComponent::Original);
706677
return SILType::getPrimitiveObjectType(originalFnTy);
707678
}
708679
auto resultFnTy = originalFnTy->getAutoDiffDerivativeFunctionType(
@@ -713,8 +684,8 @@ getExtracteeType(SILValue function, Extractee extractee, SILModule &module) {
713684
}
714685

715686
DifferentiableFunctionExtractInst::DifferentiableFunctionExtractInst(
716-
SILModule &module, SILDebugLocation debugLoc, Extractee extractee,
717-
SILValue theFunction)
687+
SILModule &module, SILDebugLocation debugLoc,
688+
NormalDifferentiableFunctionTypeComponent extractee, SILValue theFunction)
718689
: InstructionBase(debugLoc,
719690
getExtracteeType(theFunction, extractee, module)),
720691
extractee(extractee), operands(this, theFunction) {}

0 commit comments

Comments
 (0)