Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit ae8eac0

Browse files
koen-dejongherxwei
authored andcommitted
use enable in _vjpCausallyMasked (#218)
Fixes #217
1 parent 41c4a7a commit ae8eac0

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

Transformer/Model.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ func causallyMasked(_ dotProducts: Tensor<Float>, enable: Bool = false) -> Tenso
102102
// causal mask is intentionally invisible to differentiation
103103
func _vjpCausallyMasked(_ dotProducts: Tensor<Float>, enable: Bool)
104104
-> (Tensor<Float>, (Tensor<Float>) -> Tensor<Float>) {
105-
return (causallyMasked(dotProducts), identity)
105+
return (causallyMasked(dotProducts, enable: enable), identity)
106106
}
107107

108108
struct Attention: ParameterlessLayer {

0 commit comments

Comments
 (0)