Skip to content

Commit 163429e

Browse files
dan-zhengbgogul
authored andcommitted
[AutoDiff] NFC: IRGen gardening. (#28238)
Move code from lib/IRGen/GenProto.cpp to dedicated file lib/IRGen/GenDiffWitness.cpp.
1 parent 3a5e486 commit 163429e

File tree

5 files changed

+67
-36
lines changed

5 files changed

+67
-36
lines changed

lib/IRGen/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ add_swift_host_library(swiftIRGen STATIC
2626
GenDecl.cpp
2727
# SWIFT_ENABLE_TENSORFLOW
2828
GenDiffFunc.cpp
29+
GenDiffWitness.cpp
30+
# SWIFT_ENABLE_TENSORFLOW END
2931
GenEnum.cpp
3032
GenExistential.cpp
3133
GenFunc.cpp

lib/IRGen/GenDecl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1068,7 +1068,7 @@ void IRGenerator::emitGlobalTopLevel() {
10681068
// SWIFT_ENABLE_TENSORFLOW
10691069
// Emit differentiability witnesses.
10701070
for (auto &dw :
1071-
PrimaryIGM->getSILModule().getDifferentiabilityWitnessList()) {
1071+
PrimaryIGM->getSILModule().getDifferentiabilityWitnessList()) {
10721072
if (dw.isDeclaration())
10731073
continue;
10741074

lib/IRGen/GenDiffFunc.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,13 @@
99
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
1010
//
1111
//===----------------------------------------------------------------------===//
12+
//
1213
// SWIFT_ENABLE_TENSORFLOW
14+
//
15+
// This file implements IR generation for `@differentiable` function types in
16+
// Swift.
17+
//
18+
//===----------------------------------------------------------------------===//
1319

1420
#include "swift/AST/Decl.h"
1521
#include "swift/AST/Pattern.h"
@@ -32,7 +38,6 @@
3238
using namespace swift;
3339
using namespace irgen;
3440

35-
3641
//----------------------------------------------------------------------------//
3742
// `@differentiable` (non-linear) function type info
3843
//----------------------------------------------------------------------------//

lib/IRGen/GenDiffWitness.cpp

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
//===--- GenDiffWitness.cpp - IRGen for differentiability witnesses -------===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See https://swift.org/LICENSE.txt for license information
9+
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===----------------------------------------------------------------------===//
12+
//
13+
// SWIFT_ENABLE_TENSORFLOW
14+
//
15+
// This file implements IR generation for SIL differentiability witnesses.
16+
//
17+
//===----------------------------------------------------------------------===//
18+
19+
#include "swift/AST/PrettyStackTrace.h"
20+
#include "swift/SIL/SILDifferentiabilityWitness.h"
21+
22+
#include "ConstantBuilder.h"
23+
#include "IRGenModule.h"
24+
25+
using namespace swift;
26+
using namespace irgen;
27+
28+
void IRGenModule::emitSILDifferentiabilityWitness(
29+
SILDifferentiabilityWitness *dw) {
30+
PrettyStackTraceDifferentiabilityWitness _st(
31+
"emitting differentiability witness for", dw->getKey());
32+
33+
// Don't emit declarations.
34+
if (dw->isDeclaration())
35+
return;
36+
37+
ConstantInitBuilder builder(*this);
38+
auto diffWitnessContents = builder.beginStruct();
39+
40+
// TODO(TF-894): When the differentiation transform canonicalizes all
41+
// differentiability witnesses to have JVP/VJP functions, remove the nullptr
42+
// cases and assert that JVP/VJP functions exist.
43+
if (dw->getJVP()) {
44+
diffWitnessContents.addBitCast(
45+
getAddrOfSILFunction(dw->getJVP(), NotForDefinition), Int8PtrTy);
46+
} else {
47+
diffWitnessContents.addNullPointer(Int8PtrTy);
48+
}
49+
if (dw->getVJP()) {
50+
diffWitnessContents.addBitCast(
51+
getAddrOfSILFunction(dw->getVJP(), NotForDefinition), Int8PtrTy);
52+
} else {
53+
diffWitnessContents.addNullPointer(Int8PtrTy);
54+
}
55+
56+
getAddrOfDifferentiabilityWitness(
57+
dw, diffWitnessContents.finishAndCreateFuture());
58+
}

lib/IRGen/GenProto.cpp

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2166,40 +2166,6 @@ void IRGenModule::emitSILWitnessTable(SILWitnessTable *wt) {
21662166
RequireMetadata);
21672167
}
21682168

2169-
// SWIFT_ENABLE_TENSORFLOW
2170-
void IRGenModule::emitSILDifferentiabilityWitness(
2171-
SILDifferentiabilityWitness *dw) {
2172-
PrettyStackTraceDifferentiabilityWitness _st(
2173-
"emitting differentiability witness for", dw->getKey());
2174-
2175-
// Don't emit declarations.
2176-
if (dw->isDeclaration())
2177-
return;
2178-
2179-
ConstantInitBuilder builder(*this);
2180-
auto diffWitnessContents = builder.beginStruct();
2181-
2182-
// TODO(marcrasi): When the differentiation pass generates JVP/VJP for
2183-
// witnesses, remove the nullptr case and add assertions that the JVP/VJP
2184-
// exist.
2185-
if (dw->getJVP()) {
2186-
diffWitnessContents.addBitCast(
2187-
getAddrOfSILFunction(dw->getJVP(), NotForDefinition), Int8PtrTy);
2188-
} else {
2189-
diffWitnessContents.addNullPointer(Int8PtrTy);
2190-
}
2191-
if (dw->getVJP()) {
2192-
diffWitnessContents.addBitCast(
2193-
getAddrOfSILFunction(dw->getVJP(), NotForDefinition), Int8PtrTy);
2194-
} else {
2195-
diffWitnessContents.addNullPointer(Int8PtrTy);
2196-
}
2197-
2198-
getAddrOfDifferentiabilityWitness(
2199-
dw, diffWitnessContents.finishAndCreateFuture());
2200-
}
2201-
// SWIFT_ENABLE_TENSORFLOW_END
2202-
22032169
/// True if a function's signature in LLVM carries polymorphic parameters.
22042170
/// Generic functions and protocol witnesses carry polymorphic parameters.
22052171
bool irgen::hasPolymorphicParameters(CanSILFunctionType ty) {

0 commit comments

Comments
 (0)