Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit cf21585

Browse files
brettkooncerxwei
authored andcommitted
transformer: upstream api changes (#171)
1 parent ed143c0 commit cf21585

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

Transformer/Model.swift

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,9 @@ struct FeedForward: Layer {
5151
}
5252

5353
struct AttentionInput: Differentiable {
54-
let query: Tensor<Float>
55-
let key: Tensor<Float>
56-
let value: Tensor<Float>
54+
var query: Tensor<Float>
55+
var key: Tensor<Float>
56+
var value: Tensor<Float>
5757
}
5858

5959
@differentiable(wrt: (query, key, value), vjp: _vjpMakeAttentionInput)
@@ -69,8 +69,8 @@ func _vjpMakeAttentionInput(query: Tensor<Float>, key: Tensor<Float>, value: Ten
6969
}
7070

7171
struct AttentionContext: Differentiable {
72-
let key: Tensor<Float>
73-
let value: Tensor<Float>
72+
var key: Tensor<Float>
73+
var value: Tensor<Float>
7474
}
7575

7676
@differentiable(wrt: (key, value), vjp: _vjpMakeAttentionContext)

Transformer/Operators.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ func gelu<Scalar: TensorFlowFloatingPoint>(_ x: Tensor<Scalar>) -> Tensor<Scalar
2929
@differentiable(
3030
wrt: (left, right),
3131
vjp: _vjpBatchedMatmul
32-
where Scalar : Differentiable & FloatingPoint
32+
where Scalar : Differentiable & TensorFlowFloatingPoint
3333
)
3434
func batchedMatmul<Scalar : Numeric>(
3535
_ left: Tensor<Scalar>,
@@ -41,7 +41,7 @@ func batchedMatmul<Scalar : Numeric>(
4141
}
4242

4343
@usableFromInline
44-
func _vjpBatchedMatmul<Scalar : Differentiable & FloatingPoint>(
44+
func _vjpBatchedMatmul<Scalar : Differentiable & TensorFlowFloatingPoint>(
4545
_ left: Tensor<Scalar>,
4646
_ right: Tensor<Scalar>,
4747
adjointLeft: Bool,

0 commit comments

Comments
 (0)