Skip to content

Commit 5663e3c

Browse files
[MTK] Add support for Llama 3.2 and code updates to align with current ET API for dynamic dim
Differential Revision: D65986967 Pull Request resolved: #6726
1 parent c734ad4 commit 5663e3c

File tree

12 files changed

+825352
-13
lines changed

12 files changed

+825352
-13
lines changed

examples/mediatek/aot_utils/llm_utils/sanity_checks.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,9 @@ def check_weights_exist(weight_dir):
204204
f"No weight files found in {weight_dir}! Weight files should be either .bin or .safetensors file types."
205205
)
206206
safetensors_l = [f for f in os.listdir(weight_dir) if f.endswith(".safetensors")]
207-
bin_l = [f for f in os.listdir(weight_dir) if f.endswith(".bin")]
207+
bin_l = [
208+
f for f in os.listdir(weight_dir) if f.endswith(".bin") and "embedding" not in f
209+
]
208210
if len(safetensors_l) & len(bin_l):
209211
raise RuntimeError(
210212
"Weights should only be in either .bin or .safetensors format, not both."

examples/mediatek/model_export_scripts/llama.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,9 @@ def main():
419419
print(f"Max Num Token: {max_num_token}")
420420
print(f"Max Cache Size: {max_cache_size}")
421421

422+
if args.dataset is not None:
423+
embedding_layer = get_embedding_layer(config, weight_dir, state_dict)
424+
422425
# Instantiate model chunks
423426
print("Instantiating submodels")
424427
models = []
@@ -437,7 +440,6 @@ def main():
437440
cal_dataset = None
438441
if args.dataset is not None:
439442
cal_dataset = load_dataset("text", data_files=args.dataset, split="train")
440-
embedding_layer = get_embedding_layer(config, weight_dir, state_dict)
441443
master_rot_emb = get_master_rot_emb(config, dtype=torch.float32)
442444
if args.preformatter is not None:
443445
cal_dataset = cal_dataset.map(

examples/mediatek/models/llm_models/modeling_common.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -675,8 +675,8 @@ def load_weights(self, state_dict, state_dict_start_idx):
675675
)
676676
else:
677677
if self.config.tie_word_embeddings:
678-
lm_head_weight_key = "embed_tokens.weight"
679-
lm_head_bias_key = "embed_tokens.bias"
678+
lm_head_weight_key = f"{prefix}embed_tokens.weight"
679+
lm_head_bias_key = f"{prefix}embed_tokens.bias"
680680
else:
681681
lm_head_weight_key = "lm_head.weight"
682682
lm_head_bias_key = "lm_head.bias"
@@ -751,15 +751,16 @@ def get_example_inputs(
751751
for _ in range(2 * self.num_blocks)
752752
],
753753
)
754+
# Specify dims that would be dynamic during calibration
754755
# Note: Assume cache size fixed shape as torch dynamic shape cannot handle dim 3 being
755756
# combination of 2 dynamic dims
756757
if get_dym_shape:
757758
nt = Dim("num_token", max=num_token)
758759
cache_dims = tuple(({} for _ in range(2 * self.num_blocks)))
759760
dynamic_shapes = (
760-
{0: None, 1: nt, 2: None},
761-
{0: None, 1: None, 2: nt, 3: nt + cache_size},
762-
{0: None, 1: None, 2: nt, 3: None},
761+
{0: Dim.STATIC, 1: nt, 2: Dim.STATIC},
762+
{0: Dim.STATIC, 1: Dim.STATIC, 2: nt, 3: nt + cache_size},
763+
{0: Dim.STATIC, 1: Dim.STATIC, 2: nt, 3: Dim.STATIC},
763764
cache_dims,
764765
)
765766
return example_inputs, dynamic_shapes
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
{
2+
"architectures": [
3+
"LlamaForCausalLM"
4+
],
5+
"bos_token_id": 128000,
6+
"eos_token_id": 128001,
7+
"head_dim": 64,
8+
"hidden_size": 2048,
9+
"initializer_range": 0.02,
10+
"intermediate_size": 8192,
11+
"max_position_embeddings": 131072,
12+
"model_type": "llama",
13+
"num_attention_heads": 32,
14+
"num_hidden_layers": 16,
15+
"num_key_value_heads": 8,
16+
"rms_norm_eps": 1e-05,
17+
"rope_theta": 500000.0,
18+
"tie_word_embeddings": true,
19+
"torch_dtype": "bfloat16",
20+
"transformers_version": "4.45.0.dev0",
21+
"vocab_size": 128256,
22+
"tokenizer": "pretrained_fast"
23+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
{
2+
"bos_token": {
3+
"content": "<|begin_of_text|>",
4+
"lstrip": false,
5+
"normalized": false,
6+
"rstrip": false,
7+
"single_word": false
8+
},
9+
"eos_token": {
10+
"content": "<|eot_id|>",
11+
"lstrip": false,
12+
"normalized": false,
13+
"rstrip": false,
14+
"single_word": false
15+
}
16+
}

0 commit comments

Comments
 (0)