Skip to content

Add proper pt2e calibration #5095

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Sep 6, 2024
Merged

Add proper pt2e calibration #5095

merged 10 commits into from
Sep 6, 2024

Conversation

cccclai
Copy link
Contributor

@cccclai cccclai commented Sep 4, 2024

Currently pt2e calibration is a dummy calibration, use a proper calibration for eval

Command line for evaluate:

python -m examples.models.llama2.eval_llama  -t /data/users/chenlai/models/stories/tokenizer.model -p /data/users/chenlai/models/stories/params.json -c /data/users/chenlai/models/stories/stories110M.pt --pt2e_quantize qnn_16a4w --qnn -kv --disable_dynamic_shape  --max_seq_len 16  --limit 1 --calibration_tasks "wikitext" --calibration_limit 1 --calibration_seq_length 16 -t /data/users/chenlai/models/stories/tokenizer.model 

Command line for export:

python -m examples.models.llama2.export_llama  -t /data/users/chenlai/models/stories/tokenizer.model -p /data/users/chenlai/models/stories/params.json -c /data/users/chenlai/models/stories/stories110M.pt --pt2e_quantize qnn_16a4w --qnn -kv --disable_dynamic_shape  --max_seq_len 16  --limit 1 --calibration_tasks "wikitext" --calibration_limit 1 --calibration_seq_length 16 -t /data/users/chenlai/models/stories/tokenizer.model 

Copy link

pytorch-bot bot commented Sep 4, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/5095

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit d4d7cfa with merge base 9739609 (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 4, 2024
@shewu-quic
Copy link
Collaborator

I think we might need to calibrate some special tokens in input template for llama3 instruct.

@cccclai
Copy link
Contributor Author

cccclai commented Sep 5, 2024

I think we might need to calibrate some special tokens in input template for llama3 instruct.

I see. Should I leave it to you for the proper calibration pr or would you prefer me to amend this one?

@shewu-quic
Copy link
Collaborator

shewu-quic commented Sep 5, 2024

I see. Should I leave it to you for the proper calibration pr or would you prefer me to amend this one?

If you can, I’d appreciate it!

@cccclai
Copy link
Contributor Author

cccclai commented Sep 5, 2024

calibrate some special tokens in input template for llama3 instruct.

Sure happy to help. Mind sharing what extra calibration you did?

@shewu-quic
Copy link
Collaborator

shewu-quic commented Sep 5, 2024

calibrate some special tokens in input template for llama3 instruct.

Sure happy to help. Mind sharing what extra calibration you did?

Sure, I'm just calibrating a prompt that contains an input template.

def eval_once(self, module: torch.fx.GraphModule, string: str = "Once upon a time", max_len: int = 128):
        tokenizer = SimpleTokenizer(self.tokenizer_path)
        # TODO: change criteria & support batch inputs if necessary
        pos = torch.tensor(0, dtype=torch.int64)
        token_list = [tokenizer.bos_id] + tokenizer.encode(string)

        with torch.no_grad():
            while token_list[-1] != tokenizer.eos_id and pos < max_len:
                logits = module(
                    torch.full((1, 1), token_list[pos]),
                    torch.tensor((pos, )),
                )
                pos += 1
                if pos >= len(token_list):
                    token_list.append(torch.argmax(logits[:], dim=-1).item())
...
    def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManager":
        ...
        # Calibration
        self.eval_once(m, string="<|start_header_id|>system<|end_header_id|>\n\nYou are a cute and funny chatbot answering questions<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nTell me about Meta<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", max_len=128)

@cccclai cccclai requested a review from kimishpatel September 5, 2024 22:17
# Batch process the whole sequence.
logits = self._model(inps[:, : self._max_seq_length], pos_tensor)
return logits
if not self._dynamic_shape:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When do we not enable dynamic shape?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe dynamic_shape is default to true now. We can probably ignore the False case here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We actually disable dynamic shape to all other backends except xnnpack: https://github.com/pytorch/executorch/tree/main/examples/models/llama2#optional-smaller-models-delegated-to-other-backends

Particularly for QNN, we can only do static shape for now.

@@ -190,7 +252,26 @@ def pt2e_quantize(self, quantizers: Optional[List[Quantizer]]) -> "LLMEdgeManage
), "Please run capture_pre_autograd_graph first"
m = prepare_pt2e(self.pre_autograd_graph_module, composed_quantizer)
# Calibrate
m(*self.example_inputs)
logging.info(f"Calibrating with tasks: {self.calibration_tasks}, limit: {self.calibration_limit}, seq_length: {self.calibration_seq_length}, tokenizer_path: {self.tokenizer_path}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this log info duplicated with the log at line 263? Maybe remove this line?

token_list = [tokenizer.bos_id] + tokenizer.encode(string, bos=True, eos=False)

with torch.no_grad():
while token_list[-1] != tokenizer.eos_id and pos < max_len:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious why not batch prefill here to make it faster?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if batch prefill here, it should also check dynamic shape.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we use the graph module instead of eager model for calibration, and the graph module is captured with static shape. The graph module is captured with fix shape and batch prefill requires dynamic shape.

Copy link
Collaborator

@shewu-quic shewu-quic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just minor fixed
Overall look good to me.

Copy link
Contributor

@kimishpatel kimishpatel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So instead of modifying export_llama_lib the way I think you want to do this is via

  1. Introduce QNNRunEvalWrapper like here https://github.com/pytorch/executorch/blob/main/examples/models/llama2/eval_llama_lib.py#L124
  2. On LLMEdgeManager instance returned from _prepare_for_llama_export, call quantization apis manually. Similar to https://github.com/pytorch/executorch/blob/main/examples/models/llama2/eval_llama_lib.py#L145. Where you export the model and then instead of calling pt2e_quantize, call prepare_pt2e manually. Such a model then can be returned as nn Module.
  3. Wrap the module in QNNRunEvalWrapper like other wrappers.

Why this way?
I think this way you are not polluting export_llama_lib with eval related concerns and eval related to code still remains within eval pipeline

@@ -166,19 +166,25 @@ def build_args_parser() -> argparse.ArgumentParser:
nargs="+",
type=str,
default=None,
help="Tasks for GPTQ calibration",
help="Tasks for GPTQ calibration from lm_eval",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: For future reference, separate out unrelated fixes

@cccclai
Copy link
Contributor Author

cccclai commented Sep 6, 2024

@kimishpatel thanks for the detailed review. Apply suggesion on the latest commit and plz take another look.

  1. Introduce QNNRunEvalWrapper like here https://github.com/pytorch/executorch/blob/main/examples/models/llama2/eval_llama_lib.py#L124

I introduce a graphmodule run eval wrapper in the new commit instead, as I feel like it's not qnn specific. The actual calibration data might be different but that can be control by the args

On LLMEdgeManager instance returned from _prepare_for_llama_export, call quantization apis manually. Similar to https://github.com/pytorch/executorch/blob/main/examples/models/llama2/eval_llama_lib.py#L145. Where you export the model and then instead of calling pt2e_quantize, call prepare_pt2e manually. Such a model then can be returned as nn Module.

I'm not exactly sure sure what it means. Do you mean having a separate calibrate api in LLMEdgeManager? and then do .prepare_pt2e().calibrate().convert_pt2e()? Also It's a graph module but not nn module here.

Wrap the module in QNNRunEvalWrapper like other wrappers.

Did this for the GraphModuleEvalWrapper

@kimishpatel
Copy link
Contributor

I introduce a graphmodule run eval wrapper in the new commit instead, as I feel like it's not qnn specific. The actual calibration data might be different but that can be control by the args

Sounds good

@kimishpatel
Copy link
Contributor

Also It's a graph module but not nn module here.

graph module is nn module

@kimishpatel
Copy link
Contributor

I'm not exactly sure sure what it means. Do you mean having a separate calibrate api in LLMEdgeManager?

No the opposite. I dont think it makes sense to have calibrate api on LLMEdgeManager. I meant more like

llm_manager ...
graph_module = prepare_pt2e(llm_manager.exported_program().graph_module)

@@ -167,6 +178,69 @@ def capture_pre_autograd_graph(self) -> "LLMEdgeManager":
)
return self

