@@ -1289,179 +1289,6 @@ static ValueDecl *getAutoDiffApplyTransposeFunction(
1289
1289
return builder.build (Id);
1290
1290
}
1291
1291
1292
- static ValueDecl *getDifferentiableFunctionConstructor (
1293
- ASTContext &Context, Identifier Id, unsigned arity, bool throws) {
1294
- assert (arity >= 1 );
1295
- unsigned numGenericParams = 1 + arity;
1296
- BuiltinFunctionBuilder builder (Context, numGenericParams);
1297
- // Get the `Differentiable` and `AdditiveArithmetic` protocols.
1298
- auto *diffableProto =
1299
- Context.getProtocol (KnownProtocolKind::Differentiable);
1300
- auto *tangentVectorDecl =
1301
- diffableProto->getAssociatedType (Context.Id_TangentVector );
1302
- assert (tangentVectorDecl);
1303
- // Create type parameters and add conformance constraints.
1304
- auto origResultGen = makeGenericParam (arity);
1305
- builder.addConformanceRequirement (origResultGen, diffableProto);
1306
- SmallVector<decltype (origResultGen), 2 > fnArgGens;
1307
- for (auto i : range (arity)) {
1308
- auto T = makeGenericParam (i);
1309
- builder.addConformanceRequirement (T, diffableProto);
1310
- fnArgGens.push_back (T);
1311
- }
1312
-
1313
- BuiltinFunctionBuilder::LambdaGenerator origFnGen {
1314
- [=, &fnArgGens](BuiltinFunctionBuilder &builder) -> Type {
1315
- SmallVector<FunctionType::Param, 2 > params;
1316
- for (auto ¶mGen : fnArgGens)
1317
- params.push_back (FunctionType::Param (paramGen.build (builder)));
1318
- return FunctionType::get (params, origResultGen.build (builder))
1319
- ->withExtInfo (FunctionType::ExtInfoBuilder (
1320
- FunctionTypeRepresentation::Swift, throws)
1321
- .build ());
1322
- }
1323
- };
1324
-
1325
- BuiltinFunctionBuilder::LambdaGenerator jvpGen {
1326
- [=, &fnArgGens, &Context](BuiltinFunctionBuilder &builder) -> Type {
1327
- SmallVector<FunctionType::Param, 2 > params;
1328
- for (auto ¶mGen : fnArgGens)
1329
- params.push_back (FunctionType::Param (paramGen.build (builder)));
1330
- auto origResultType = origResultGen.build (builder);
1331
- SmallVector<FunctionType::Param, 2 > differentialParams;
1332
- for (auto ¶m : params) {
1333
- auto tanType = DependentMemberType::get (
1334
- param.getPlainType (), tangentVectorDecl);
1335
- differentialParams.push_back (FunctionType::Param (tanType));
1336
- }
1337
- auto differentialResultType = DependentMemberType::get (
1338
- origResultType, tangentVectorDecl);
1339
- auto differentialType =
1340
- FunctionType::get ({differentialParams}, differentialResultType);
1341
- auto jvpResultType = TupleType::get (
1342
- {TupleTypeElt (origResultType, Context.Id_value ),
1343
- TupleTypeElt (differentialType, Context.Id_differential )}, Context);
1344
- return FunctionType::get (params, jvpResultType)
1345
- ->withExtInfo (FunctionType::ExtInfoBuilder (
1346
- FunctionTypeRepresentation::Swift, throws)
1347
- .build ());
1348
- }
1349
- };
1350
-
1351
- BuiltinFunctionBuilder::LambdaGenerator vjpGen {
1352
- [=, &fnArgGens, &Context](BuiltinFunctionBuilder &builder) -> Type {
1353
- SmallVector<FunctionType::Param, 2 > params;
1354
- for (auto ¶mGen : fnArgGens)
1355
- params.push_back (FunctionType::Param (paramGen.build (builder)));
1356
- auto origResultType = origResultGen.build (builder);
1357
- SmallVector<TupleTypeElt, 2 > pullbackResultTupleElts;
1358
- for (auto ¶m : params) {
1359
- auto tanType = DependentMemberType::get (
1360
- param.getPlainType (), tangentVectorDecl);
1361
- pullbackResultTupleElts.push_back (TupleTypeElt (tanType));
1362
- }
1363
- auto pullbackParam = FunctionType::Param (
1364
- DependentMemberType::get (origResultType, tangentVectorDecl));
1365
- auto pullbackType = FunctionType::get (
1366
- {pullbackParam},
1367
- pullbackResultTupleElts.size () == 1
1368
- ? pullbackResultTupleElts.front ().getType ()
1369
- : TupleType::get (pullbackResultTupleElts, Context));
1370
- auto vjpResultType = TupleType::get (
1371
- {TupleTypeElt (origResultType, Context.Id_value ),
1372
- TupleTypeElt (pullbackType, Context.Id_pullback )}, Context);
1373
- return FunctionType::get (params, vjpResultType)
1374
- ->withExtInfo (FunctionType::ExtInfoBuilder (
1375
- FunctionTypeRepresentation::Swift, throws)
1376
- .build ());
1377
- }
1378
- };
1379
-
1380
- BuiltinFunctionBuilder::LambdaGenerator resultGen {
1381
- [&](BuiltinFunctionBuilder &builder) -> Type {
1382
- auto origFnType = origFnGen.build (builder)->castTo <FunctionType>();
1383
- return origFnType->withExtInfo (
1384
- origFnType->getExtInfo ()
1385
- .intoBuilder ()
1386
- .withDifferentiabilityKind (DifferentiabilityKind::Normal)
1387
- .build ());
1388
- }
1389
- };
1390
-
1391
- builder.addParameter (origFnGen, ValueOwnership::Owned);
1392
- builder.addParameter (jvpGen, ValueOwnership::Owned);
1393
- builder.addParameter (vjpGen, ValueOwnership::Owned);
1394
- builder.setResult (resultGen);
1395
- return builder.build (Id);
1396
- }
1397
-
1398
- static ValueDecl *getLinearFunctionConstructor (
1399
- ASTContext &Context, Identifier Id, unsigned arity, bool throws) {
1400
- assert (arity >= 1 );
1401
- unsigned numGenericParams = 1 + arity;
1402
- BuiltinFunctionBuilder builder (Context, numGenericParams);
1403
- // Get the `Differentiable` and `AdditiveArithmetic` protocols.
1404
- auto *diffableProto =
1405
- Context.getProtocol (KnownProtocolKind::Differentiable);
1406
- auto *addArithProto =
1407
- Context.getProtocol (KnownProtocolKind::AdditiveArithmetic);
1408
- // Create type parameters and add conformance constraints.
1409
- auto origResultGen = makeGenericParam (arity);
1410
- builder.addConformanceRequirement (origResultGen, diffableProto);
1411
- builder.addConformanceRequirement (origResultGen, addArithProto);
1412
- SmallVector<decltype (origResultGen), 2 > fnArgGens;
1413
- for (auto i : range (arity)) {
1414
- auto T = makeGenericParam (i);
1415
- builder.addConformanceRequirement (T, diffableProto);
1416
- builder.addConformanceRequirement (T, addArithProto);
1417
- fnArgGens.push_back (T);
1418
- }
1419
-
1420
- BuiltinFunctionBuilder::LambdaGenerator origFnGen {
1421
- [=, &fnArgGens](BuiltinFunctionBuilder &builder) -> Type {
1422
- SmallVector<FunctionType::Param, 2 > params;
1423
- for (auto ¶mGen : fnArgGens)
1424
- params.push_back (FunctionType::Param (paramGen.build (builder)));
1425
- return FunctionType::get (params, origResultGen.build (builder))
1426
- ->withExtInfo (FunctionType::ExtInfoBuilder (
1427
- FunctionTypeRepresentation::Swift, throws)
1428
- .build ());
1429
- }
1430
- };
1431
-
1432
- BuiltinFunctionBuilder::LambdaGenerator transposeFnGen {
1433
- [=, &fnArgGens, &Context](BuiltinFunctionBuilder &builder) -> Type {
1434
- auto origResultType = origResultGen.build (builder);
1435
- SmallVector<TupleTypeElt, 2 > resultTupleElts;
1436
- for (auto ¶mGen : fnArgGens)
1437
- resultTupleElts.push_back (paramGen.build (builder));
1438
- return FunctionType::get (
1439
- {FunctionType::Param (origResultType)},
1440
- resultTupleElts.size () == 1
1441
- ? resultTupleElts.front ().getType ()
1442
- : TupleType::get (resultTupleElts, Context));
1443
- }
1444
- };
1445
-
1446
- BuiltinFunctionBuilder::LambdaGenerator resultGen {
1447
- [&](BuiltinFunctionBuilder &builder) -> Type {
1448
- auto origFnType = origFnGen.build (builder)->castTo <FunctionType>();
1449
- return origFnType->withExtInfo (
1450
- origFnType->getExtInfo ()
1451
- .intoBuilder ()
1452
- .withDifferentiabilityKind (DifferentiabilityKind::Linear)
1453
- .build ());
1454
- }
1455
- };
1456
-
1457
- builder.addParameter (origFnGen, ValueOwnership::Owned);
1458
- builder.addParameter (transposeFnGen, ValueOwnership::Owned);
1459
- builder.setResult (resultGen);
1460
- return builder.build (Id);
1461
- }
1462
-
1463
-
1464
-
1465
1292
static ValueDecl *getGlobalStringTablePointer (ASTContext &Context,
1466
1293
Identifier Id) {
1467
1294
// String -> Builtin.RawPointer
@@ -2403,22 +2230,6 @@ ValueDecl *swift::getBuiltinValueDecl(ASTContext &Context, Identifier Id) {
2403
2230
return nullptr ;
2404
2231
return getAutoDiffApplyTransposeFunction (Context, Id, arity, throws);
2405
2232
}
2406
- if (OperationName.startswith (" differentiableFunction_" )) {
2407
- unsigned arity;
2408
- bool throws;
2409
- if (!autodiff::getBuiltinDifferentiableOrLinearFunctionConfig (
2410
- OperationName, arity, throws))
2411
- return nullptr ;
2412
- return getDifferentiableFunctionConstructor (Context, Id, arity, throws);
2413
- }
2414
- if (OperationName.startswith (" linearFunction_" )) {
2415
- unsigned arity;
2416
- bool throws;
2417
- if (!autodiff::getBuiltinDifferentiableOrLinearFunctionConfig (
2418
- OperationName, arity, throws))
2419
- return nullptr ;
2420
- return getLinearFunctionConstructor (Context, Id, arity, throws);
2421
- }
2422
2233
2423
2234
auto BV = llvm::StringSwitch<BuiltinValueKind>(OperationName)
2424
2235
#define BUILTIN (id, name, Attrs ) .Case(name, BuiltinValueKind::id)
@@ -2702,8 +2513,6 @@ ValueDecl *swift::getBuiltinValueDecl(ASTContext &Context, Identifier Id) {
2702
2513
2703
2514
case BuiltinValueKind::ApplyDerivative:
2704
2515
case BuiltinValueKind::ApplyTranspose:
2705
- case BuiltinValueKind::DifferentiableFunction:
2706
- case BuiltinValueKind::LinearFunction:
2707
2516
llvm_unreachable (" Handled above" );
2708
2517
2709
2518
case BuiltinValueKind::OnFastPath:
0 commit comments