Skip to content

Commit c8cd24a

Browse files
authored
Merge pull request #35366 from rxwei/remove-differentiable-function-ctor
[AutoDiff] Remove 'differentiableFunction(from:)' and 'linearFunction(from:)'
2 parents d710cfc + f79391b commit c8cd24a

File tree

9 files changed

+13
-338
lines changed

9 files changed

+13
-338
lines changed

include/swift/AST/Builtins.def

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -498,12 +498,6 @@ BUILTIN_SIL_OPERATION(ApplyDerivative, "applyDerivative", Special)
498498
/// applyTranspose
499499
BUILTIN_SIL_OPERATION(ApplyTranspose, "applyTranspose", Special)
500500

501-
/// differentiableFunction
502-
BUILTIN_SIL_OPERATION(DifferentiableFunction, "differentiableFunction", Special)
503-
504-
/// linearFunction
505-
BUILTIN_SIL_OPERATION(LinearFunction, "linearFunction", Special)
506-
507501
/// withUnsafeContinuation<T> : (Builtin.RawUnsafeContinuation -> ()) async -> T
508502
///
509503
/// Unsafely capture the current continuation and pass it to the given

lib/AST/Builtins.cpp

Lines changed: 0 additions & 191 deletions
Original file line numberDiff line numberDiff line change
@@ -1289,179 +1289,6 @@ static ValueDecl *getAutoDiffApplyTransposeFunction(
12891289
return builder.build(Id);
12901290
}
12911291

