Skip to content

Commit b02c692

Browse files
authored
Buckify Llama multimodal export (#7604)
1 parent ee6f2d9 commit b02c692

File tree

9 files changed

+104
-14
lines changed

9 files changed

+104
-14
lines changed

examples/models/llama/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ runtime.python_library(
109109
"//ai_codesign/gen_ai/fast_hadamard_transform:fast_hadamard_transform",
110110
"//caffe2:torch",
111111
"//executorch/backends/vulkan/_passes:vulkan_passes",
112+
"//executorch/exir/passes:init_mutable_pass",
112113
"//executorch/examples/models:model_base",
113114
"//executorch/examples/models:models",
114115
"//executorch/exir/passes:init_mutable_pass",
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
2+
3+
oncall("executorch")
4+
5+
python_library(
6+
name = "multimodal_lib",
7+
srcs = [
8+
"__init__.py",
9+
],
10+
deps = [
11+
"//executorch/examples/models/llama3_2_vision/text_decoder:model",
12+
"//executorch/examples/models/llama3_2_vision/vision_encoder:model",
13+
],
14+
)
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
2+
3+
oncall("executorch")
4+
5+
python_library(
6+
name = "model",
7+
srcs = [
8+
"model.py",
9+
],
10+
deps = [
11+
"//caffe2:torch",
12+
"//executorch/examples/models:checkpoint",
13+
"//pytorch/torchtune:lib",
14+
"//executorch/extension/llm/modules:module_lib",
15+
],
16+
)
17+

examples/models/llama3_2_vision/text_decoder/model.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -133,20 +133,6 @@ def __init__(self, **kwargs):
133133
print(unexpected)
134134
print("============= /unexpected ================")
135135

136-
# Prune the output layer if output_prune_map is provided.
137-
output_prune_map = None
138-
if self.output_prune_map_path is not None:
139-
from executorch.examples.models.llama2.source_transformation.prune_output import (
140-
prune_output_vocab,
141-
)
142-
143-
with open(self.output_prune_map_path, "r") as f:
144-
output_prune_map = json.load(f)
145-
# Change keys from string to int (json only supports string keys)
146-
output_prune_map = {int(k): v for (k, v) in output_prune_map.items()}
147-
148-
self.model_ = prune_output_vocab(self.model_, output_prune_map)
149-
150136
if self.use_kv_cache:
151137
print("Setting up KV cache on the model...")
152138
self.model_.setup_caches(
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
2+
3+
oncall("executorch")
4+
5+
python_library(
6+
name = "model",
7+
srcs = [
8+
"__init__.py",
9+
"model.py",
10+
],
11+
deps = [
12+
"//caffe2:torch",
13+
"//executorch/extension/llm/modules:module_lib",
14+
"//pytorch/torchtune:lib",
15+
"//executorch/examples/models:model_base",
16+
],
17+
)

examples/models/llama3_2_vision/vision_encoder/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
# pyre-ignore-all-errors
8+
79
from dataclasses import dataclass, field
810
from typing import Optional
911

extension/llm/modules/TARGETS

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
2+
3+
oncall("executorch")
4+
5+
python_library(
6+
name = "kv_cache",
7+
srcs = [
8+
"kv_cache.py",
9+
],
10+
deps = [
11+
"//caffe2:torch",
12+
"//pytorch/torchtune:lib",
13+
],
14+
)
15+
16+
python_library(
17+
name = "attention",
18+
srcs = [
19+
"attention.py",
20+
],
21+
deps = [
22+
":kv_cache",
23+
"//caffe2:torch",
24+
"//executorch/extension/llm/custom_ops:custom_ops",
25+
"//pytorch/torchtune:lib",
26+
],
27+
)
28+
29+
python_library(
30+
name = "position_embeddings",
31+
srcs = [
32+
"_position_embeddings.py",
33+
],
34+
deps = [
35+
"//caffe2:torch",
36+
],
37+
)
38+
39+
python_library(
40+
name = "module_lib",
41+
srcs = [
42+
"__init__.py",
43+
],
44+
deps= [
45+
":position_embeddings",
46+
":attention",
47+
":kv_cache",
48+
]
49+
)

extension/llm/modules/_position_embeddings.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
# Added torch._check() to make sure guards on symints are enforced.
99
# See https://github.com/pytorch/torchtune/blob/main/torchtune/models/clip/_position_embeddings.py
1010

11+
# pyre-ignore-all-errors
12+
1113
import logging
1214
import math
1315
from typing import Any, Dict, Tuple

extension/llm/modules/attention.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
# pyre-ignore-all-errors
8+
79
import logging
810
from typing import Optional
911

0 commit comments

Comments
 (0)