14
14
15
15
import TensorFlow
16
16
17
- struct Config {
17
+ struct Config : Codable {
18
18
let vocabSize : Int
19
19
let contextSize : Int
20
20
let embeddingSize : Int
21
21
let headCount : Int
22
22
let layerCount : Int
23
- }
24
23
25
- extension Config {
26
- init ( dictionary: [ String : Int ] ) {
27
- vocabSize = dictionary [ " n_vocab " ] !
28
- contextSize = dictionary [ " n_ctx " ] !
29
- embeddingSize = dictionary [ " n_embd " ] !
30
- headCount = dictionary [ " n_head " ] !
31
- layerCount = dictionary [ " n_layer " ] !
24
+ enum CodingKeys : String , CodingKey {
25
+ case vocabSize = " n_vocab "
26
+ case contextSize = " n_ctx "
27
+ case embeddingSize = " n_embd "
28
+ case headCount = " n_head "
29
+ case layerCount = " n_layer "
32
30
}
33
31
}
34
32
35
- let config = Config ( dictionary: [
36
- " n_vocab " : 50257 ,
37
- " n_ctx " : 1024 ,
38
- " n_embd " : 768 ,
39
- " n_head " : 12 ,
40
- " n_layer " : 12
41
- ] )
42
-
43
33
func readTensor< Scalar: TensorFlowScalar > (
44
34
fromPath path: String ,
45
35
name: String ,
@@ -55,18 +45,23 @@ func readTensor<Scalar: TensorFlowScalar>(
55
45
}
56
46
57
47
protocol InitializableFromPythonCheckpoint {
58
- init ( contentsOfPythonCheckpointFile path: String , scope: String )
48
+ init ( contentsOfPythonCheckpointFile path: String , config : Config , scope: String )
59
49
}
60
50
61
51
extension Dense : InitializableFromPythonCheckpoint {
62
- init ( contentsOfPythonCheckpointFile path: String , scope: String ) {
52
+ init ( contentsOfPythonCheckpointFile path: String , config : Config , scope: String ) {
63
53
let kernel = readTensor ( fromPath: path, name: scope + " /w " , scalarType: Scalar . self)
64
54
self . init (
65
55
weight: kernel. squeezingShape ( at: 0 ) ,
66
56
bias: readTensor ( fromPath: path, name: scope + " /b " , scalarType: Scalar . self) ,
67
57
activation: identity)
68
58
}
69
- init ( contentsOfPythonCheckpointFile path: String , scope: String , activation: String ) {
59
+ init (
60
+ contentsOfPythonCheckpointFile path: String ,
61
+ config: Config ,
62
+ scope: String ,
63
+ activation: String
64
+ ) {
70
65
let kernel = readTensor ( fromPath: path, name: scope + " /w " , scalarType: Scalar . self)
71
66
self . init (
72
67
weight: kernel. squeezingShape ( at: 0 ) ,
@@ -76,7 +71,7 @@ extension Dense: InitializableFromPythonCheckpoint {
76
71
}
77
72
78
73
extension LayerNorm : InitializableFromPythonCheckpoint {
79
- init ( contentsOfPythonCheckpointFile path: String , scope: String ) {
74
+ init ( contentsOfPythonCheckpointFile path: String , config : Config , scope: String ) {
80
75
self . init (
81
76
offset: readTensor ( fromPath: path, name: scope + " /b " , scalarType: Scalar . self) ,
82
77
scale: readTensor ( fromPath: path, name: scope + " /g " , scalarType: Scalar . self) ,
@@ -86,57 +81,66 @@ extension LayerNorm: InitializableFromPythonCheckpoint {
86
81
}
87
82
88
83
extension MultiHeadAttention : InitializableFromPythonCheckpoint {
89
- init ( contentsOfPythonCheckpointFile path: String , scope: String ) {
84
+ init ( contentsOfPythonCheckpointFile path: String , config : Config , scope: String ) {
90
85
attention = Attention (
91
86
size: config. embeddingSize / config. headCount,
92
87
causal: true ,
93
88
dropProbability: 0.2 )
94
89
wqkv = TimeDistributed ( Dense < Float > (
95
90
contentsOfPythonCheckpointFile: path,
91
+ config: config,
96
92
scope: scope + " /c_attn " ) )
97
93
wo = TimeDistributed ( Dense < Float > (
98
94
contentsOfPythonCheckpointFile: path,
95
+ config: config,
99
96
scope: scope + " /c_proj " ) )
100
- headCount = 12
97
+ headCount = config . headCount
101
98
}
102
99
}
103
100
104
101
extension FeedForward : InitializableFromPythonCheckpoint {
105
- init ( contentsOfPythonCheckpointFile path: String , scope: String ) {
102
+ init ( contentsOfPythonCheckpointFile path: String , config : Config , scope: String ) {
106
103
dense1 = TimeDistributed ( Dense < Float > (
107
104
contentsOfPythonCheckpointFile: path,
108
- scope: scope + " /c_fc " , activation: " gelu " ) )
105
+ config: config,
106
+ scope: scope + " /c_fc " ,
107
+ activation: " gelu " ) )
109
108
dense2 = TimeDistributed ( Dense < Float > (
110
109
contentsOfPythonCheckpointFile: path,
110
+ config: config,
111
111
scope: scope + " /c_proj " ) )
112
112
dropout = Dropout ( probability: 0.2 )
113
113
}
114
114
}
115
115
116
116
extension EncoderLayer : InitializableFromPythonCheckpoint {
117
- init ( contentsOfPythonCheckpointFile path: String , scope: String ) {
117
+ init ( contentsOfPythonCheckpointFile path: String , config : Config , scope: String ) {
118
118
selfAttention = MultiHeadAttention (
119
- contentsOfPythonCheckpointFile: path,
120
- scope: scope + " /attn " )
119
+ contentsOfPythonCheckpointFile: path, config: config, scope: scope + " /attn " )
121
120
selfAttentionDropout = Dropout ( probability: 0.2 )
122
- selfAttentionNorm = LayerNorm ( contentsOfPythonCheckpointFile: path, scope: scope + " /ln_1 " )
123
- feedForward = FeedForward ( contentsOfPythonCheckpointFile: path, scope: scope + " /mlp " )
121
+ selfAttentionNorm = LayerNorm (
122
+ contentsOfPythonCheckpointFile: path, config: config, scope: scope + " /ln_1 " )
123
+ feedForward = FeedForward (
124
+ contentsOfPythonCheckpointFile: path, config: config, scope: scope + " /mlp " )
124
125
feedForwardDropout = Dropout ( probability: 0.2 )
125
- feedForwardNorm = LayerNorm ( contentsOfPythonCheckpointFile: path, scope: scope + " /ln_2 " )
126
+ feedForwardNorm = LayerNorm (
127
+ contentsOfPythonCheckpointFile: path, config: config, scope: scope + " /ln_2 " )
126
128
}
127
129
}
128
130
129
131
extension TransformerLM : InitializableFromPythonCheckpoint {
130
- init ( contentsOfPythonCheckpointFile path: String , scope: String ) {
132
+ init ( contentsOfPythonCheckpointFile path: String , config : Config , scope: String ) {
131
133
embedding = Embedding (
132
134
weight: readTensor ( fromPath: path, name: scope + " /wte " , scalarType: Float . self) )
133
135
positionalEmbeddings = readTensor (
134
136
fromPath: path,
135
137
name: scope + " /wpe " ,
136
138
scalarType: Float . self)
137
139
layers = ( 0 ..< config. layerCount) . map { i in
138
- EncoderLayer ( contentsOfPythonCheckpointFile: path, scope: scope + " /h \( i) " )
140
+ EncoderLayer (
141
+ contentsOfPythonCheckpointFile: path, config: config, scope: scope + " /h \( i) " )
139
142
}
140
- norm = LayerNorm ( contentsOfPythonCheckpointFile: path, scope: scope + " /ln_f " )
143
+ norm = LayerNorm (
144
+ contentsOfPythonCheckpointFile: path, config: config, scope: scope + " /ln_f " )
141
145
}
142
146
}
0 commit comments