|
22 | 22 | #include "swift/AST/DiagnosticsSIL.h"
|
23 | 23 | #include "swift/AST/ForeignInfo.h"
|
24 | 24 | #include "swift/AST/GenericEnvironment.h"
|
| 25 | +// SWIFT_ENABLE_TENSORFLOW |
| 26 | +#include "swift/AST/GenericSignatureBuilder.h" |
25 | 27 | #include "swift/AST/Module.h"
|
26 | 28 | #include "swift/AST/ProtocolConformance.h"
|
27 | 29 | #include "swift/SIL/SILModule.h"
|
@@ -148,22 +150,68 @@ CanSILFunctionType SILFunctionType::getWithoutDifferentiability() {
|
148 | 150 | getOptionalErrorResult(), getASTContext());
|
149 | 151 | }
|
150 | 152 |
|
| 153 | +// Returns the canonical generic signature for an autodiff associated function |
| 154 | +// given an existing associated function generic signature. All differentiation |
| 155 | +// parameters are constrained to conform to `Differentiable`. |
| 156 | +static CanGenericSignature getAutoDiffAssociatedFunctionGenericSignature( |
| 157 | + CanGenericSignature assocFnGenSig, |
| 158 | + ArrayRef<SILParameterInfo> originalParameters, |
| 159 | + AutoDiffIndexSubset *parameterIndices, SILModule &module) { |
| 160 | + // If associated function has no |
| 161 | + if (!assocFnGenSig) |
| 162 | + return nullptr; |
| 163 | + auto &ctx = module.getASTContext(); |
| 164 | + GenericSignatureBuilder builder(ctx); |
| 165 | + |
| 166 | + // Add associated function generic signature. |
| 167 | + builder.addGenericSignature(assocFnGenSig); |
| 168 | + // Constrain all wrt parameters to conform to `Differentiable`. |
| 169 | + auto source = |
| 170 | + GenericSignatureBuilder::FloatingRequirementSource::forAbstract(); |
| 171 | + auto *diffableProto = ctx.getProtocol(KnownProtocolKind::Differentiable); |
| 172 | + for (unsigned paramIdx : parameterIndices->getIndices()) { |
| 173 | + auto paramType = originalParameters[paramIdx].getType(); |
| 174 | + Requirement req(RequirementKind::Conformance, paramType, |
| 175 | + diffableProto->getDeclaredType()); |
| 176 | + builder.addRequirement(req, source, module.getSwiftModule()); |
| 177 | + } |
| 178 | + return std::move(builder) |
| 179 | + .computeGenericSignature(SourceLoc(), /*allowConcreteGenericParams*/ true) |
| 180 | + ->getCanonicalSignature(); |
| 181 | +} |
| 182 | + |
151 | 183 | CanSILFunctionType SILFunctionType::getAutoDiffAssociatedFunctionType(
|
152 | 184 | AutoDiffIndexSubset *parameterIndices, unsigned resultIndex,
|
153 | 185 | unsigned differentiationOrder, AutoDiffAssociatedFunctionKind kind,
|
154 | 186 | SILModule &module, LookupConformanceFn lookupConformance,
|
155 |
| - CanGenericSignature whereClauseGenSig) { |
| 187 | + CanGenericSignature assocFnGenSig) { |
156 | 188 | // JVP: (T...) -> ((R...),
|
157 | 189 | // (T.TangentVector...) -> (R.TangentVector...))
|
158 | 190 | // VJP: (T...) -> ((R...),
|
159 | 191 | // (R.TangentVector...) -> (T.TangentVector...))
|
160 | 192 |
|
161 | 193 | auto &ctx = getASTContext();
|
162 | 194 | auto &typeConverter = module.Types;
|
163 |
| - if (!whereClauseGenSig) |
164 |
| - whereClauseGenSig = getGenericSignature(); |
165 |
| - Lowering::GenericContextScope genericContextScope( |
166 |
| - module.Types, whereClauseGenSig); |
| 195 | + |
| 196 | + // Helper function testing if we are differentiating wrt this index. |
| 197 | + auto isWrtIndex = [&](unsigned index) -> bool { |
| 198 | + return index < parameterIndices->getCapacity() && |
| 199 | + parameterIndices->contains(index); |
| 200 | + }; |
| 201 | + |
| 202 | + // Calculate differentiation parameter infos. |
| 203 | + SmallVector<SILParameterInfo, 4> wrtParams; |
| 204 | + for (auto valueAndIndex : enumerate(getParameters())) |
| 205 | + if (isWrtIndex(valueAndIndex.index())) |
| 206 | + wrtParams.push_back(valueAndIndex.value()); |
| 207 | + |
| 208 | + // Get the canonical associated function generic signature. |
| 209 | + if (!assocFnGenSig) |
| 210 | + assocFnGenSig = getGenericSignature(); |
| 211 | + assocFnGenSig = getAutoDiffAssociatedFunctionGenericSignature( |
| 212 | + assocFnGenSig, getParameters(), parameterIndices, module); |
| 213 | + Lowering::GenericContextScope genericContextScope(module.Types, |
| 214 | + assocFnGenSig); |
167 | 215 |
|
168 | 216 | // Given a type, returns its formal SIL parameter info.
|
169 | 217 | auto getTangentParameterInfoForOriginalResult = [&](
|
@@ -214,18 +262,6 @@ CanSILFunctionType SILFunctionType::getAutoDiffAssociatedFunctionType(
|
214 | 262 | return {tanType, conv};
|
215 | 263 | };
|
216 | 264 |
|
217 |
| - // Helper function testing if we are differentiating wrt this index. |
218 |
| - auto isWrtIndex = [&](unsigned index) -> bool { |
219 |
| - return index < parameterIndices->getCapacity() && |
220 |
| - parameterIndices->contains(index); |
221 |
| - }; |
222 |
| - |
223 |
| - // Calculate differentiation parameter infos. |
224 |
| - SmallVector<SILParameterInfo, 4> wrtParams; |
225 |
| - for (auto valueAndIndex : enumerate(getParameters())) |
226 |
| - if (isWrtIndex(valueAndIndex.index())) |
227 |
| - wrtParams.push_back(valueAndIndex.value()); |
228 |
| - |
229 | 265 | CanSILFunctionType closureType;
|
230 | 266 | switch (kind) {
|
231 | 267 | case AutoDiffAssociatedFunctionKind::JVP: {
|
@@ -280,12 +316,12 @@ CanSILFunctionType SILFunctionType::getAutoDiffAssociatedFunctionType(
|
280 | 316 | newResults.reserve(getNumResults() + 1);
|
281 | 317 | for (auto &result : getResults()) {
|
282 | 318 | auto mappedResult = result.getWithType(
|
283 |
| - result.getType()->getCanonicalType(whereClauseGenSig)); |
| 319 | + result.getType()->getCanonicalType(assocFnGenSig)); |
284 | 320 | newResults.push_back(mappedResult);
|
285 | 321 | }
|
286 |
| - newResults.push_back({closureType->getCanonicalType(whereClauseGenSig), |
| 322 | + newResults.push_back({closureType->getCanonicalType(assocFnGenSig), |
287 | 323 | ResultConvention::Owned});
|
288 |
| - return SILFunctionType::get(whereClauseGenSig, getExtInfo(), |
| 324 | + return SILFunctionType::get(assocFnGenSig, getExtInfo(), |
289 | 325 | getCoroutineKind(), getCalleeConvention(),
|
290 | 326 | getParameters(), getYields(), newResults,
|
291 | 327 | getOptionalErrorResult(), ctx,
|
|
0 commit comments