Skip to content

Commit 1c0c17c

Browse files
committed
Update on "Transform model to be able to use Attention Sink"
This PR adds necessary functions for transforming the model to be able to use Attention Sink. Differential Revision: [D65571289](https://our.internmc.facebook.com/intern/diff/D65571289/) [ghstack-poisoned]
1 parent 8a46c77 commit 1c0c17c

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -434,9 +434,9 @@ def build_args_parser() -> argparse.ArgumentParser:
434434

435435
parser.add_argument(
436436
"--use_attention_sink",
437-
default="4,2044,1024",
437+
default=None,
438438
type=str,
439-
help="Use attention sink to have fluent multi-round conversation. '<sink_size>,<window_size>,<batch_eviction_size>'"
439+
help="Use attention sink to have fluent multi-round conversation. '<sink_size>,<window_size>,<batch_eviction_size>', e.g., '4,2044,1024'.",
440440
)
441441

442442
parser.add_argument(

examples/models/llama/model.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -200,11 +200,10 @@ def __init__(self, **kwargs):
200200
)
201201

202202
sanitize_checkpoint_from_pre_quantization(checkpoint)
203-
204-
if hasattr(self.args, "use_attention_sink"):
205-
from .source_transformation.sink_attention import (
206-
enable_attention_sink,
207-
)
203+
204+
if hasattr(self.args, "use_attention_sink") and self.args.use_attention_sink:
205+
from .source_transformation.attention_sink import enable_attention_sink
206+
208207
attention_sink_params = self.args.use_attention_sink.split(",")
209208
assert len(attention_sink_params) == 3
210209

@@ -213,7 +212,8 @@ def __init__(self, **kwargs):
213212
params=model_args,
214213
sink_size=int(attention_sink_params[0]),
215214
window_size=int(attention_sink_params[1]),
216-
eviction_batch_size=int(attention_sink_params[2]))
215+
eviction_batch_size=int(attention_sink_params[2]),
216+
)
217217

218218
# assign=True: load params/buffers by assignment instead of performing an in-place copy.
219219
# Because we are using device="meta", tensors do not have memory associated with them

0 commit comments

Comments
 (0)