Skip to content

Commit 9b319cf

Browse files
committed
Simplifying test case.
1 parent 3550b99 commit 9b319cf

File tree

1 file changed

+6
-21
lines changed

1 file changed

+6
-21
lines changed

test/AutoDiff/validation-test/inout_control_flow.swift

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,6 @@
11
// RUN: %target-run-simple-swift
22
// REQUIRES: executable_test
33

4-
// Would fail due to unavailability of swift_autoDiffCreateLinearMapContext.
5-
// UNSUPPORTED: use_os_stdlib
6-
// UNSUPPORTED: back_deployment_runtime
7-
84
import StdlibUnittest
95
import _Differentiation
106

@@ -43,19 +39,8 @@ InoutControlFlowTests.test("MutatingBeforeControlFlow") {
4339
}
4440

4541
// SR-14053
46-
protocol NumericDifferentiable : Numeric, Differentiable {
47-
@differentiable(reverse) static func *(lhs: Self, rhs: Self) -> Self
48-
}
49-
50-
extension Float: NumericDifferentiable {}
51-
52-
struct Model2<T: NumericDifferentiable>: Differentiable {
53-
var first: T
54-
var second: T
55-
}
56-
5742
@differentiable(reverse)
58-
func adjust<T: NumericDifferentiable>(model: inout Model2<T>, multiplier: T) {
43+
func adjust(model: inout Model, multiplier: Float) {
5944
model.first = model.second * multiplier
6045

6146
// Dummy no-op if block, required to introduce control flow.
@@ -64,21 +49,21 @@ func adjust<T: NumericDifferentiable>(model: inout Model2<T>, multiplier: T) {
6449
}
6550

6651
@differentiable(reverse)
67-
func loss2(model: Model2<Float>, multiplier: Float) -> Float {
52+
func loss2(model: Model, multiplier: Float) -> Float {
6853
var model = model
6954
adjust(model: &model, multiplier: multiplier)
7055
return model.first
7156
}
7257

7358
InoutControlFlowTests.test("InoutParameterWithControlFlow") {
74-
var model = Model2<Float>(first: 1, second: 3)
59+
var model = Model(first: 1, second: 3)
7560
let grad = gradient(at: model, 5.0, of: loss2)
7661
expectEqual(0, grad.0.first)
7762
expectEqual(5, grad.0.second)
7863
}
7964

8065
@differentiable(reverse)
81-
func adjust2<T: NumericDifferentiable>(multiplier: T, model: inout Model2<T>) {
66+
func adjust2(multiplier: Float, model: inout Model) {
8267
model.first = model.second * multiplier
8368

8469
// Dummy no-op if block, required to introduce control flow.
@@ -87,14 +72,14 @@ func adjust2<T: NumericDifferentiable>(multiplier: T, model: inout Model2<T>) {
8772
}
8873

8974
@differentiable(reverse)
90-
func loss3(model: Model2<Float>, multiplier: Float) -> Float {
75+
func loss3(model: Model, multiplier: Float) -> Float {
9176
var model = model
9277
adjust2(multiplier: multiplier, model: &model)
9378
return model.first
9479
}
9580

9681
InoutControlFlowTests.test("LaterInoutParameterWithControlFlow") {
97-
var model = Model2<Float>(first: 1, second: 3)
82+
var model = Model(first: 1, second: 3)
9883
let grad = gradient(at: model, 5.0, of: loss3)
9984
expectEqual(0, grad.0.first)
10085
expectEqual(5, grad.0.second)

0 commit comments

Comments
 (0)