Skip to content

Commit c766f0d

Browse files
committed
Apply calibration patch and deduplicate delegate cache patch
1 parent 7010a11 commit c766f0d

File tree

6 files changed

+152
-56
lines changed

6 files changed

+152
-56
lines changed

examples/qualcomm/oss_scripts/llama3_2/llama.py

Lines changed: 83 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -73,37 +73,49 @@ def _kv_calibrate(
7373
max_seq_len=512,
7474
):
7575
sp_model = get_tokenizer(tokenizer_model_path)
76-
_, atten_mask, _, k_caches, v_caches = example_inputs
7776

7877
# TODO: change criteria & support batch inputs if necessary
79-
pos = torch.tensor(0, dtype=torch.int32)
8078
max_cache_len = max_seq_len - 1
81-
token_list = sp_model.encode(user_prompts, bos=True, eos=False)
8279

83-
with torch.no_grad():
84-
while token_list[-1] != sp_model.eos_id and pos < max_cache_len:
85-
logits, new_k_caches, new_v_caches = module(
86-
torch.full((1, 1), token_list[pos], dtype=torch.int32),
87-
atten_mask,
88-
torch.full((1, 1), pos),
89-
*k_caches,
90-
*v_caches,
91-
)
92-
k_caches = [
93-
torch.cat([k_cache[:, :, 1:], new_k_caches[i]], dim=-1)
94-
for i, k_cache in enumerate(k_caches)
95-
]
96-
v_caches = [
97-
torch.cat([v_cache[:, 1:, :], new_v_caches[i]], dim=1)
98-
for i, v_cache in enumerate(v_caches)
99-
]
100-
101-
pos += 1
102-
atten_mask[0][-pos - 1] = 0
103-
if pos >= len(token_list):
104-
token_list.append(torch.argmax(logits[:, -1], dim=-1).item())
10580

106-
print(f"calibration data:\n{sp_model.decode(token_list)}")
81+
# token_list = sp_model.encode(user_prompts, bos=True, eos=False)
82+
83+
user_token_list = [
84+
# what is the capital of the united states
85+
[128000, 128006, 882, 128007, 271, 12840, 374, 279, 6864, 315, 279, 29292, 5415, 128009, 128006, 78191, 128007, 271],
86+
# what is 1 + 1
87+
[128000, 128006, 882, 128007, 271, 12840, 374, 220, 16, 489, 220, 16, 128009, 128006, 78191, 128007, 271],
88+
# what is the meaning of life
89+
[128000, 128006, 882, 128007, 271, 12840, 374, 279, 7438, 315, 2324, 128009, 128006, 78191, 128007, 271],
90+
]
91+
92+
for token_list in user_token_list:
93+
_, atten_mask, _, k_caches, v_caches = copy.deepcopy(example_inputs)
94+
pos = torch.tensor(0, dtype=torch.int32)
95+
with torch.no_grad():
96+
while token_list[-1] != sp_model.eos_id and pos < max_cache_len:
97+
logits, new_k_caches, new_v_caches = module(
98+
torch.full((1, 1), token_list[pos], dtype=torch.int32),
99+
atten_mask,
100+
torch.full((1, 1), pos),
101+
*k_caches,
102+
*v_caches,
103+
)
104+
k_caches = [
105+
torch.cat([k_cache[:, :, 1:], new_k_caches[i]], dim=-1)
106+
for i, k_cache in enumerate(k_caches)
107+
]
108+
v_caches = [
109+
torch.cat([v_cache[:, 1:, :], new_v_caches[i]], dim=1)
110+
for i, v_cache in enumerate(v_caches)
111+
]
112+
113+
pos += 1
114+
atten_mask[0][-pos - 1] = 0
115+
if pos >= len(token_list):
116+
token_list.append(torch.argmax(logits[:, -1], dim=-1).item())
117+
118+
logging.info(f"calibration data:\n{sp_model.decode(token_list)}")
107119

108120