1292-
static ValueDecl *getDifferentiableFunctionConstructor(
1293-
ASTContext &Context, Identifier Id, unsigned arity, bool throws) {
1294-
assert(arity >= 1);
1295-
unsigned numGenericParams = 1 + arity;
1296-
BuiltinFunctionBuilder builder(Context, numGenericParams);
1297-
// Get the `Differentiable` and `AdditiveArithmetic` protocols.
1298-
auto *diffableProto =
1299-
Context.getProtocol(KnownProtocolKind::Differentiable);
1300-
auto *tangentVectorDecl =
1301-
diffableProto->getAssociatedType(Context.Id_TangentVector);
1302-
assert(tangentVectorDecl);
1303-
// Create type parameters and add conformance constraints.
1304-
auto origResultGen = makeGenericParam(arity);
1305-
builder.addConformanceRequirement(origResultGen, diffableProto);
1306-
SmallVector<decltype(origResultGen), 2> fnArgGens;
1307-
for (auto i : range(arity)) {
1308-
auto T = makeGenericParam(i);
1309-
builder.addConformanceRequirement(T, diffableProto);
1310-
fnArgGens.push_back(T);
1311-
}
1312-
1313-
BuiltinFunctionBuilder::LambdaGenerator origFnGen {
1314-
[=, &fnArgGens](BuiltinFunctionBuilder &builder) -> Type {
1315-
SmallVector<FunctionType::Param, 2> params;
1316-
for (auto &paramGen : fnArgGens)
1317-
params.push_back(FunctionType::Param(paramGen.build(builder)));
1318-
return FunctionType::get(params, origResultGen.build(builder))
1319-
->withExtInfo(FunctionType::ExtInfoBuilder(
1320-
FunctionTypeRepresentation::Swift, throws)
1321-
.build());
1322-
}
1323-
};
1324-
1325-
BuiltinFunctionBuilder::LambdaGenerator jvpGen {
1326-
[=, &fnArgGens, &Context](BuiltinFunctionBuilder &builder) -> Type {
1327-
SmallVector<FunctionType::Param, 2> params;
1328-
for (auto &paramGen : fnArgGens)
1329-
params.push_back(FunctionType::Param(paramGen.build(builder)));
1330-
auto origResultType = origResultGen.build(builder);
1331-
SmallVector<FunctionType::Param, 2> differentialParams;
1332-
for (auto &param : params) {
1333-
auto tanType = DependentMemberType::get(
1334-
param.getPlainType(), tangentVectorDecl);
1335-
differentialParams.push_back(FunctionType::Param(tanType));
1336-
}
1337-
auto differentialResultType = DependentMemberType::get(
1338-
origResultType, tangentVectorDecl);
1339-
auto differentialType =
1340-
FunctionType::get({differentialParams}, differentialResultType);
1341-
auto jvpResultType = TupleType::get(
1342-
{TupleTypeElt(origResultType, Context.Id_value),
1343-
TupleTypeElt(differentialType, Context.Id_differential)}, Context);
1344-
return FunctionType::get(params, jvpResultType)
1345-
->withExtInfo(FunctionType::ExtInfoBuilder(
1346-
FunctionTypeRepresentation::Swift, throws)
1347-
.build());
1348-
}
1349-
};
1350-
1351-
BuiltinFunctionBuilder::LambdaGenerator vjpGen {
1352-
[=, &fnArgGens, &Context](BuiltinFunctionBuilder &builder) -> Type {
1353-
SmallVector<FunctionType::Param, 2> params;
1354-
for (auto &paramGen : fnArgGens)
1355-
params.push_back(FunctionType::Param(paramGen.build(builder)));
1356-
auto origResultType = origResultGen.build(builder);
1357-
SmallVector<TupleTypeElt, 2> pullbackResultTupleElts;
1358-
for (auto &param : params) {
1359-
auto tanType = DependentMemberType::get(
1360-
param.getPlainType(), tangentVectorDecl);
1361-
pullbackResultTupleElts.push_back(TupleTypeElt(tanType));
1362-
}
1363-
auto pullbackParam = FunctionType::Param(
1364-
DependentMemberType::get(origResultType, tangentVectorDecl));
1365-
auto pullbackType = FunctionType::get(
1366-
{pullbackParam},
1367-
pullbackResultTupleElts.size() == 1
1368-
? pullbackResultTupleElts.front().getType()
1369-
: TupleType::get(pullbackResultTupleElts, Context));
1370-
auto vjpResultType = TupleType::get(
1371-
{TupleTypeElt(origResultType, Context.Id_value),
1372-
TupleTypeElt(pullbackType, Context.Id_pullback)}, Context);
1373-
return FunctionType::get(params, vjpResultType)
1374-
->withExtInfo(FunctionType::ExtInfoBuilder(
1375-
FunctionTypeRepresentation::Swift, throws)
1376-
.build());
1377-
}
1378-
};
1379-
1380-
BuiltinFunctionBuilder::LambdaGenerator resultGen {
1381-
[&](BuiltinFunctionBuilder &builder) -> Type {
1382-
auto origFnType = origFnGen.build(builder)->castTo<FunctionType>();
1383-
return origFnType->withExtInfo(
1384-
origFnType->getExtInfo()
1385-
.intoBuilder()
1386-
.withDifferentiabilityKind(DifferentiabilityKind::Normal)
1387-
.build());
1388-
}
1389-
};
1390-
1391-
builder.addParameter(origFnGen, ValueOwnership::Owned);
1392-
builder.addParameter(jvpGen, ValueOwnership::Owned);
1393-
builder.addParameter(vjpGen, ValueOwnership::Owned);
1394-
builder.setResult(resultGen);
1395-
return builder.build(Id);
1396-
}
1397-
1398-
static ValueDecl *getLinearFunctionConstructor(
1399-
ASTContext &Context, Identifier Id, unsigned arity, bool throws) {
1400-
assert(arity >= 1);
1401-
unsigned numGenericParams = 1 + arity;
1402-
BuiltinFunctionBuilder builder(Context, numGenericParams);
1403-
// Get the `Differentiable` and `AdditiveArithmetic` protocols.
1404-
auto *diffableProto =
1405-
Context.getProtocol(KnownProtocolKind::Differentiable);
1406-
auto *addArithProto =
1407-
Context.getProtocol(KnownProtocolKind::AdditiveArithmetic);
1408-
// Create type parameters and add conformance constraints.
1409-
auto origResultGen = makeGenericParam(arity);
1410-
builder.addConformanceRequirement(origResultGen, diffableProto);
1411-
builder.addConformanceRequirement(origResultGen, addArithProto);
1412-
SmallVector<decltype(origResultGen), 2> fnArgGens;
1413-
for (auto i : range(arity)) {
1414-
auto T = makeGenericParam(i);
1415-
builder.addConformanceRequirement(T, diffableProto);
1416-
builder.addConformanceRequirement(T, addArithProto);
1417-
fnArgGens.push_back(T);
1418-
}
1419-
1420-
BuiltinFunctionBuilder::LambdaGenerator origFnGen {
1421-
[=, &fnArgGens](BuiltinFunctionBuilder &builder) -> Type {
1422-
SmallVector<FunctionType::Param, 2> params;
1423-
for (auto &paramGen : fnArgGens)
1424-
params.push_back(FunctionType::Param(paramGen.build(builder)));
1425-
return FunctionType::get(params, origResultGen.build(builder))
1426-
->withExtInfo(FunctionType::ExtInfoBuilder(
1427-
FunctionTypeRepresentation::Swift, throws)
1428-
.build());
1429-
}
1430-
};
1431-
1432-
BuiltinFunctionBuilder::LambdaGenerator transposeFnGen {
1433-
[=, &fnArgGens, &Context](BuiltinFunctionBuilder &builder) -> Type {
1434-
auto origResultType = origResultGen.build(builder);
1435-
SmallVector<TupleTypeElt, 2> resultTupleElts;
1436-
for (auto &paramGen : fnArgGens)
1437-
resultTupleElts.push_back(paramGen.build(builder));
1438-
return FunctionType::get(
1439-
{FunctionType::Param(origResultType)},
1440-
resultTupleElts.size() == 1
1441-
? resultTupleElts.front().getType()
1442-
: TupleType::get(resultTupleElts, Context));
1443-
}
1444-
};
1445-
1446-
BuiltinFunctionBuilder::LambdaGenerator resultGen {
1447-
[&](BuiltinFunctionBuilder &builder) -> Type {
1448-
auto origFnType = origFnGen.build(builder)->castTo<FunctionType>();
1449-
return origFnType->withExtInfo(
1450-
origFnType->getExtInfo()
1451-
.intoBuilder()
1452-
.withDifferentiabilityKind(DifferentiabilityKind::Linear)
1453-
.build());
1454-
}
1455-
};
1456-
1457-
builder.addParameter(origFnGen, ValueOwnership::Owned);
1458-
builder.addParameter(transposeFnGen, ValueOwnership::Owned);
1459-
builder.setResult(resultGen);
1460-
return builder.build(Id);
1461-
}
1462-
1463-
1464-
14651292
static ValueDecl *getGlobalStringTablePointer(ASTContext &Context,
14661293
Identifier Id) {
14671294
// String -> Builtin.RawPointer
@@ -2403,22 +2230,6 @@ ValueDecl *swift::getBuiltinValueDecl(ASTContext &Context, Identifier Id) {
24032230
return nullptr;
24042231
return getAutoDiffApplyTransposeFunction(Context, Id, arity, throws);
24052232
}
2406-
if (OperationName.startswith("differentiableFunction_")) {
2407-
unsigned arity;
2408-
bool throws;
2409-
if (!autodiff::getBuiltinDifferentiableOrLinearFunctionConfig(
2410-
OperationName, arity, throws))
2411-
return nullptr;
2412-
return getDifferentiableFunctionConstructor(Context, Id, arity, throws);
2413-
}
2414-
if (OperationName.startswith("linearFunction_")) {
2415-
unsigned arity;
2416-
bool throws;
2417-
if (!autodiff::getBuiltinDifferentiableOrLinearFunctionConfig(
2418-
OperationName, arity, throws))
2419-
return nullptr;
2420-
return getLinearFunctionConstructor(Context, Id, arity, throws);
2421-
}
24222233

24232234
auto BV = llvm::StringSwitch<BuiltinValueKind>(OperationName)
24242235
#define BUILTIN(id, name, Attrs) .Case(name, BuiltinValueKind::id)
@@ -2702,8 +2513,6 @@ ValueDecl *swift::getBuiltinValueDecl(ASTContext &Context, Identifier Id) {
27022513

27032514
case BuiltinValueKind::ApplyDerivative:
27042515
case BuiltinValueKind::ApplyTranspose:
2705-
case BuiltinValueKind::DifferentiableFunction:
2706-
case BuiltinValueKind::LinearFunction:
27072516
llvm_unreachable("Handled above");
27082517

27092518
case BuiltinValueKind::OnFastPath:

lib/SIL/IR/SILModule.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -366,10 +366,6 @@ const BuiltinInfo &SILModule::getBuiltinInfo(Identifier ID) {
366366
Info.ID = BuiltinValueKind::ApplyDerivative;
367367
else if (OperationName.startswith("applyTranspose_"))
368368
Info.ID = BuiltinValueKind::ApplyTranspose;
369-
else if (OperationName.startswith("differentiableFunction_"))
370-
Info.ID = BuiltinValueKind::DifferentiableFunction;
371-
else if (OperationName.startswith("linearFunction_"))
372-
Info.ID = BuiltinValueKind::LinearFunction;
373369
else
374370
Info.ID = llvm::StringSwitch<BuiltinValueKind>(OperationName)
375371
#define BUILTIN(id, name, attrs) .Case(name, BuiltinValueKind::id)

stdlib/public/Differentiation/DifferentialOperators.swift

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,10 @@ import Swift
1919
// Transpose
2020

2121
@inlinable
22-
public func transpose<T, R>(
22+
public func _transpose<T, R>(
2323
of body: @escaping @differentiable(linear) (T) -> R
2424
) -> @differentiable(linear) (R) -> T {
25-
let original = body as (T) -> R
26-
let transpose = { x in Builtin.applyTranspose_arity1(body, x) }
27-
return Builtin.linearFunction_arity1(transpose, original)
25+
fatalError("Transpose is unimplemented and unsupported")
2826
}
2927

3028
// Value with differential

stdlib/public/Differentiation/DifferentiationUtilities.swift

Lines changed: 0 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -17,63 +17,6 @@
1717

1818
import Swift
1919

20-
//===----------------------------------------------------------------------===//
21-
// Differentiable function creation
22-
//===----------------------------------------------------------------------===//
23-
24-
/// Create a differentiable function from a vector-Jacobian products function.
25-
@inlinable
26-
public func differentiableFunction<T : Differentiable, R : Differentiable>(
27-
from vjp: @escaping (T)
28-
-> (value: R, pullback: (R.TangentVector) -> T.TangentVector)
29-
) -> @differentiable (T) -> R {
30-
Builtin.differentiableFunction_arity1(
31-
/*original*/ { vjp($0).value },
32-
/*jvp*/ { _ in
33-
fatalError("""
34-
Functions formed with `differentiableFunction(from:)` cannot yet \
35-
be used with differential-producing differential operators.
36-
""")
37-
},
38-
/*vjp*/ vjp)
39-
}
40-
41-
/// Create a differentiable function from a vector-Jacobian products function.
42-
@inlinable
43-
public func differentiableFunction<T, U, R>(
44-
from vjp: @escaping (T, U)
45-
-> (value: R, pullback: (R.TangentVector)
46-
-> (T.TangentVector, U.TangentVector))
47-
) -> @differentiable (T, U) -> R {
48-
Builtin.differentiableFunction_arity2(
49-
/*original*/ { vjp($0, $1).value },
50-
/*jvp*/ { _, _ in
51-
fatalError("""
52-
Functions formed with `differentiableFunction(from:)` cannot yet \
53-
be used with differential-producing differential operators.
54-
""")
55-
},
56-
/*vjp*/ vjp)
57-
}
58-
59-
/// Create a differentiable function from a vector-Jacobian products function.
60-
@inlinable
61-
public func differentiableFunction<T, U, V, R>(
62-
from vjp: @escaping (T, U, V)
63-
-> (value: R, pullback: (R.TangentVector)
64-
-> (T.TangentVector, U.TangentVector, V.TangentVector))
65-
) -> @differentiable (T, U, V) -> R {
66-
Builtin.differentiableFunction_arity3(
67-
/*original*/ { vjp($0, $1, $2).value },
68-
/*jvp*/ { _, _, _ in
69-
fatalError("""
70-
Functions formed with `differentiableFunction(from:)` cannot yet \
71-
be used with differential-producing differential operators.
72-
""")
73-
},
74-
/*vjp*/ vjp)
75-
}
76-
7720
//===----------------------------------------------------------------------===//
7821
// Derivative customization
7922
//===----------------------------------------------------------------------===//

test/AutoDiff/SILGen/autodiff_builtins.swift

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -127,32 +127,6 @@ func applyTranspose_f_indirect_arity1<T: AdditiveArithmetic & Differentiable>(_
127127
// CHECK: bb0([[OUT_PARAM:%.*]] : $*T, [[X:%.*]] : $*T):
128128
// CHECK: [[RESULT:%.*]] = apply [[TRANSPOSE:%.*]]([[OUT_PARAM]], [[X]])
129129

130-
// MARK: - differentiableFunction
131-
132-
@_silgen_name("differentiableFunction_f_direct_arity1")
133-
func differentiableFunction_f_direct_arity1() -> @differentiable (Float) -> Float {
134-
return Builtin.differentiableFunction_arity1(f_direct_arity1, f_direct_arity1_jvp, f_direct_arity1_vjp)
135-
}
136-
// CHECK-LABEL: sil{{.*}}@differentiableFunction_f_direct_arity1
137-
// CHECK: [[DIFF_FN:%.*]] = differentiable_function
138-
// CHECK: return [[DIFF_FN]]
139-
140-
// MARK: - linearFunction
141-
// TODO(TF-1142): Add linear_funcion to this test when it exists.
142-
143-
@_silgen_name("linearFunction_f_direct_arity1")
144-
func linearFunction_f_direct_arity1() -> @differentiable(linear) (Float) -> Float {
145-
return Builtin.linearFunction_arity1(f_direct_arity1, f_direct_arity1)
146-
}
147-
// CHECK-LABEL: sil{{.*}}@linearFunction_f_direct_arity1
148-
// CHECK: bb0:
149-
// CHECK: [[ORIG1:%.*]] = function_ref @f_direct_arity1 : $@convention(thin) (Float) -> Float
150-
// CHECK: [[THICK_ORIG1:%.*]] = thin_to_thick_function [[ORIG1]] : $@convention(thin) (Float) -> Float to $@callee_guaranteed (Float) -> Float
151-
// CHECK: [[ORIG2:%.*]] = function_ref @f_direct_arity1 : $@convention(thin) (Float) -> Float
152-
// CHECK: [[THICK_ORIG2:%.*]] = thin_to_thick_function [[ORIG2]] : $@convention(thin) (Float) -> Float to $@callee_guaranteed (Float) -> Float
153-
// CHECK: [[LINEAR:%.*]] = linear_function [parameters 0] [[THICK_ORIG1]] : $@callee_guaranteed (Float) -> Float with_transpose [[THICK_ORIG2]] : $@callee_guaranteed (Float) -> Float
154-
// CHECK: return [[LINEAR]] : $@differentiable(linear) @callee_guaranteed (Float) -> Float
155-
156130
struct ExamplePullbackStruct<T: Differentiable> {
157131
var pb0: (T.TangentVector) -> T.TangentVector
158132
}

0 commit comments

Comments
 (0)