Skip to content

Commit b4a9e53

Browse files
committed
Comment out eot condition to generate all tokens
1 parent 3933cb6 commit b4a9e53

File tree

4 files changed

+11
-62
lines changed

4 files changed

+11
-62
lines changed

backends/qualcomm/utils/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -823,6 +823,7 @@ def generate_multi_graph_program(
823823
)
824824
assert qnn_mgr.Init().value == 0, "failed to load processed bytes"
825825
binary_info = bytes(qnn_mgr.Compile())
826+
print("Checking the size of QNN binary info: ", len(binary_info))
826827
assert len(binary_info) != 0, "failed to generate QNN context binary"
827828
graph_names = qnn_mgr.GetGraphNames()
828829
for graph_name in graph_names:

examples/qualcomm/oss_scripts/llama/llama.py

Lines changed: 6 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def _kv_calibrate(
8282
_, atten_mask, _, k_caches, v_caches = example_inputs
8383

8484
# TODO: change criteria & support batch inputs if necessary
85+
pos = torch.tensor(0, dtype=torch.int32)
8586
max_cache_len = max_seq_len - 1
8687

8788
token_list = []
@@ -113,42 +114,10 @@ def _kv_calibrate(
113114
for i, v_cache in enumerate(v_caches)
114115
]
115116

116-
# token_list = sp_model.encode(user_prompts, bos=True, eos=False)
117-
118-
user_token_list = [
119-
# what is the capital of the united states
120-
[128000, 128006, 882, 128007, 271, 12840, 374, 279, 6864, 315, 279, 29292, 5415, 128009, 128006, 78191, 128007, 271],
121-
# what is 1 + 1
122-
[128000, 128006, 882, 128007, 271, 12840, 374, 220, 16, 489, 220, 16, 128009, 128006, 78191, 128007, 271],
123-
# what is the meaning of life
124-
[128000, 128006, 882, 128007, 271, 12840, 374, 279, 7438, 315, 2324, 128009, 128006, 78191, 128007, 271],
125-
]
126-
127-
for token_list in user_token_list:
128-
_, atten_mask, _, k_caches, v_caches = copy.deepcopy(example_inputs)
129-
pos = torch.tensor(0, dtype=torch.int32)
130-
with torch.no_grad():
131-
while token_list[-1] != sp_model.eos_id and pos < max_cache_len:
132-
logits, new_k_caches, new_v_caches = module(
133-
torch.full((1, 1), token_list[pos], dtype=torch.int32),
134-
atten_mask,
135-
torch.full((1, 1), pos),
136-
*k_caches,
137-
*v_caches,
138-
)
139-
k_caches = [
140-
torch.cat([k_cache[:, :, 1:], new_k_caches[i]], dim=-1)
141-
for i, k_cache in enumerate(k_caches)
142-
]
143-
v_caches = [
144-
torch.cat([v_cache[:, 1:, :], new_v_caches[i]], dim=1)
145-
for i, v_cache in enumerate(v_caches)
146-
]
147-
148-
pos += 1
149-
atten_mask[0][-pos - 1] = 0
150-
if pos >= len(token_list):
151-
token_list.append(torch.argmax(logits[:, -1], dim=-1).item())
117+
pos += 1
118+
atten_mask[0][-pos - 1] = 0
119+
if pos >= len(token_list):
120+
token_list.append(torch.argmax(logits[:, -1], dim=-1).item())
152121

153122
print(f"kv calibration data:\n{tokenizer.decode(token_list)}")
154123

@@ -359,17 +328,7 @@ def quantize(self, quant_dtype, args, tokenizer, custom_annotations=()):
359328
max_seq_len=self.llama_meta["get_max_seq_len"],
360329
)
361330

362-
fx_graph_module = convert_pt2e(fx_graph_module)
363-
364-
logging.info("Evaluating the converted model...")
365-
calibrate(
366-
self.get_example_inputs(self.llama_meta["get_use_kv_cache"]),
367-
args.prompt,
368-
fx_graph_module,
369-
tokenizer_model_path=args.tokenizer_model,
370-
max_seq_len=self.llama_meta["get_max_seq_len"],
371-
)
372-
self.llama_model = fx_graph_module
331+
self.llama_model = convert_pt2e(fx_graph_module)
373332

374333
def lowering_modules(
375334
self,

examples/qualcomm/oss_scripts/llama/runner/runner.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -404,10 +404,10 @@ Error Runner::generate(
404404
token_callback(piece_res.get().c_str());
405405
}
406406

407-
if (pos >= num_prompt_tokens && eos_id_.count(cur_token) > 0) {
408-
ET_LOG(Info, "\nReached to the end of generation");
409-
break;
410-
}
407+
// if (pos >= num_prompt_tokens && eos_id_.count(cur_token) > 0) {
408+
// ET_LOG(Info, "\nReached to the end of generation");
409+
// break;
410+
// }
411411
}
412412
};
413413

exir/emit/test/test_emit.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1682,10 +1682,7 @@ def forward(self, x):
16821682
]
16831683
self.assertEqual(external_map["linear.weight"], 0)
16841684
self.assertEqual(external_map["linear.bias"], 1)
1685-
<<<<<<< HEAD
16861685

1687-
=======
1688-
>>>>>>> c766f0dc0 (Apply calibration patch and deduplicate delegate cache patch)
16891686
def test_delegate_deduplicate(self) -> None:
16901687
class SharedModule(torch.nn.Module):
16911688
def __init__(self):
@@ -1695,10 +1692,6 @@ def __init__(self):
16951692
def forward(self, x):
16961693
return self.linear(x)
16971694

1698-
<<<<<<< HEAD
1699-
=======
1700-
1701-
>>>>>>> c766f0dc0 (Apply calibration patch and deduplicate delegate cache patch)
17021695
class Module1(torch.nn.Module):
17031696
def __init__(self, shared_module):
17041697
super().__init__()
@@ -1707,10 +1700,6 @@ def __init__(self, shared_module):
17071700
def forward(self, x):
17081701
return self.shared_module(x)
17091702

1710-
<<<<<<< HEAD
1711-
=======
1712-
1713-
>>>>>>> c766f0dc0 (Apply calibration patch and deduplicate delegate cache patch)
17141703
class Module2(torch.nn.Module):
17151704
def __init__(self, shared_module):
17161705
super().__init__()

0 commit comments

Comments
 (0)