109121
def _prefill_calibrate(
@@ -114,32 +126,44 @@ def _prefill_calibrate(
114126
max_seq_len=512,
115127
):
116128
sp_model = get_tokenizer(tokenizer_model_path)
117-
_, atten_mask = example_inputs
118129
max_cache_len = max_seq_len - 1
119130

120131
# TODO: change criteria & support batch inputs if necessary
121-
token_list = sp_model.encode(user_prompts, bos=True, eos=False)
122-
token_list = torch.tensor(token_list)[:max_cache_len].reshape(1, -1)
123-
last_prompt_pos = token_list.numel()
124-
if last_prompt_pos < max_cache_len:
125-
token_list = torch.cat(
126-
[
127-
token_list,
128-
torch.zeros((1, max_cache_len - last_prompt_pos), dtype=torch.int32),
129-
],
130-
dim=1,
131-
)
132-
else:
133-
token_list = token_list[:, :max_cache_len]
134-
135-
with torch.no_grad():
136-
logits, new_k_caches, new_v_caches = module(
137-
token_list,
138-
atten_mask,
139-
)
140-
predict = [torch.argmax(logits[:, last_prompt_pos - 1], dim=-1).item()]
132+
133+
# token_list = sp_model.encode(user_prompts, bos=True, eos=False)
134+
135+
user_token_list = [
136+
# what is the capital of the united states
137+
[128000, 128006, 882, 128007, 271, 12840, 374, 279, 6864, 315, 279, 29292, 5415, 128009, 128006, 78191, 128007, 271],
138+
# what is 1 + 1
139+
[128000, 128006, 882, 128007, 271, 12840, 374, 220, 16, 489, 220, 16, 128009, 128006, 78191, 128007, 271],
140+
# what is the meaning of life
141+
[128000, 128006, 882, 128007, 271, 12840, 374, 279, 7438, 315, 2324, 128009, 128006, 78191, 128007, 271],
142+
]
143+
144+
for token_list in user_token_list:
145+
_, atten_mask = copy.deepcopy(example_inputs)
146+
token_list = torch.tensor(token_list)[:max_cache_len].reshape(1, -1)
147+
last_prompt_pos = token_list.numel()
148+
if last_prompt_pos < max_cache_len:
149+
token_list = torch.cat(
150+
[
151+
token_list,
152+
torch.zeros((1, max_cache_len - last_prompt_pos), dtype=torch.int32),
153+
],
154+
dim=1,
155+
)
156+
else:
157+
token_list = token_list[:, :max_cache_len]
141158

142-
print(f"calibration data:\n{sp_model.decode(predict)}")
159+
with torch.no_grad():
160+
logits, new_k_caches, new_v_caches = module(
161+
token_list,
162+
atten_mask,
163+
)
164+
predict = [torch.argmax(logits[:, last_prompt_pos - 1], dim=-1).item()]
165+
166+
logging.info(f"calibration data:\n{sp_model.decode(predict)}")
143167

144168

145169
def calibrate(
@@ -249,7 +273,17 @@ def quantize(self, quant_dtype, args, custom_annotations=()):
249273
max_seq_len=self.llama_meta["get_max_seq_len"],
250274
)
251275

252-
self.llama_model = convert_pt2e(fx_graph_module)
276+
fx_graph_module = convert_pt2e(fx_graph_module)
277+
278+
logging.info("Evaluating the converted model...")
279+
calibrate(
280+
self.get_example_inputs(self.llama_meta["get_use_kv_cache"]),
281+
args.prompt,
282+
fx_graph_module,
283+
tokenizer_model_path=args.tokenizer_model,
284+
max_seq_len=self.llama_meta["get_max_seq_len"],
285+
)
286+
self.llama_model = fx_graph_module
253287

254288
def lowering_modules(
255289
self,

exir/_serialize/_program.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ def _extract_delegate_segments(
224224
"""
225225
remaining_inline: List[BackendDelegateInlineData] = []
226226
inline_indices_seen: set[int] = set()
227+
segment_index_map: dict[bytes, int] = {}
227228
for plan in program.execution_plan:
228229
for delegate in plan.delegates:
229230
if delegate.processed.location != DataLocation.INLINE:
@@ -249,8 +250,11 @@ def _extract_delegate_segments(
249250
inline_indices_seen.add(delegate.processed.index)
250251
if inline.data:
251252
# Move the delegate data out of the program.
252-
segment_index = len(segments)
253-
segments.append(Cord(inline.data))
253+
segment_index = segment_index_map.get(inline.data)
254+
if segment_index is None:
255+
segment_index = len(segments)
256+
segments.append(Cord(inline.data))
257+
segment_index_map[inline.data] = segment_index
254258
delegate.processed = BackendDelegateDataReference(
255259
location=DataLocation.SEGMENT,
256260
index=segment_index,

exir/backend/test/demos/rpc/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ runtime.python_library(
2828
],
2929
visibility = [
3030
"//executorch/exir/backend/test/...",
31+
"//executorch/exir/emit/test/...",
3132
],
3233
deps = [
3334
":executor_backend_preprocess",

exir/emit/_emitter.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,8 @@ class _ProgramState:
122122
# Delegate data stored directly in the flatbuffer. Pointed to by BackendDelegateDataReference,
123123
# and should be copied to Program.backend_delegate_data.
124124
backend_delegate_data: List[BackendDelegateInlineData] = field(default_factory=list)
125+
# Delegate cache that is used across all entry points.
126+
backend_delegate_data_cache: Dict[bytes, int] = field(default_factory=dict)
125127

126128
# Constants are optionally stored in external files.
127129
# Aggregate unique external constants into one buffer.
@@ -1112,10 +1114,13 @@ def _emit_delegate(
11121114
if delegate_index is None:
11131115
# Allocate an entry for the data. TODO(T150113674): Reuse any duplicate entries if
11141116
# present.
1115-
data_index: int = len(self.program_state.backend_delegate_data)
1116-
self.program_state.backend_delegate_data.append(
1117-
BackendDelegateInlineData(data=processed_bytes)
1118-
)
1117+
data_index: Optional[int] = self.program_state.backend_delegate_data_cache.get(processed_bytes)
1118+
if data_index is None:
1119+
data_index = len(self.program_state.backend_delegate_data)
1120+
self.program_state.backend_delegate_data_cache[processed_bytes] = data_index
1121+
self.program_state.backend_delegate_data.append(
1122+
BackendDelegateInlineData(data=processed_bytes)
1123+
)
11191124

11201125
backend_delegate = BackendDelegate(
11211126
id=lowered_module.backend_id,

exir/emit/test/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ python_unittest(
1616
"//executorch/exir:lib",
1717
"//executorch/exir:print_program",
1818
"//executorch/exir:schema",
19+
"//executorch/exir/backend/test/demos/rpc:executor_backend_partitioner",
1920
"//executorch/exir/backend:backend_api",
2021
"//executorch/exir/emit:lib",
2122
"//executorch/exir/passes:const_prop_pass",

exir/emit/test/test_emit.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@
2626
from executorch.exir._serialize._program import deserialize_pte_binary
2727
from executorch.exir.backend.backend_api import to_backend
2828
from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
29+
from executorch.exir.backend.test.demos.rpc.executor_backend_partitioner import (
30+
ExecutorBackendPartitioner,
31+
)
2932
from executorch.exir.dialects._ops import ops as exir_ops
3033
from executorch.exir.emit import emit_program # noqa
3134
from executorch.exir.error import InternalError
@@ -60,7 +63,7 @@
6063
from functorch.experimental import control_flow
6164
from torch import nn
6265

63-
from torch.export import Dim, export
66+
from torch.export import Dim, export, export_for_training
6467

6568

6669
class WrapperModule(torch.nn.Module):
@@ -1626,3 +1629,51 @@ def forward(self, x):
16261629
]
16271630
self.assertEqual(external_map["linear.weight"], 0)
16281631
self.assertEqual(external_map["linear.bias"], 1)
1632+
def test_delegate_deduplicate(self) -> None:
1633+
class SharedModule(torch.nn.Module):
1634+
def __init__(self):
1635+
super().__init__()
1636+
self.linear = torch.nn.Linear(2, 2)
1637+
1638+
def forward(self, x):
1639+
return self.linear(x)
1640+
1641+
1642+
class Module1(torch.nn.Module):
1643+
def __init__(self, shared_module):
1644+
super().__init__()
1645+
self.shared_module = shared_module
1646+
1647+
def forward(self, x):
1648+
return self.shared_module(x)
1649+
1650+
1651+
class Module2(torch.nn.Module):
1652+
def __init__(self, shared_module):
1653+
super().__init__()
1654+
self.shared_module = shared_module
1655+
1656+
def forward(self, x):
1657+
return self.shared_module(x)
1658+
1659+
shared_module = SharedModule()
1660+
module_1 = Module1(shared_module)
1661+
module_2 = Module2(shared_module)
1662+
example_inputs = (torch.randn(2, 2),)
1663+
module_1(*example_inputs)
1664+
module_2(*example_inputs)
1665+
1666+
ep1 = export_for_training(module_1, example_inputs)
1667+
ep2 = export_for_training(module_2, example_inputs)
1668+
1669+
edge_program_manager = exir.to_edge(
1670+
{"forward1": ep1, "forward2": ep2},
1671+
compile_config=exir.EdgeCompileConfig(
1672+
_check_ir_validity=False, _use_edge_ops=True
1673+
),
1674+
)
1675+
1676+
edge_program_manager = edge_program_manager.to_backend(ExecutorBackendPartitioner()).to_executorch()
1677+
1678+
# Check that there is only one delegate because two methods are exactly the same
1679+
self.assertEqual(len(edge_program_manager.executorch_program.backend_delegate_data), 1)

0 commit comments

Comments
 (0)