Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit 35c20e8

Browse files
leoxzhaosaeta
authored andcommitted
Load config variable from hparams.json file so that Transformer can work with the bigger GPT-2 models (It is called "staged release". as of now, only 117M and 345M are available). (#154)
1 parent 67f7c64 commit 35c20e8

File tree

2 files changed

+48
-38
lines changed

2 files changed

+48
-38
lines changed

Transformer/PythonCheckpointReader.swift

Lines changed: 38 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -14,32 +14,22 @@
1414

1515
import TensorFlow
1616

17-
struct Config {
17+
struct Config : Codable {
1818
let vocabSize: Int
1919
let contextSize: Int
2020
let embeddingSize: Int
2121
let headCount: Int
2222
let layerCount: Int
23-
}
2423

25-
extension Config {
26-
init(dictionary: [String: Int]) {
27-
vocabSize = dictionary["n_vocab"]!
28-
contextSize = dictionary["n_ctx"]!
29-
embeddingSize = dictionary["n_embd"]!
30-
headCount = dictionary["n_head"]!
31-
layerCount = dictionary["n_layer"]!
24+
enum CodingKeys: String, CodingKey {
25+
case vocabSize = "n_vocab"
26+
case contextSize = "n_ctx"
27+
case embeddingSize = "n_embd"
28+
case headCount = "n_head"
29+
case layerCount = "n_layer"
3230
}
3331
}
3432

35-
let config = Config(dictionary: [
36-
"n_vocab": 50257,
37-
"n_ctx": 1024,
38-
"n_embd": 768,
39-
"n_head": 12,
40-
"n_layer": 12
41-
])
42-
4333
func readTensor<Scalar: TensorFlowScalar>(
4434
fromPath path: String,
4535
name: String,
@@ -55,18 +45,23 @@ func readTensor<Scalar: TensorFlowScalar>(
5545
}
5646

5747
protocol InitializableFromPythonCheckpoint {
58-
init(contentsOfPythonCheckpointFile path: String, scope: String)
48+
init(contentsOfPythonCheckpointFile path: String, config: Config, scope: String)
5949
}
6050

6151
extension Dense: InitializableFromPythonCheckpoint {
62-
init(contentsOfPythonCheckpointFile path: String, scope: String) {
52+
init(contentsOfPythonCheckpointFile path: String, config: Config, scope: String) {
6353
let kernel = readTensor(fromPath: path, name: scope + "/w", scalarType: Scalar.self)
6454
self.init(
6555
weight: kernel.squeezingShape(at: 0),
6656
bias: readTensor(fromPath: path, name: scope + "/b", scalarType: Scalar.self),
6757
activation: identity)
6858
}
69-
init(contentsOfPythonCheckpointFile path: String, scope: String, activation: String) {
59+
init(
60+
contentsOfPythonCheckpointFile path: String,
61+
config: Config,
62+
scope: String,
63+
activation: String
64+
) {
7065
let kernel = readTensor(fromPath: path, name: scope + "/w", scalarType: Scalar.self)
7166
self.init(
7267
weight: kernel.squeezingShape(at: 0),
@@ -76,7 +71,7 @@ extension Dense: InitializableFromPythonCheckpoint {
7671
}
7772

7873
extension LayerNorm: InitializableFromPythonCheckpoint {
79-
init(contentsOfPythonCheckpointFile path: String, scope: String) {
74+
init(contentsOfPythonCheckpointFile path: String, config: Config, scope: String) {
8075
self.init(
8176
offset: readTensor(fromPath: path, name: scope + "/b", scalarType: Scalar.self),
8277
scale: readTensor(fromPath: path, name: scope + "/g", scalarType: Scalar.self),
@@ -86,57 +81,66 @@ extension LayerNorm: InitializableFromPythonCheckpoint {
8681
}
8782

8883
extension MultiHeadAttention: InitializableFromPythonCheckpoint {
89-
init(contentsOfPythonCheckpointFile path: String, scope: String) {
84+
init(contentsOfPythonCheckpointFile path: String, config: Config, scope: String) {
9085
attention = Attention(
9186
size: config.embeddingSize / config.headCount,
9287
causal: true,
9388
dropProbability: 0.2)
9489
wqkv = TimeDistributed(Dense<Float>(
9590
contentsOfPythonCheckpointFile: path,
91+
config: config,
9692
scope: scope + "/c_attn"))
9793
wo = TimeDistributed(Dense<Float>(
9894
contentsOfPythonCheckpointFile: path,
95+
config: config,
9996
scope: scope + "/c_proj"))
100-
headCount = 12
97+
headCount = config.headCount
10198
}
10299
}
103100

104101
extension FeedForward: InitializableFromPythonCheckpoint {
105-
init(contentsOfPythonCheckpointFile path: String, scope: String) {
102+
init(contentsOfPythonCheckpointFile path: String, config: Config, scope: String) {
106103
dense1 = TimeDistributed(Dense<Float>(
107104
contentsOfPythonCheckpointFile: path,
108-
scope: scope + "/c_fc", activation: "gelu"))
105+
config: config,
106+
scope: scope + "/c_fc",
107+
activation: "gelu"))
109108
dense2 = TimeDistributed(Dense<Float>(
110109
contentsOfPythonCheckpointFile: path,
110+
config: config,
111111
scope: scope + "/c_proj"))
112112
dropout = Dropout(probability: 0.2)
113113
}
114114
}
115115

116116
extension EncoderLayer: InitializableFromPythonCheckpoint {
117-
init(contentsOfPythonCheckpointFile path: String, scope: String) {
117+
init(contentsOfPythonCheckpointFile path: String, config: Config, scope: String) {
118118
selfAttention = MultiHeadAttention(
119-
contentsOfPythonCheckpointFile: path,
120-
scope: scope + "/attn")
119+
contentsOfPythonCheckpointFile: path, config: config, scope: scope + "/attn")
121120
selfAttentionDropout = Dropout(probability: 0.2)
122-
selfAttentionNorm = LayerNorm(contentsOfPythonCheckpointFile: path, scope: scope + "/ln_1")
123-
feedForward = FeedForward(contentsOfPythonCheckpointFile: path, scope: scope + "/mlp")
121+
selfAttentionNorm = LayerNorm(
122+
contentsOfPythonCheckpointFile: path, config: config, scope: scope + "/ln_1")
123+
feedForward = FeedForward(
124+
contentsOfPythonCheckpointFile: path, config: config, scope: scope + "/mlp")
124125
feedForwardDropout = Dropout(probability: 0.2)
125-
feedForwardNorm = LayerNorm(contentsOfPythonCheckpointFile: path, scope: scope + "/ln_2")
126+
feedForwardNorm = LayerNorm(
127+
contentsOfPythonCheckpointFile: path, config: config, scope: scope + "/ln_2")
126128
}
127129
}
128130

129131
extension TransformerLM: InitializableFromPythonCheckpoint {
130-
init(contentsOfPythonCheckpointFile path: String, scope: String) {
132+
init(contentsOfPythonCheckpointFile path: String, config: Config, scope: String) {
131133
embedding = Embedding(
132134
weight: readTensor(fromPath: path, name: scope + "/wte", scalarType: Float.self))
133135
positionalEmbeddings = readTensor(
134136
fromPath: path,
135137
name: scope + "/wpe",
136138
scalarType: Float.self)
137139
layers = (0..<config.layerCount).map { i in
138-
EncoderLayer(contentsOfPythonCheckpointFile: path, scope: scope + "/h\(i)")
140+
EncoderLayer(
141+
contentsOfPythonCheckpointFile: path, config: config, scope: scope + "/h\(i)")
139142
}
140-
norm = LayerNorm(contentsOfPythonCheckpointFile: path, scope: scope + "/ln_f")
143+
norm = LayerNorm(
144+
contentsOfPythonCheckpointFile: path, config: config, scope: scope + "/ln_f")
141145
}
142146
}

Transformer/main.swift

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,19 @@
1414

1515
import Python
1616
import TensorFlow
17+
import Foundation
1718

19+
let modelName = "117M"
1820
let sys = Python.import("sys")
1921
sys.path = sys.path + ["."]
20-
let encoder = Python.import("encoder").get_encoder("117M")
21-
22-
let checkpoint = "models/117M/model.ckpt"
23-
let model = TransformerLM(contentsOfPythonCheckpointFile: checkpoint, scope: "model")
22+
let encoder = Python.import("encoder").get_encoder(modelName)
23+
24+
let checkpoint = "models/\(modelName)/model.ckpt"
25+
let configFile = "models/\(modelName)/hparams.json"
26+
let configData = try Data.init(contentsOf: URL(fileURLWithPath: configFile))
27+
let config = try JSONDecoder().decode(Config.self, from: configData)
28+
let model = TransformerLM(
29+
contentsOfPythonCheckpointFile: checkpoint, config: config, scope: "model")
2430

2531
let start_token = Int32(encoder.encoder["<|endoftext|>"])!
2632
var tokens = Tensor(shape: [1, 1], scalars: [start_token])

0 commit comments

Comments
 (0)