Skip to content

Commit e145bd1

Browse files
committed
Lint
1 parent e0c4b8a commit e145bd1

File tree

2 files changed

+23
-16
lines changed

2 files changed

+23
-16
lines changed

examples/models/llama2/export_llama_lib.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -817,7 +817,7 @@ def _load_llama_model(
817817
generate_full_logits=generate_full_logits,
818818
fairseq2=weight_type == WeightType.FAIRSEQ2,
819819
max_seq_len=max_seq_len,
820-
enable_dynamic_shape=enable_dynamic_shape ,
820+
enable_dynamic_shape=enable_dynamic_shape,
821821
output_prune_map_path=output_prune_map_path,
822822
args=args,
823823
)

examples/models/llama3_2_vision/model.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,23 +10,28 @@
1010
from typing import Any, Dict
1111

1212
import torch
13-
14-
from executorch.examples.models.model_base import EagerModelBase
15-
from torchtune.models.llama3_2_vision._convert_weights import llama3_vision_meta_to_tune
16-
from torchtune.models.llama3_2_vision._component_builders import llama3_2_vision_decoder
1713
from executorch.examples.models.checkpoint import (
18-
get_default_model_resource_dir,
1914
get_checkpoint_dtype,
15+
get_default_model_resource_dir,
2016
)
2117

18+
from executorch.examples.models.model_base import EagerModelBase
19+
from torchtune.models.llama3_2_vision._component_builders import llama3_2_vision_decoder
20+
from torchtune.models.llama3_2_vision._convert_weights import llama3_vision_meta_to_tune
21+
2222

2323
def to_decoder_checkpoint(checkpoint: Dict[str, Any]) -> Dict[str, Any]:
2424
"""
2525
Extracts and formats the decoder-related weights from the checkpoint. The checkpoint contains
2626
weight names prefixed with "encoder"/"decoder", such as "encoder.layer.etc" or "decoder.norm.scale".
2727
To load the text decoder on its own, the "decoder" prefix needs to be removed.
2828
"""
29-
return {".".join(weight.split(".")[1:]): value for weight, value in checkpoint.items() if weight.startswith("decoder")}
29+
return {
30+
".".join(weight.split(".")[1:]): value
31+
for weight, value in checkpoint.items()
32+
if weight.startswith("decoder")
33+
}
34+
3035

3136
class Llama3_2Decoder(EagerModelBase):
3237
"""
@@ -36,7 +41,9 @@ class Llama3_2Decoder(EagerModelBase):
3641
def __init__(self, **kwargs):
3742
# Set member vars from kwargs.
3843
self.max_seq_len = kwargs.get("max_seq_len", 8192)
39-
self.encoder_max_seq_len = kwargs.get("encoder_max_seq_len", int(4 * (448 / 14) ** 2 + 1))
44+
self.encoder_max_seq_len = kwargs.get(
45+
"encoder_max_seq_len", int(4 * (448 / 14) ** 2 + 1)
46+
)
4047
self.generate_full_logits = kwargs.get("generate_full_logits", False)
4148
self.enable_dynamic_shape = kwargs.get("enable_dynamic_shape", False)
4249
self.output_prune_map_path = kwargs.get("output_prune_map_path", None)
@@ -46,7 +53,6 @@ def __init__(self, **kwargs):
4653
self.verbose = kwargs.get("verbose", False)
4754
self.args = kwargs.get("args", None)
4855

49-
5056
ckpt_dir = get_default_model_resource_dir(__file__)
5157
# Single checkpoint file.
5258
checkpoint_path = kwargs.get("checkpoint", ckpt_dir / "demo_rand_params.pth")
@@ -57,7 +63,9 @@ def __init__(self, **kwargs):
5763
# Load checkpoint and params.
5864
device = "cpu"
5965
if checkpoint_dir is not None:
60-
raise NotImplementedError("Sharded checkpoint not yet supported for Llama3_2Decoder.")
66+
raise NotImplementedError(
67+
"Sharded checkpoint not yet supported for Llama3_2Decoder."
68+
)
6169
else:
6270
checkpoint = torch.load(checkpoint_path, map_location=device, mmap=True)
6371
checkpoint = llama3_vision_meta_to_tune(checkpoint)
@@ -107,7 +115,9 @@ def __init__(self, **kwargs):
107115
# Prune the output layer if output_prune_map is provided.
108116
output_prune_map = None
109117
if self.output_prune_map_path is not None:
110-
from executorch.examples.models.llama2.source_transformation.prune_output import prune_output_vocab
118+
from executorch.examples.models.llama2.source_transformation.prune_output import (
119+
prune_output_vocab,
120+
)
111121

112122
with open(self.output_prune_map_path, "r") as f:
113123
output_prune_map = json.load(f)
@@ -123,9 +133,7 @@ def get_eager_model(self) -> torch.nn.Module:
123133
return self.model_.to(torch.float16)
124134

125135
def get_example_inputs(self):
126-
return (
127-
torch.ones(1, 64, dtype=torch.long), # positional inputs
128-
)
136+
return (torch.ones(1, 64, dtype=torch.long),) # positional inputs
129137

130138
def get_example_kwarg_inputs(self):
131139
# TODO: add input_pos and mask when after making cache work.
@@ -137,7 +145,7 @@ def get_example_kwarg_inputs(self):
137145
}
138146

139147
def get_dynamic_shapes(self):
140-
dim = torch.export.Dim("token_dim", min=1,max=self.max_seq_len)
148+
dim = torch.export.Dim("token_dim", min=1, max=self.max_seq_len)
141149
dynamic_shapes = {
142150
"tokens": {0: 1, 1: dim},
143151
# "encoder_input": {0:1, 1:dim_enc, 2:4096},
@@ -146,4 +154,3 @@ def get_dynamic_shapes(self):
146154
# "input_pos" : {0: dim},
147155
}
148156
return dynamic_shapes
149-

0 commit comments

Comments
 (0)