def pt2e_calibrate(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method should not be part of builder at all. It is meant to produce a model not calibrate.

Hence my suggestion was to move the functionality of this method either inside GraphModuleEvalWrapper or soemthing else

Copy link
Contributor

@kimishpatel kimishpatel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

discussed and looks good

@cccclai cccclai merged commit 7122d31 into main Sep 6, 2024
35 of 36 checks passed
@cccclai cccclai deleted the pt2e_calibration branch September 6, 2024 17:10
@cccclai cccclai restored the pt2e_calibration branch September 6, 2024 18:25
cccclai added a commit that referenced this pull request Sep 6, 2024
cccclai added a commit that referenced this pull request Sep 6, 2024
Revert "Add proper pt2e calibration (#5095)"

This reverts commit 7122d31.
cccclai added a commit to cccclai/executorch-1 that referenced this pull request Sep 6, 2024
Summary:
See discussion in pytorch#5095

Reland because of internal failure

Differential Revision: D62323396
cccclai added a commit to cccclai/executorch-1 that referenced this pull request Sep 7, 2024
Summary:
Pull Request resolved: pytorch#5152

See discussion in pytorch#5095

Reland because of internal failure

Differential Revision: D62323396
cccclai added a commit to cccclai/executorch-1 that referenced this pull request Sep 7, 2024
Summary:
Pull Request resolved: pytorch#5152

See discussion in pytorch#5095

Reland because of internal failure

Differential Revision: D62323396
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants