Skip to content

Commit 5f0a14a

Browse files
Update eager runner to support AttentionSink (#7149)
* Transform model to be able to use Attention Sink Pull Request resolved: #6700 This PR adds necessary functions for transforming the model to be able to use Attention Sink. ghstack-source-id: 256108077 @exported-using-ghexport Differential Revision: [D65571289](https://our.internmc.facebook.com/intern/diff/D65571289/) * Update eager runner to support AttentionSink Pull Request resolved: #6921 This PR updates the eager runner to support AttentionSink. It also fixes issues in the `chat_completion` function to properly handle the position id. ghstack-source-id: 256108078 Differential Revision: [D66076486](https://our.internmc.facebook.com/intern/diff/D66076486/) * add eval for attention sink (#7150) Pull Request resolved: #7070 This PR adds the function to evaluate the model's perplexity when AttentionSink is enabled. This is mostly copied from https://github.com/mit-han-lab/streaming-llm/blob/main/examples/eval_long_ppl.py which is used by the AttentionSink paper to evaluate the model's perplexity when AttentionSink is enabled. ghstack-source-id: 256108079 @exported-using-ghexport Differential Revision: [D66474732](https://our.internmc.facebook.com/intern/diff/D66474732/) Co-authored-by: Lunwen He <[email protected]> --------- Co-authored-by: Lunwen He <[email protected]>
1 parent aa67cd9 commit 5f0a14a

File tree

5 files changed

+97
-13
lines changed

5 files changed

+97
-13
lines changed

examples/models/llama/TARGETS

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,8 @@ runtime.python_library(
150150
"@EXECUTORCH_CLIENTS",
151151
],
152152
deps = [
153+
"fbsource//third-party/pypi/tqdm:tqdm",
154+
"fbsource//third-party/pypi/datasets:datasets",
153155
"fbsource//third-party/pypi/lm-eval:lm-eval",
154156
"fbsource//third-party/pypi/tiktoken:tiktoken",
155157
":export_library",

examples/models/llama/eval_llama.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,11 @@
1010

1111
import torch
1212

13-
from .eval_llama_lib import build_args_parser, eval_llama
13+
from .eval_llama_lib import (
14+
build_args_parser,
15+
eval_llama,
16+
eval_llama_with_attention_sink,
17+
)
1418

1519
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
1620
logging.basicConfig(level=logging.INFO, format=FORMAT)
@@ -24,7 +28,10 @@ def main() -> None:
2428
args = parser.parse_args()
2529
# Overrides this arg, because evaluation requires full logits.
2630
args.generate_full_logits = True
27-
eval_llama(modelname, args) # pyre-ignore
31+
if args.use_attention_sink:
32+
eval_llama_with_attention_sink(modelname, args) # pyre-ignore
33+
else:
34+
eval_llama(modelname, args) # pyre-ignore
2835

2936

3037
if __name__ == "__main__":

examples/models/llama/eval_llama_lib.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from typing import Optional, Union
1111

1212
import torch
13+
14+
from datasets import load_dataset
1315
from executorch.examples.models.llama.export_llama_lib import (
1416
get_quantizer_and_quant_params,
1517
)
@@ -21,6 +23,8 @@
2123
)
2224
from executorch.extension.llm.tokenizer.utils import get_tokenizer
2325
from lm_eval.evaluator import simple_evaluate
26+
from torch.nn import CrossEntropyLoss
27+
from tqdm import tqdm
2428

2529
from .evaluate.eager_eval import EagerEvalWrapper
2630

@@ -280,6 +284,9 @@ def build_args_parser() -> argparse.ArgumentParser:
280284
help="Save the checkpoint after source transformations, for other evaluation platform to run the same checkpoint.",
281285
)
282286

287+
# Set of parameters secpific to AttentionSink.
288+
parser.add_argument("--attention_sink_eval_tokens", type=int, default=0)
289+
283290
return parser
284291

285292

@@ -309,3 +316,60 @@ def eval_llama(
309316

310317
for task, res in eval_results["results"].items():
311318
print(f"{task}: {res}")
319+
320+
321+
def eval_llama_with_attention_sink(model_name: str, args: argparse.ArgumentParser):
322+
"""
323+
Evaluate the model's perplexity when AttentionSink is enabled.
324+
325+
This is mostly copied from https://github.com/mit-han-lab/streaming-llm/blob/main/examples/eval_long_ppl.py
326+
"""
327+
assert args.use_attention_sink is not None # pyre-ignore [16]
328+
assert args.attention_sink_eval_tokens > 0 # pyre-ignore [16]
329+
attention_sink_params = args.use_attention_sink.split(",")
330+
assert len(attention_sink_params) == 3
331+
sink_size = int(attention_sink_params[0])
332+
window_size = int(attention_sink_params[1])
333+
334+
assert args.max_seq_length == sink_size + window_size # pyre-ignore [16]
335+
336+
device = "cuda" if torch.cuda.is_available() else "cpu"
337+
manager: LLMEdgeManager = _prepare_for_llama_export(args)
338+
model = manager.model.eval().to(device=device)
339+
tokenizer = get_tokenizer(args.tokenizer_path) # pyre-ignore [16]
340+
341+
eval_data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
342+
343+
nlls = []
344+
loss_fn = CrossEntropyLoss(reduction="none")
345+
progress_bar = tqdm(total=args.attention_sink_eval_tokens)
346+
input_pos = 0
347+
while input_pos < args.attention_sink_eval_tokens:
348+
for text in eval_data["text"]: # pyre-ignore [16]
349+
tokens = tokenizer.encode(text, bos=False, eos=False)
350+
if len(tokens) <= 0:
351+
continue
352+
with torch.no_grad():
353+
num_tokens = min(
354+
len(tokens) - 1, args.attention_sink_eval_tokens - input_pos
355+
)
356+
logits = model(
357+
torch.tensor(
358+
[tokens[:num_tokens]], dtype=torch.int64, device=device
359+
),
360+
torch.tensor([input_pos], dtype=torch.int64, device=device),
361+
).squeeze(dim=0)
362+
neg_log_likelihood = loss_fn(
363+
logits,
364+
torch.tensor(
365+
[tokens[1 : num_tokens + 1]], dtype=torch.int64, device=device
366+
).view(-1),
367+
)
368+
nlls.append(neg_log_likelihood)
369+
input_pos += num_tokens
370+
progress_bar.update(num_tokens)
371+
if input_pos >= args.attention_sink_eval_tokens:
372+
break
373+
ppl = torch.exp(torch.cat(nlls).mean())
374+
print(f"Perplexity: {ppl.item()}")
375+
return ppl.item()

examples/models/llama/runner/eager.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,11 @@ def execute_runner(runner_class: Type[LlamaRunner]) -> None:
8484
with torch.no_grad():
8585
runner = runner_class(args) # pyre-ignore: Missing argument [20]
8686
generated_tokens = (
87-
runner.chat_completion(temperature=args.temperature)
87+
runner.chat_completion(
88+
max_seq_len=1000000 if args.use_attention_sink else args.max_seq_length,
89+
temperature=args.temperature,
90+
show_progress=args.show_tokens,
91+
)
8892
if args.chat
8993
else runner.text_completion(
9094
prompt=args.prompt,

examples/models/llama/runner/generation.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -168,18 +168,19 @@ def text_completion(
168168

169169
def chat_completion(
170170
self,
171+
max_seq_len: int,
171172
temperature: float = 0.6,
172173
top_p: float = 0.9,
174+
show_progress: bool = False,
173175
) -> List[int]:
174176
"""
175177
Perform multi-turn chat with the language model.
176178
177179
Args:
178-
prompt (str): Text prompt for completion.
180+
max_seq_len (int): Maximum number of tokens to generate for each prompt.
179181
temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.
180182
top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
181-
echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False.
182-
183+
show_progress (bool, optional): Flag indicating whether to show number of tokens generated.
183184
Returns:
184185
Generated list of tokens.
185186
@@ -188,20 +189,26 @@ def chat_completion(
188189
"""
189190
exit_prompt = "exit"
190191
tokens = []
192+
pre_stop_token = []
191193
prompt = input("Me: ")
192194
while prompt and prompt != exit_prompt:
193195
print("LLM: ", end="", flush=True)
194-
new_tokens = self.generate(
195-
prompt_tokens=self.tokenizer.encode(
196-
self._format_prompt(prompt), bos=True, eos=False
197-
),
198-
max_seq_len=self.max_seq_len,
196+
prompt_tokens = self.tokenizer.encode(
197+
self._format_prompt(prompt), bos=True, eos=False
198+
)
199+
generated_tokens = self.generate(
200+
prompt_tokens=pre_stop_token + prompt_tokens,
201+
max_seq_len=max_seq_len,
199202
temperature=temperature,
200203
top_p=top_p,
201-
echo=True,
204+
echo=False,
202205
pos_base=len(tokens) - 1 if len(tokens) > 0 else 0,
203206
)
204-
tokens.extend(new_tokens)
207+
pre_stop_token = generated_tokens[-1:]
208+
tokens.extend(prompt_tokens)
209+
tokens.extend(generated_tokens)
210+
if show_progress:
211+
print(f"[Generated {len(tokens)} tokens]")
205212
prompt = input("Me: ")
206213
return tokens
207214

0 commit comments

Comments
 (0)