-
Notifications
You must be signed in to change notification settings - Fork 607
Add proper pt2e calibration #5095
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/5095
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit d4d7cfa with merge base 9739609 ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
I think we might need to calibrate some special tokens in input template for llama3 instruct. |
I see. Should I leave it to you for the proper calibration pr or would you prefer me to amend this one? |
If you can, I’d appreciate it! |
Sure happy to help. Mind sharing what extra calibration you did? |
Sure, I'm just calibrating a prompt that contains an input template. def eval_once(self, module: torch.fx.GraphModule, string: str = "Once upon a time", max_len: int = 128):
tokenizer = SimpleTokenizer(self.tokenizer_path)
# TODO: change criteria & support batch inputs if necessary
pos = torch.tensor(0, dtype=torch.int64)
token_list = [tokenizer.bos_id] + tokenizer.encode(string)
with torch.no_grad():
while token_list[-1] != tokenizer.eos_id and pos < max_len:
logits = module(
torch.full((1, 1), token_list[pos]),
torch.tensor((pos, )),
)
pos += 1
if pos >= len(token_list):
token_list.append(torch.argmax(logits[:], dim=-1).item())
...
def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManager":
...
# Calibration
self.eval_once(m, string="<|start_header_id|>system<|end_header_id|>\n\nYou are a cute and funny chatbot answering questions<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nTell me about Meta<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", max_len=128) |
# Batch process the whole sequence. | ||
logits = self._model(inps[:, : self._max_seq_length], pos_tensor) | ||
return logits | ||
if not self._dynamic_shape: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When do we not enable dynamic shape?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe dynamic_shape is default to true now. We can probably ignore the False
case here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We actually disable dynamic shape to all other backends except xnnpack: https://github.com/pytorch/executorch/tree/main/examples/models/llama2#optional-smaller-models-delegated-to-other-backends
Particularly for QNN, we can only do static shape for now.
extension/llm/export/builder.py
Outdated
@@ -190,7 +252,26 @@ def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManage | |||
), "Please run capture_pre_autograd_graph first" | |||
m = prepare_pt2e(self.pre_autograd_graph_module, composed_quantizer) | |||
# Calibrate | |||
m(*self.example_inputs) | |||
logging.info(f"Calibrating with tasks: {self.calibration_tasks}, limit: {self.calibration_limit}, seq_length: {self.calibration_seq_length}, tokenizer_path: {self.tokenizer_path}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this log info duplicated with the log at line 263? Maybe remove this line?
extension/llm/export/builder.py
Outdated
token_list = [tokenizer.bos_id] + tokenizer.encode(string, bos=True, eos=False) | ||
|
||
with torch.no_grad(): | ||
while token_list[-1] != tokenizer.eos_id and pos < max_len: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For llama 3, eos_id
is actually eot_id
BTW, https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct/discussions/73
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious why not batch prefill here to make it faster?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think if batch prefill here, it should also check dynamic shape.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here we use the graph module instead of eager model for calibration, and the graph module is captured with static shape. The graph module is captured with fix shape and batch prefill requires dynamic shape.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just minor fixed
Overall look good to me.
a187cf8
to
21d3974
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So instead of modifying export_llama_lib the way I think you want to do this is via
- Introduce QNNRunEvalWrapper like here https://github.com/pytorch/executorch/blob/main/examples/models/llama2/eval_llama_lib.py#L124
- On
LLMEdgeManager
instance returned from _prepare_for_llama_export, call quantization apis manually. Similar to https://github.com/pytorch/executorch/blob/main/examples/models/llama2/eval_llama_lib.py#L145. Where you export the model and then instead of calling pt2e_quantize, call prepare_pt2e manually. Such a model then can be returned as nn Module. - Wrap the module in QNNRunEvalWrapper like other wrappers.
Why this way?
I think this way you are not polluting export_llama_lib with eval related concerns and eval related to code still remains within eval pipeline
@@ -166,19 +166,25 @@ def build_args_parser() -> argparse.ArgumentParser: | |||
nargs="+", | |||
type=str, | |||
default=None, | |||
help="Tasks for GPTQ calibration", | |||
help="Tasks for GPTQ calibration from lm_eval", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: For future reference, separate out unrelated fixes
@kimishpatel thanks for the detailed review. Apply suggesion on the latest commit and plz take another look.
I introduce a graphmodule run eval wrapper in the new commit instead, as I feel like it's not qnn specific. The actual calibration data might be different but that can be control by the args
I'm not exactly sure sure what it means. Do you mean having a separate calibrate api in
Did this for the GraphModuleEvalWrapper |
Sounds good |
graph module is nn module |
No the opposite. I dont think it makes sense to have calibrate api on LLMEdgeManager. I meant more like
|
@@ -167,6 +178,69 @@ def capture_pre_autograd_graph(self) -> "LLMEdgeManager": | |||
) | |||
return self | |||
|
|||
def pt2e_calibrate( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This method should not be part of builder at all. It is meant to produce a model not calibrate.
Hence my suggestion was to move the functionality of this method either inside GraphModuleEvalWrapper or soemthing else
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
discussed and looks good
Summary: See discussion in pytorch#5095 Reland because of internal failure Differential Revision: D62323396
Summary: Pull Request resolved: pytorch#5152 See discussion in pytorch#5095 Reland because of internal failure Differential Revision: D62323396
Summary: Pull Request resolved: pytorch#5152 See discussion in pytorch#5095 Reland because of internal failure Differential Revision: D62323396
Currently pt2e calibration is a dummy calibration, use a proper calibration for eval
Command line for evaluate:
Command line for export: