Skip to content

Commit 071f932

Browse files
mikekgfbmalfet
authored andcommitted
l3 (#120)
* l3 * model load * support params file * typo * typo
1 parent 201ca6e commit 071f932

File tree

3 files changed

+149
-43
lines changed

3 files changed

+149
-43
lines changed

export.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,13 @@ def main(checkpoint_path, device, quantize = "{ }", args = None):
7070
print("Loading model ...")
7171
t0 = time.time()
7272
model = _load_model(
73-
checkpoint_path, device=device, precision=precision, use_tp=False)
73+
checkpoint_path,
74+
args.checkpoint_dir,
75+
args.params_path,
76+
device=device,
77+
precision=precision,
78+
use_tp=False
79+
)
7480

7581
device_sync(device=device) # MKG
7682
print(f"Time to load model: {time.time() - t0:.02f} seconds")
@@ -152,6 +158,18 @@ def cli():
152158
default="not_specified",
153159
help="Model checkpoint path.",
154160
)
161+
parser.add_argument(
162+
"--checkpoint-dir",
163+
type=Path,
164+
default=None,
165+
help="Model checkpoint directory.",
166+
)
167+
parser.add_argument(
168+
"--params-path",
169+
type=Path,
170+
default=None,
171+
help="Parameter file path.",
172+
)
155173
parser.add_argument(
156174
"--output-pte-path",
157175
type=str,

generate.py

Lines changed: 73 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -274,12 +274,50 @@ def encode_tokens(tokenizer, string, bos=True, device="cuda"):
274274
return torch.tensor(tokens, dtype=torch.int, device=device)
275275

276276

277-
def _load_model(checkpoint_path, device, precision, use_tp=False):
277+
def _load_model(
278+
checkpoint_path,
279+
checkpoint_dir,
280+
params_path,
281+
device,
282+
precision,
283+
use_tp=False
284+
):
278285
use_cuda = "cuda" in device
279286
with torch.device("meta"):
280-
model = Transformer.from_name(checkpoint_path.parent.name)
287+
if params_path:
288+
model = Transformer.from_params(params_path)
289+
else:
290+
model = Transformer.from_name(checkpoint_path.parent.name)
291+
292+
# checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
293+
cps = []
294+
if checkpoint_dir is not None:
295+
# Load multiple checkpoint; ignore the single path.
296+
checkpoint_path = None
297+
for i in range(4):
298+
cp_name = f"consolidated.{i}.pth"
299+
print(f"Loading {cp_name}")
300+
cps.append(
301+
torch.load(
302+
os.path.join(checkpoint_dir, cp_name),
303+
map_location=device,
304+
mmap=True,
305+
)
306+
)
307+
308+
checkpoint = {}
309+
for key in cps[0].keys():
310+
if not torch.allclose(cps[0][key], cps[1][key]):
311+
values = (cps[0][key], cps[1][key], cps[2][key], cps[3][key])
312+
if key.endswith("wo.weight") or key.endswith("w2.weight"):
313+
checkpoint[key] = torch.cat(values, dim=1)
314+
else:
315+
checkpoint[key] = torch.cat(values, dim=0)
316+
else:
317+
checkpoint[key] = cps[0][key]
318+
else:
319+
checkpoint = torch.load(checkpoint_path, map_location=device, mmap=True, weights_only=True)
281320

282-
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
283321
if "model" in checkpoint and "stories" in str(checkpoint_path):
284322
checkpoint = checkpoint["model"]
285323

@@ -306,6 +344,8 @@ def main(
306344
top_k: int = 200,
307345
temperature: float = 0.8,
308346
checkpoint_path: Optional[Path] = None,
347+
checkpoint_dir: Optional[Path] = None,
348+
params_path: Optional[Path] = None,
309349
tokenizer_path: Optional[Path] = None,
310350
compile: bool = True,
311351
compile_prefill: bool = False,
@@ -351,7 +391,14 @@ def main(
351391

352392
print("Loading model ...")
353393
t0 = time.time()
354-
model_ = _load_model(checkpoint_path, device, precision, use_tp)
394+
model_ = _load_model(
395+
checkpoint_path,
396+
checkpoint_dir,
397+
params_path,
398+
device,
399+
precision,
400+
use_tp
401+
)
355402
if dso_path:
356403
assert not model_dtype, f"dtype setting not valid for a DSO model. Specify dtype during export."
357404
assert quantize is None or quantize == "{ }", f"quantize not valid for exported DSO model. Specify quantization during export."
@@ -390,7 +437,14 @@ def main(
390437
model.to(dtype=name_to_dtype(model_dtype))
391438

392439
if is_speculative:
393-
draft_model = _load_model(draft_checkpoint_path, device, precision, use_tp)
440+
draft_model = _load_model(
441+
draft_checkpoint_path,
442+
None,
443+
None,
444+
device,
445+
precision,
446+
use_tp
447+
)
394448
else:
395449
draft_model = None
396450

@@ -553,6 +607,18 @@ def cli():
553607
default=None,
554608
help="Model checkpoint path.",
555609
)
610+
parser.add_argument(
611+
"--checkpoint-dir",
612+
type=Path,
613+
default=None,
614+
help="Model checkpoint directory.",
615+
)
616+
parser.add_argument(
617+
"--params-path",
618+
type=Path,
619+
default=None,
620+
help="Parameter file path.",
621+
)
556622
parser.add_argument(
557623
"--tokenizer-path",
558624
type=Path,
@@ -621,6 +687,8 @@ def cli():
621687
args.top_k,
622688
args.temperature,
623689
args.checkpoint_path,
690+
args.checkpoint_dir,
691+
args.params_path,
624692
args.tokenizer_path,
625693
args.compile,
626694
args.compile_prefill,

model.py

Lines changed: 57 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -24,23 +24,39 @@ class ModelArgs:
2424
block_size: int = 2048
2525
vocab_size: int = 32000
2626
n_layer: int = 32
27-
n_head: int = 32
27+
# n_head in gpt-fast
28+
n_heads: int = 32
2829
dim: int = 4096
29-
intermediate_size: int = None
30+
# hidden dim is intermediate_size in gpt-fast
31+
hidden_dim: int = None
3032
n_local_heads: int = -1
3133
head_dim: int = 64
3234
rope_base: float = 10000
3335
norm_eps: float = 1e-5
34-
36+
multiple_of = 256
37+
ffn_dim_multiplier = None
38+
3539
def __post_init__(self):
3640
if self.n_local_heads == -1:
37-
self.n_local_heads = self.n_head
38-
if self.intermediate_size is None:
41+
self.n_local_heads = self.n_heads
42+
if self.hidden_dim is None:
43+
# If hidden_dim is not explicitly set in the ModelArgs,
44+
# then calculate implicitly based on dim and
45+
# also multiple of `args.multiple_of`
46+
multiple_of = self.multiple_of
3947
hidden_dim = 4 * self.dim
40-
n_hidden = int(2 * hidden_dim / 3)
41-
self.intermediate_size = find_multiple(n_hidden, 256)
42-
self.head_dim = self.dim // self.n_head
48+
hidden_dim = int(2 * hidden_dim / 3)
49+
if self.ffn_dim_multiplier is not None:
50+
hidden_dim = int(self.ffn_dim_multiplier * hidden_dim)
51+
self.hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
52+
self.head_dim = self.dim // self.n_heads
4353

54+
@classmethod
55+
def from_params(cls, params_path):
56+
with open(params_path, "r") as f:
57+
params = json.loads(f.read())
58+
return cls(**params)
59+
4460
@classmethod
4561
def from_name(cls, name: str):
4662
print(f"name {name}")
@@ -70,47 +86,47 @@ def from_name(cls, name: str):
7086
"CodeLlama-7b-Python-hf": dict(
7187
block_size=16384, vocab_size=32000, n_layer=32, dim=4096, rope_base=1000000
7288
),
73-
"7B": dict(n_layer=32, n_head=32, dim=4096),
74-
"13B": dict(n_layer=40, n_head=40, dim=5120),
75-
"30B": dict(n_layer=60, n_head=52, dim=6656),
89+
"7B": dict(n_layer=32, n_heads=32, dim=4096),
90+
"13B": dict(n_layer=40, n_heads=40, dim=5120),
91+
"30B": dict(n_layer=60, n_heads=52, dim=6656),
7692
"34B": dict(
7793
n_layer=48,
78-
n_head=64,
94+
n_heads=64,
7995
dim=8192,
8096
vocab_size=32000,
8197
n_local_heads=8,
82-
intermediate_size=22016,
98+
hidden_dim=22016,
8399
rope_base=1000000,
84100
), # CodeLlama-34B-Python-hf
85101
"70B": dict(
86-
n_layer=80, n_head=64, dim=8192, n_local_heads=8, intermediate_size=28672
102+
n_layer=80, n_heads=64, dim=8192, n_local_heads=8, hidden_dim=28672
87103
),
88104
"Mistral-7B": dict(
89105
n_layer=32,
90-
n_head=32,
106+
n_heads=32,
91107
n_local_heads=8,
92108
dim=4096,
93-
intermediate_size=14336,
109+
hidden_dim=14336,
94110
vocab_size=32000,
95111
),
96112
"Mistral-7B-Instruct-v0.1": dict(
97113
n_layer=32,
98-
n_head=32,
114+
n_heads=32,
99115
n_local_heads=8,
100116
dim=4096,
101-
intermediate_size=14336,
117+
hidden_dim=14336,
102118
vocab_size=32000,
103119
),
104120
"Mistral-7B-Instruct-v0.2": dict(
105121
n_layer=32,
106-
n_head=32,
122+
n_heads=32,
107123
n_local_heads=8,
108124
dim=4096,
109-
intermediate_size=14336,
125+
hidden_dim=14336,
110126
vocab_size=32000,
111127
),
112-
"stories15M": dict(n_layer=6, n_head=6, dim=288),
113-
"stories110M": dict(n_layer=12, n_head=12, dim=768),
128+
"stories15M": dict(n_layer=6, n_heads=6, dim=288),
129+
"stories110M": dict(n_layer=12, n_heads=12, dim=768),
114130
}
115131

116132

@@ -160,7 +176,7 @@ def setup_caches(self, max_batch_size, max_seq_length):
160176
and self.max_batch_size >= max_batch_size
161177
):
162178
return
163-
head_dim = self.config.dim // self.config.n_head
179+
head_dim = self.config.dim // self.config.n_heads
164180
max_seq_length = find_multiple(max_seq_length, 8)
165181
self.max_seq_length = max_seq_length
166182
self.max_batch_size = max_batch_size
@@ -170,8 +186,8 @@ def setup_caches(self, max_batch_size, max_seq_length):
170186
)
171187

172188
freqs_cis = precompute_freqs_cis(
173-
self.config.block_size,
174-
self.config.dim // self.config.n_head,
189+
self.config.dim // self.config.n_heads,
190+
self.config.block_size * 2,
175191
self.config.rope_base,
176192
)
177193
self.register_buffer("freqs_cis", freqs_cis, persistent=True)
@@ -202,6 +218,10 @@ def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
202218
def from_name(cls, name: str):
203219
return cls(ModelArgs.from_name(name))
204220

221+
@classmethod
222+
def from_params(cls, params_path: str):
223+
return cls(ModelArgs.from_params(params_path))
224+
205225

206226
class TransformerBlock(nn.Module):
207227
def __init__(self, config: ModelArgs) -> None:
@@ -222,19 +242,19 @@ def forward(
222242
class Attention(nn.Module):
223243
def __init__(self, config: ModelArgs):
224244
super().__init__()
225-
assert config.dim % config.n_head == 0
245+
assert config.dim % config.n_heads == 0
226246

227247
# key, query, value projections for all heads, but in a batch
228-
# total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
248+
# total_head_dim = (config.n_heads + 2 * config.n_local_heads) * config.head_dim
229249
# self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
230-
self.wq = nn.Linear(config.dim, config.n_head * config.head_dim, bias=False)
250+
self.wq = nn.Linear(config.dim, config.n_heads * config.head_dim, bias=False)
231251
self.wk = nn.Linear(config.dim, config.n_local_heads * config.head_dim, bias=False)
232252
self.wv = nn.Linear(config.dim, config.n_local_heads * config.head_dim, bias=False)
233253

234254
self.wo = nn.Linear(config.dim, config.dim, bias=False)
235255
self.kv_cache = None
236256

237-
self.n_head = config.n_head
257+
self.n_heads = config.n_heads
238258
self.head_dim = config.head_dim
239259
self.n_local_heads = config.n_local_heads
240260
self.dim = config.dim
@@ -263,7 +283,7 @@ def forward(
263283
# kv_size = self.n_local_heads * self.head_dim
264284
# q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
265285

266-
q = q.view(bsz, seqlen, self.n_head, self.head_dim)
286+
q = q.view(bsz, seqlen, self.n_heads, self.head_dim)
267287
k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
268288
v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
269289

@@ -275,8 +295,8 @@ def forward(
275295
if self.kv_cache is not None:
276296
k, v = self.kv_cache.update(input_pos, k, v)
277297

278-
k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
279-
v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
298+
k = k.repeat_interleave(self.n_heads // self.n_local_heads, dim=1)
299+
v = v.repeat_interleave(self.n_heads // self.n_local_heads, dim=1)
280300
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
281301

282302
y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
@@ -288,9 +308,9 @@ def forward(
288308
class FeedForward(nn.Module):
289309
def __init__(self, config: ModelArgs) -> None:
290310
super().__init__()
291-
self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
292-
self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
293-
self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
311+
self.w1 = nn.Linear(config.dim, config.hidden_dim, bias=False)
312+
self.w2 = nn.Linear(config.hidden_dim, config.dim, bias=False)
313+
self.w3 = nn.Linear(config.dim, config.hidden_dim, bias=False)
294314

295315
def forward(self, x: Tensor) -> Tensor:
296316
return self.w2(F.silu(self.w1(x)) * self.w3(x))
@@ -309,8 +329,8 @@ def forward(self, x: Tensor) -> Tensor:
309329
output = self._norm(x.float()).type_as(x)
310330
return output * self.weight
311331

312-
313-
def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor:
332+
# transpsoed first two arguments to align with model in ET
333+
def precompute_freqs_cis(n_elem: int, seq_len: int, base: int = 10000) -> Tensor:
314334
freqs = 1.0 / (
315335
base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
316336
)

0 commit comments

Comments
 (0)