@@ -111,30 +111,10 @@ public struct GoModelOutput: Differentiable {
111
111
public let logits : Tensor < Float >
112
112
}
113
113
114
- // This might be needed when we add training to work around an AD bug for memberwise initializers
115
- // @differentiable(wrt: (policy, value, logits), vjp: _vjpMakeGoModelOutput)
116
- // func makeGoModelOutput(
117
- // policy: Tensor<Float>, value: Tensor<Float>, logits: Tensor<Float>)
118
- // -> GoModelOutput {
119
- // return GoModelOutput(policy: policy, value: value, logits: logits)
120
- // }
121
- // func _vjpMakeGoModelOutput(
122
- // policy: Tensor<Float>, value: Tensor<Float>, logits: Tensor<Float>)
123
- // -> (GoModelOutput, (GoModelOutput.CotangentVector)
124
- // -> (Tensor<Float>, Tensor<Float>, Tensor<Float>)) {
125
- // let result = GoModelOutput(policy: policy, value: value, logits: logits)
126
- // return (result, { seed in (seed.policy, seed.value, seed.logits) })
127
- // }
128
-
129
114
public struct GoModel : Layer {
130
115
@noDerivative let configuration : ModelConfiguration
131
116
var initialConv : ConvBN
132
- // TODO(jekbradbury): support differentiation wrt residualBlocks
133
- // [T] where T: Differentiable doesn't (shouldn't?) conform to Differentiable,
134
- // so we will likely need a LayerArray<T> where T: Layer type. But this
135
- // itself won't work until we have better generics support, and even then
136
- // T can't be an existential Layer. So it's @noDerivative for now.
137
- @noDerivative var residualBlocks : [ ResidualIdentityBlock ]
117
+ var residualBlocks : [ ResidualIdentityBlock ]
138
118
var policyConv : ConvBN
139
119
var policyDense : Dense < Float >
140
120
var valueConv : ConvBN
0 commit comments