10
10
from typing import Any , Dict
11
11
12
12
import torch
13
-
14
- from executorch .examples .models .model_base import EagerModelBase
15
- from torchtune .models .llama3_2_vision ._convert_weights import llama3_vision_meta_to_tune
16
- from torchtune .models .llama3_2_vision ._component_builders import llama3_2_vision_decoder
17
13
from executorch .examples .models .checkpoint import (
18
- get_default_model_resource_dir ,
19
14
get_checkpoint_dtype ,
15
+ get_default_model_resource_dir ,
20
16
)
21
17
18
+ from executorch .examples .models .model_base import EagerModelBase
19
+ from torchtune .models .llama3_2_vision ._component_builders import llama3_2_vision_decoder
20
+ from torchtune .models .llama3_2_vision ._convert_weights import llama3_vision_meta_to_tune
21
+
22
22
23
23
def to_decoder_checkpoint (checkpoint : Dict [str , Any ]) -> Dict [str , Any ]:
24
24
"""
25
25
Extracts and formats the decoder-related weights from the checkpoint. The checkpoint contains
26
26
weight names prefixed with "encoder"/"decoder", such as "encoder.layer.etc" or "decoder.norm.scale".
27
27
To load the text decoder on its own, the "decoder" prefix needs to be removed.
28
28
"""
29
- return {"." .join (weight .split ("." )[1 :]): value for weight , value in checkpoint .items () if weight .startswith ("decoder" )}
29
+ return {
30
+ "." .join (weight .split ("." )[1 :]): value
31
+ for weight , value in checkpoint .items ()
32
+ if weight .startswith ("decoder" )
33
+ }
34
+
30
35
31
36
class Llama3_2Decoder (EagerModelBase ):
32
37
"""
@@ -36,7 +41,9 @@ class Llama3_2Decoder(EagerModelBase):
36
41
def __init__ (self , ** kwargs ):
37
42
# Set member vars from kwargs.
38
43
self .max_seq_len = kwargs .get ("max_seq_len" , 8192 )
39
- self .encoder_max_seq_len = kwargs .get ("encoder_max_seq_len" , int (4 * (448 / 14 ) ** 2 + 1 ))
44
+ self .encoder_max_seq_len = kwargs .get (
45
+ "encoder_max_seq_len" , int (4 * (448 / 14 ) ** 2 + 1 )
46
+ )
40
47
self .generate_full_logits = kwargs .get ("generate_full_logits" , False )
41
48
self .enable_dynamic_shape = kwargs .get ("enable_dynamic_shape" , False )
42
49
self .output_prune_map_path = kwargs .get ("output_prune_map_path" , None )
@@ -46,7 +53,6 @@ def __init__(self, **kwargs):
46
53
self .verbose = kwargs .get ("verbose" , False )
47
54
self .args = kwargs .get ("args" , None )
48
55
49
-
50
56
ckpt_dir = get_default_model_resource_dir (__file__ )
51
57
# Single checkpoint file.
52
58
checkpoint_path = kwargs .get ("checkpoint" , ckpt_dir / "demo_rand_params.pth" )
@@ -57,7 +63,9 @@ def __init__(self, **kwargs):
57
63
# Load checkpoint and params.
58
64
device = "cpu"
59
65
if checkpoint_dir is not None :
60
- raise NotImplementedError ("Sharded checkpoint not yet supported for Llama3_2Decoder." )
66
+ raise NotImplementedError (
67
+ "Sharded checkpoint not yet supported for Llama3_2Decoder."
68
+ )
61
69
else :
62
70
checkpoint = torch .load (checkpoint_path , map_location = device , mmap = True )
63
71
checkpoint = llama3_vision_meta_to_tune (checkpoint )
@@ -107,7 +115,9 @@ def __init__(self, **kwargs):
107
115
# Prune the output layer if output_prune_map is provided.
108
116
output_prune_map = None
109
117
if self .output_prune_map_path is not None :
110
- from executorch .examples .models .llama2 .source_transformation .prune_output import prune_output_vocab
118
+ from executorch .examples .models .llama2 .source_transformation .prune_output import (
119
+ prune_output_vocab ,
120
+ )
111
121
112
122
with open (self .output_prune_map_path , "r" ) as f :
113
123
output_prune_map = json .load (f )
@@ -123,9 +133,7 @@ def get_eager_model(self) -> torch.nn.Module:
123
133
return self .model_ .to (torch .float16 )
124
134
125
135
def get_example_inputs (self ):
126
- return (
127
- torch .ones (1 , 64 , dtype = torch .long ), # positional inputs
128
- )
136
+ return (torch .ones (1 , 64 , dtype = torch .long ),) # positional inputs
129
137
130
138
def get_example_kwarg_inputs (self ):
131
139
# TODO: add input_pos and mask when after making cache work.
@@ -137,7 +145,7 @@ def get_example_kwarg_inputs(self):
137
145
}
138
146
139
147
def get_dynamic_shapes (self ):
140
- dim = torch .export .Dim ("token_dim" , min = 1 ,max = self .max_seq_len )
148
+ dim = torch .export .Dim ("token_dim" , min = 1 , max = self .max_seq_len )
141
149
dynamic_shapes = {
142
150
"tokens" : {0 : 1 , 1 : dim },
143
151
# "encoder_input": {0:1, 1:dim_enc, 2:4096},
@@ -146,4 +154,3 @@ def get_dynamic_shapes(self):
146
154
# "input_pos" : {0: dim},
147
155
}
148
156
return dynamic_shapes
149
-
0 commit comments