Skip to content

Commit 483d399

Browse files
authored
[AutoDiff] Temporarily disable non-varied result warning. (#26928)
Previously, a warning and fixit was generated for differentiable functions with non-varied results. ``` tf-775.swift:3:3: warning: result does not depend on differentiation arguments and will always have a zero derivative; do you want to use 'withoutDerivative(at:)'? .zero ^ withoutDerivative(at: ) ``` However, TF-775 exposes that the fixit does not work and the warning is unsilenceable. This patch temporarily disables the warning, as a robust fix requires non-trivial activity analysis changes. A lack of warning is better than an unsilenceable false positive. TF-788 tracks re-enabling the warning. Add regression test to ensure that unsilenceable warnings will not happen again.
1 parent 2f377d9 commit 483d399

File tree

3 files changed

+27
-9
lines changed

3 files changed

+27
-9
lines changed

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4606,12 +4606,13 @@ class JVPEmitter final
46064606
attr->getIndices().parameters->getNumIndices());
46074607
auto origParamArgs = original->getArgumentsWithoutIndirectResults();
46084608

4609-
// Check if result is not varied.
4609+
// TODO(TF-788): Re-enable non-varied result warning.
4610+
/*
4611+
// Emit a warning and fixit if original result is not varied, because it
4612+
// will always have a zero derivative.
46104613
SmallVector<SILValue, 8> origFormalResults;
46114614
collectAllFormalResultsInTypeOrder(*original, origFormalResults);
46124615
auto origResult = origFormalResults[getIndices().source];
4613-
// Emit warning if original result is not varied, because it will always
4614-
// have a zero derivative.
46154616
if (!activityInfo.isVaried(origResult, getIndices().parameters)) {
46164617
// Emit fixit if original result has a valid source location.
46174618
auto startLoc = origResult.getLoc().getStartSourceLoc();
@@ -4622,6 +4623,7 @@ class JVPEmitter final
46224623
.fixItInsertAfter(endLoc, ")");
46234624
}
46244625
}
4626+
*/
46254627

46264628
auto *diffEntry = getDifferential().getEntryBlock();
46274629
diffBuilder.setInsertionPoint(
@@ -5636,8 +5638,10 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
56365638
SmallVector<SILValue, 8> origFormalResults;
56375639
collectAllFormalResultsInTypeOrder(original, origFormalResults);
56385640
auto origResult = origFormalResults[getIndices().source];
5639-
// Emit warning if original result is not varied, because it will always
5640-
// have a zero derivative.
5641+
// TODO(TF-788): Re-enable non-varied result warning.
5642+
/*
5643+
// Emit a warning and fixit if original result is not varied, because it
5644+
// will always have a zero derivative.
56415645
if (!getActivityInfo().isVaried(origResult, getIndices().parameters)) {
56425646
// Emit fixit if original result has a valid source location.
56435647
auto startLoc = origResult.getLoc().getStartSourceLoc();
@@ -5648,6 +5652,7 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
56485652
.fixItInsertAfter(endLoc, ")");
56495653
}
56505654
}
5655+
*/
56515656
builder.setInsertionPoint(
56525657
pullbackEntry, getNextFunctionLocalAllocationInsertionPoint());
56535658
if (seed->getType().isAddress()) {

test/AutoDiff/autodiff_diagnostics.swift

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,13 @@ _ = gradient(at: NoDerivativeProperty(x: 1, y: 1)) { s -> Float in
5050
return tmp.x
5151
}
5252
_ = gradient(at: NoDerivativeProperty(x: 1, y: 1)) { s in
53-
// expected-warning @+1 {{result does not depend on differentiation arguments and will always have a zero derivative; do you want to use 'withoutDerivative(at:)'?}} {{10-10=withoutDerivative(at:}} {{13-13=)}}
53+
// TODO(TF-788): Re-enable non-varied result warning.
54+
// xpected-warning @+1 {{result does not depend on differentiation arguments and will always have a zero derivative; do you want to use 'withoutDerivative(at:)'?}} {{10-10=withoutDerivative(at:}} {{13-13=)}}
5455
return s.y
5556
}
5657
_ = gradient(at: NoDerivativeProperty(x: 1, y: 1)) {
57-
// expected-warning @+1 {{result does not depend on differentiation arguments and will always have a zero derivative; do you want to use 'withoutDerivative(at:)'?}} {{3-3=withoutDerivative(at:}} {{7-7=)}}
58+
// TODO(TF-788): Re-enable non-varied result warning.
59+
// xpected-warning @+1 {{result does not depend on differentiation arguments and will always have a zero derivative; do you want to use 'withoutDerivative(at:)'?}} {{3-3=withoutDerivative(at:}} {{7-7=)}}
5860
$0.y
5961
}
6062

@@ -295,10 +297,20 @@ func one() -> Float {
295297
}
296298
@differentiable
297299
func nonVariedResult(_ x: Float) -> Float {
298-
// expected-warning @+1 {{result does not depend on differentiation arguments and will always have a zero derivative; do you want to use 'withoutDerivative(at:)'?}} {{10-10=withoutDerivative(at:}} {{15-15=)}}
300+
// TODO(TF-788): Re-enable non-varied result warning.
301+
// xpected-warning @+1 {{result does not depend on differentiation arguments and will always have a zero derivative; do you want to use 'withoutDerivative(at:)'?}} {{10-10=withoutDerivative(at:}} {{15-15=)}}
299302
return one()
300303
}
301304

305+
// Check that `withoutDerivative(at:)` silences the warning.
306+
307+
struct TF_775: Differentiable {
308+
@differentiable(wrt: (self))
309+
func nonVariedResult(_ input: Float) -> Float {
310+
withoutDerivative(at: input)
311+
}
312+
}
313+
302314
//===----------------------------------------------------------------------===//
303315
// Subset parameters
304316
//===----------------------------------------------------------------------===//

test/AutoDiff/forward_mode_diagnostics.swift

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@ func one() -> Float {
8080
}
8181
@differentiable
8282
func nonVariedResult(_ x: Float) -> Float {
83-
// expected-warning @+1 2 {{result does not depend on differentiation arguments and will always have a zero derivative; do you want to use 'withoutDerivative(at:)'?}} {{10-10=withoutDerivative(at:}}
83+
// TODO(TF-788): Re-enable non-varied result warning.
84+
// xpected-warning @+1 2 {{result does not depend on differentiation arguments and will always have a zero derivative; do you want to use 'withoutDerivative(at:)'?}} {{10-10=withoutDerivative(at:}}
8485
return one()
8586
}
8687

0 commit comments

Comments
 (0)