Skip to content

Commit 5a73cc1

Browse files
authored
[AutoDiff] Diagnose non-differentiable original function arguments/results. (#26407)
Diagnose non-differentiable original function arguments/results in `VJPEmitter::visitApplyInst`. Perform the check again if original function is specialized via partial application. Resolves TF-687.
1 parent 19c9e5f commit 5a73cc1

File tree

3 files changed

+57
-21
lines changed

3 files changed

+57
-21
lines changed

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 35 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3303,27 +3303,36 @@ class VJPEmitter final
33033303
/*differentiationOrder*/ 1, functionSource);
33043304
}
33053305

3306-
// Check and diagnose non-differentiable arguments.
3307-
for (unsigned paramIndex : range(originalFnTy->getNumParameters())) {
3308-
if (indices.isWrtParameter(paramIndex) &&
3309-
!originalFnTy->getParameters()[paramIndex]
3310-
.getSILStorageType()
3311-
.isDifferentiable(getModule())) {
3312-
context.emitNondifferentiabilityError(
3313-
original, invoker, diag::autodiff_nondifferentiable_argument);
3314-
errorOccurred = true;
3315-
return;
3316-
}
3317-
}
3318-
// Check and diagnose non-differentiable results.
3319-
if (!originalFnTy->getResults()[indices.source]
3320-
.getSILStorageType()
3321-
.isDifferentiable(getModule())) {
3322-
context.emitNondifferentiabilityError(
3323-
original, invoker, diag::autodiff_nondifferentiable_result);
3324-
errorOccurred = true;
3306+
// Check and diagnose non-differentiable original function type.
3307+
auto diagnoseNondifferentiableOriginalFunctionType =
3308+
[&](CanSILFunctionType origFnTy) {
3309+
// Check and diagnose non-differentiable arguments.
3310+
for (unsigned paramIndex : range(originalFnTy->getNumParameters())) {
3311+
if (indices.isWrtParameter(paramIndex) &&
3312+
!originalFnTy->getParameters()[paramIndex]
3313+
.getSILStorageType()
3314+
.isDifferentiable(getModule())) {
3315+
context.emitNondifferentiabilityError(
3316+
ai->getArgumentsWithoutIndirectResults()[paramIndex], invoker,
3317+
diag::autodiff_nondifferentiable_argument);
3318+
errorOccurred = true;
3319+
return true;
3320+
}
3321+
}
3322+
// Check and diagnose non-differentiable results.
3323+
if (!originalFnTy->getResults()[indices.source]
3324+
.getSILStorageType()
3325+
.isDifferentiable(getModule())) {
3326+
context.emitNondifferentiabilityError(
3327+
original, invoker, diag::autodiff_nondifferentiable_result);
3328+
errorOccurred = true;
3329+
return true;
3330+
}
3331+
return false;
3332+
};
3333+
if (diagnoseNondifferentiableOriginalFunctionType(originalFnTy))
33253334
return;
3326-
}
3335+
33273336
// If VJP has not yet been found, emit an `autodiff_function` instruction
33283337
// on the remapped original function operand and `autodiff_function_extract`
33293338
// the VJP. The actual JVP/VJP functions will be populated in the
@@ -3354,6 +3363,10 @@ class VJPEmitter final
33543363
ai->getLoc(), original, substMap, {},
33553364
ParameterConvention::Direct_Guaranteed);
33563365
original = vjpPartialApply;
3366+
originalFnTy = original->getType().castTo<SILFunctionType>();
3367+
// Diagnose if new original function type is non-differentiable.
3368+
if (diagnoseNondifferentiableOriginalFunctionType(originalFnTy))
3369+
return;
33573370
}
33583371

33593372
auto *autoDiffFuncInst = context.createAutoDiffFunction(
@@ -3363,6 +3376,8 @@ class VJPEmitter final
33633376

33643377
// Record the `autodiff_function` instruction.
33653378
context.getAutoDiffFunctionInsts().push_back(autoDiffFuncInst);
3379+
// TODO(TF-689): Make `autodiff_function` store result indices and remove
3380+
// `ADContext::resultIndices`.
33663381
context.getResultIndices()[autoDiffFuncInst] =
33673382
activeResultIndices.front();
33683383

test/AutoDiff/autodiff_diagnostics.swift

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,10 @@ let no_return: @differentiable (Float) -> Float = { x in
252252
// expected-note @+1 {{missing return for differentiation}}
253253
}
254254

255+
//===----------------------------------------------------------------------===//
256+
// Non-differentiable arguments and results
257+
//===----------------------------------------------------------------------===//
258+
255259
// expected-error @+1 {{function is not differentiable}}
256260
@differentiable
257261
// expected-note @+1 {{when differentiating this function definition}}

test/AutoDiff/autodiff_indirect_diagnostics.swift

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,23 @@ struct TF8Struct<Scalar> : TF8Proto where Scalar : FloatingPoint & Differentiabl
7070

7171
_ = gradient(at: 1.0, in: { x in x.squareRoot() })
7272

73+
//===----------------------------------------------------------------------===//
74+
// Non-differentiable arguments and results
75+
//===----------------------------------------------------------------------===//
76+
77+
struct TF_687<T> : Differentiable {
78+
@noDerivative var indirectDummy: T
79+
var base: Float
80+
81+
init(_ base: Float, dummy: T) {
82+
self.base = base
83+
self.indirectDummy = dummy
84+
}
85+
}
86+
// expected-error @+2 {{function is not differentiable}}
87+
// expected-note @+1 {{cannot differentiate through a non-differentiable argument; do you want to use 'withoutDerivative(at:)'?}}
88+
let _: @differentiable (Float) -> TF_687<Any> = { x in TF_687<Any>(x, dummy: x) }
89+
7390
//===----------------------------------------------------------------------===//
7491
// Add `Differentiable` conformance for generic wrt parameters
7592
//===----------------------------------------------------------------------===//
@@ -87,4 +104,4 @@ extension TF_691: Differentiable where Scalar: Differentiable {}
87104

88105
func identity<T>(_ x: TF_691<T>) -> TF_691<T> { x }
89106
let _: @differentiable (Float) -> TF_691<Float> = { x in identity(TF_691(x)) }
90-
let _: @differentiable (Float) -> TF_691<Float> = { x in id(TF_691(x)) }
107+
let _: @differentiable (Float) -> TF_691<Float> = { x in id(TF_691(x)) }

0 commit comments

Comments
 (0)