Skip to content

Commit 5af2b01

Browse files
committed
Update to use LinearDifferentiableFunctionTypeComponent.
1 parent 76721ab commit 5af2b01

File tree

6 files changed

+42
-38
lines changed

6 files changed

+42
-38
lines changed

include/swift/AST/AutoDiff.h

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,19 @@ enum class NormalDifferentiableFunctionTypeComponent : uint8_t {
5252
VJP = 2
5353
};
5454

55-
enum class LinearDifferentiableFunctionTypeComponent : uint8_t {
56-
Original = 0,
57-
Transpose = 1
55+
struct LinearDifferentiableFunctionTypeComponent {
56+
enum innerty : unsigned {
57+
Original = 0,
58+
Transpose = 1,
59+
} rawValue;
60+
61+
LinearDifferentiableFunctionTypeComponent() = default;
62+
LinearDifferentiableFunctionTypeComponent(innerty rawValue)
63+
: rawValue(rawValue) {}
64+
explicit LinearDifferentiableFunctionTypeComponent(unsigned rawValue) :
65+
LinearDifferentiableFunctionTypeComponent((innerty)rawValue) {}
66+
explicit LinearDifferentiableFunctionTypeComponent(StringRef name);
67+
operator innerty() const { return rawValue; }
5868
};
5969

6070
class ParsedAutoDiffParameter {

include/swift/SIL/SILBuilder.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -535,7 +535,7 @@ class SILBuilder {
535535
}
536536

537537
LinearFunctionExtractInst *createLinearFunctionExtract(
538-
SILLocation Loc, LinearFunctionExtractee Extractee,
538+
SILLocation Loc, LinearDifferentiableFunctionTypeComponent Extractee,
539539
SILValue TheFunction) {
540540
return insert(new (getModule()) LinearFunctionExtractInst(
541541
getModule(), getSILDebugLocation(Loc), Extractee, TheFunction));

include/swift/SIL/SILInstruction.h

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8022,42 +8022,33 @@ class LinearFunctionExtractInst
80228022
: public InstructionBase<
80238023
SILInstructionKind::LinearFunctionExtractInst,
80248024
SingleValueInstruction> {
8025-
public:
8026-
struct Extractee {
8027-
enum innerty : unsigned {
8028-
Original = 0,
8029-
Transpose = 1
8030-
} rawValue;
8031-
Extractee() = default;
8032-
Extractee(innerty rawValue) : rawValue(rawValue) {}
8033-
explicit Extractee(unsigned rawValue) : Extractee((innerty)rawValue) {}
8034-
explicit Extractee(StringRef name);
8035-
operator innerty() const { return rawValue; }
8036-
};
8037-
80388025
private:
80398026
/// The extractee.
8040-
Extractee extractee;
8027+
LinearDifferentiableFunctionTypeComponent extractee;
80418028
/// The list containing the `@differentiable(linear)` function operand.
80428029
FixedOperandList<1> operands;
80438030

80448031
static SILType
8045-
getExtracteeType(SILValue function, Extractee extractee, SILModule &module);
8032+
getExtracteeType(SILValue function,
8033+
LinearDifferentiableFunctionTypeComponent extractee,
8034+
SILModule &module);
80468035

80478036
public:
80488037
explicit LinearFunctionExtractInst(
8049-
SILModule &module, SILDebugLocation debugLoc, Extractee extractee,
8038+
SILModule &module, SILDebugLocation debugLoc,
8039+
LinearDifferentiableFunctionTypeComponent extractee,
80508040
SILValue theFunction);
80518041

8052-
Extractee getExtractee() const { return extractee; }
8042+
LinearDifferentiableFunctionTypeComponent getExtractee() const {
8043+
return extractee;
8044+
}
80538045

80548046
SILValue getFunctionOperand() const { return operands[0].get(); }
80558047
ArrayRef<Operand> getAllOperands() const { return operands.asArray(); }
80568048
MutableArrayRef<Operand> getAllOperands() { return operands.asArray(); }
80578049
};
80588050

8059-
typedef LinearFunctionExtractInst::Extractee
8060-
LinearFunctionExtractee;
8051+
typedef LinearDifferentiableFunctionTypeComponent LinearFunctionExtractee;
80618052
// SWIFT_ENABLE_TENSORFLOW END
80628053

80638054
// This is defined out of line to work around the fact that this depends on

lib/AST/AutoDiff.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,16 @@ AutoDiffDerivativeFunctionKind(StringRef string) {
3232
rawValue = *result;
3333
}
3434

35+
LinearDifferentiableFunctionTypeComponent::
36+
LinearDifferentiableFunctionTypeComponent(StringRef string) {
37+
Optional<innerty> result =
38+
llvm::StringSwitch<Optional<innerty>>(string)
39+
.Case("original", Original)
40+
.Case("transpose", Transpose);
41+
assert(result && "Invalid string");
42+
rawValue = *result;
43+
}
44+
3545
// TODO(TF-874): This helper is inefficient and should be removed. Unwrapping at
3646
// most once (for curried method types) is sufficient.
3747
static void unwrapCurryLevels(AnyFunctionType *fnTy,

lib/ParseSIL/ParseSIL.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3064,7 +3064,7 @@ bool SILParser::parseSILInstruction(SILBuilder &B) {
30643064
case SILInstructionKind::LinearFunctionExtractInst: {
30653065
// Parse the rest of the instruction: an extractee, a linear function
30663066
// operand, and a debug location.
3067-
LinearFunctionExtractee extractee;
3067+
LinearDifferentiableFunctionTypeComponent extractee;
30683068
StringRef extracteeNames[2] = {"original", "transpose"};
30693069
SILValue functionOperand;
30703070
SourceLoc lastLoc;

lib/SIL/SILInstructions.cpp

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -633,15 +633,6 @@ SILType LinearFunctionInst::getLinearFunctionType(
633633
return SILType::getPrimitiveObjectType(diffTy);
634634
}
635635

636-
LinearFunctionExtractInst::Extractee::Extractee(StringRef string) {
637-
Optional<innerty> result =
638-
llvm::StringSwitch<Optional<innerty>>(string)
639-
.Case("original", Original)
640-
.Case("transpose", Transpose);
641-
assert(result && "Invalid string");
642-
rawValue = *result;
643-
}
644-
645636
LinearFunctionInst::LinearFunctionInst(
646637
SILDebugLocation Loc, IndexSubset *ParameterIndices,
647638
SILValue OriginalFunction, Optional<SILValue> TransposeFunction,
@@ -729,14 +720,16 @@ DifferentiableFunctionExtractInst::DifferentiableFunctionExtractInst(
729720
extractee(extractee), operands(this, theFunction) {}
730721

731722
SILType LinearFunctionExtractInst::
732-
getExtracteeType(SILValue function, Extractee extractee, SILModule &module) {
723+
getExtracteeType(
724+
SILValue function, LinearDifferentiableFunctionTypeComponent extractee,
725+
SILModule &module) {
733726
auto fnTy = function->getType().castTo<SILFunctionType>();
734727
assert(fnTy->getDifferentiabilityKind() == DifferentiabilityKind::Linear);
735728
auto originalFnTy = fnTy->getWithoutDifferentiability();
736729
switch (extractee) {
737-
case Extractee::Original:
730+
case LinearDifferentiableFunctionTypeComponent::Original:
738731
return SILType::getPrimitiveObjectType(originalFnTy);
739-
case Extractee::Transpose:
732+
case LinearDifferentiableFunctionTypeComponent::Transpose:
740733
auto transposeFnTy = originalFnTy->getAutoDiffTransposeFunctionType(
741734
fnTy->getDifferentiationParameterIndices(), module.Types,
742735
LookUpConformanceInModule(module.getSwiftModule()));
@@ -745,8 +738,8 @@ getExtracteeType(SILValue function, Extractee extractee, SILModule &module) {
745738
}
746739

747740
LinearFunctionExtractInst::LinearFunctionExtractInst(
748-
SILModule &module, SILDebugLocation debugLoc, Extractee extractee,
749-
SILValue theFunction)
741+
SILModule &module, SILDebugLocation debugLoc,
742+
LinearDifferentiableFunctionTypeComponent extractee, SILValue theFunction)
750743
: InstructionBase(debugLoc,
751744
getExtracteeType(theFunction, extractee, module)),
752745
extractee(extractee), operands(this, theFunction) {}

0 commit comments

Comments
 (0)