Skip to content

Commit 3933cb6

Browse files
committed
Apply calibration patch and deduplicate delegate cache patch
1 parent e00eaea commit 3933cb6

File tree

2 files changed

+58
-6
lines changed

2 files changed

+58
-6
lines changed

examples/qualcomm/oss_scripts/llama/llama.py

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,6 @@ def _kv_calibrate(
8282
_, atten_mask, _, k_caches, v_caches = example_inputs
8383

8484
# TODO: change criteria & support batch inputs if necessary
85-
pos = torch.tensor(0, dtype=torch.int32)
8685
max_cache_len = max_seq_len - 1
8786

8887
token_list = []
@@ -114,10 +113,42 @@ def _kv_calibrate(
114113
for i, v_cache in enumerate(v_caches)
115114
]
116115

117-
pos += 1
118-
atten_mask[0][-pos - 1] = 0
119-
if pos >= len(token_list):
120-
token_list.append(torch.argmax(logits[:, -1], dim=-1).item())
116+
# token_list = sp_model.encode(user_prompts, bos=True, eos=False)
117+
118+
user_token_list = [
119+
# what is the capital of the united states
120+
[128000, 128006, 882, 128007, 271, 12840, 374, 279, 6864, 315, 279, 29292, 5415, 128009, 128006, 78191, 128007, 271],
121+
# what is 1 + 1
122+
[128000, 128006, 882, 128007, 271, 12840, 374, 220, 16, 489, 220, 16, 128009, 128006, 78191, 128007, 271],
123+
# what is the meaning of life
124+
[128000, 128006, 882, 128007, 271, 12840, 374, 279, 7438, 315, 2324, 128009, 128006, 78191, 128007, 271],
125+
]
126+
127+
for token_list in user_token_list:
128+
_, atten_mask, _, k_caches, v_caches = copy.deepcopy(example_inputs)
129+
pos = torch.tensor(0, dtype=torch.int32)
130+
with torch.no_grad():
131+
while token_list[-1] != sp_model.eos_id and pos < max_cache_len:
132+
logits, new_k_caches, new_v_caches = module(
133+
torch.full((1, 1), token_list[pos], dtype=torch.int32),
134+
atten_mask,
135+
torch.full((1, 1), pos),
136+
*k_caches,
137+
*v_caches,
138+
)
139+
k_caches = [
140+
torch.cat([k_cache[:, :, 1:], new_k_caches[i]], dim=-1)
141+
for i, k_cache in enumerate(k_caches)
142+
]
143+
v_caches = [
144+
torch.cat([v_cache[:, 1:, :], new_v_caches[i]], dim=1)
145+
for i, v_cache in enumerate(v_caches)
146+
]
147+
148+
pos += 1
149+
atten_mask[0][-pos - 1] = 0
150+
if pos >= len(token_list):
151+
token_list.append(torch.argmax(logits[:, -1], dim=-1).item())
121152

122153
print(f"kv calibration data:\n{tokenizer.decode(token_list)}")
123154

@@ -328,7 +359,17 @@ def quantize(self, quant_dtype, args, tokenizer, custom_annotations=()):
328359
max_seq_len=self.llama_meta["get_max_seq_len"],
329360
)
330361

331-
self.llama_model = convert_pt2e(fx_graph_module)
362+
fx_graph_module = convert_pt2e(fx_graph_module)
363+
364+
logging.info("Evaluating the converted model...")
365+
calibrate(
366+
self.get_example_inputs(self.llama_meta["get_use_kv_cache"]),
367+
args.prompt,
368+
fx_graph_module,
369+
tokenizer_model_path=args.tokenizer_model,
370+
max_seq_len=self.llama_meta["get_max_seq_len"],
371+
)
372+
self.llama_model = fx_graph_module
332373

333374
def lowering_modules(
334375
self,

exir/emit/test/test_emit.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1682,7 +1682,10 @@ def forward(self, x):
16821682
]
16831683
self.assertEqual(external_map["linear.weight"], 0)
16841684
self.assertEqual(external_map["linear.bias"], 1)
1685+
<<<<<<< HEAD
16851686

1687+
=======
1688+
>>>>>>> c766f0dc0 (Apply calibration patch and deduplicate delegate cache patch)
16861689
def test_delegate_deduplicate(self) -> None:
16871690
class SharedModule(torch.nn.Module):
16881691
def __init__(self):
@@ -1692,6 +1695,10 @@ def __init__(self):
16921695
def forward(self, x):
16931696
return self.linear(x)
16941697

1698+
<<<<<<< HEAD
1699+
=======
1700+
1701+
>>>>>>> c766f0dc0 (Apply calibration patch and deduplicate delegate cache patch)
16951702
class Module1(torch.nn.Module):
16961703
def __init__(self, shared_module):
16971704
super().__init__()
@@ -1700,6 +1707,10 @@ def __init__(self, shared_module):
17001707
def forward(self, x):
17011708
return self.shared_module(x)
17021709

1710+
<<<<<<< HEAD
1711+
=======
1712+
1713+
>>>>>>> c766f0dc0 (Apply calibration patch and deduplicate delegate cache patch)
17031714
class Module2(torch.nn.Module):
17041715
def __init__(self, shared_module):
17051716
super().__init__()

0 commit comments

Comments
 (0)