-
Notifications
You must be signed in to change notification settings - Fork 607
Qualcomm AI Engine Direct - Model sharding for LLM #4923
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Qualcomm AI Engine Direct - Model sharding for LLM #4923
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/4923
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 777fa22 with merge base 3fb03dc ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Hi @cccclai, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for adding this change! I believe it's the last piece?
Thanks for your prompt review. It should be work for llama model now.
class ReplaceIndexPutInput(ExportPass):
"""
Index put input workaround for quantized module
"""
dq_q_map = {
# per tensor
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor: exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor,
# per channel
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default: exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
}
def __init__(self, edge_program: torch.export.ExportedProgram):
super(ReplaceIndexPutInput, self).__init__()
self.edge_program = edge_program
def call(self, graph_module: torch.fx.GraphModule):
graph = graph_module.graph
for node in graph.nodes:
if node.target == exir_ops.edge.aten.index_put.default:
if (copy_node := list(node.users)[0]) and copy_node.target == exir_ops.edge.aten.copy.default:
m_buffer_node = copy_node.args[0]
bad_frozen_node = node.args[0]
if QCOM_QUANT_ATTRS in bad_frozen_node.meta:
m_buffer_node.meta[QCOM_QUANT_ATTRS] = bad_frozen_node.meta[QCOM_QUANT_ATTRS]
m_buffer_node.meta[QCOM_QUANT_ATTRS][QCOM_ENCODING] = self.dq_q_map[m_buffer_node.meta[QCOM_QUANT_ATTRS][QCOM_ENCODING]]
with graph.inserting_after(bad_frozen_node):
node.replace_input_with(bad_frozen_node, m_buffer_node)
else:
continue
graph.eliminate_dead_code()
graph_module.recompile()
return PassResult(graph_module, True)
|
19adc79
to
20d1e12
Compare
Hello team, I was trying to quantize the Llama 3.1 8B model against the QNN backend. After pulling in this PR, I had this issue below:
|
Oops, Sorry about that "get_n_layers" seems to be removed from meta data. Let me fix it. |
For LLM, model size is too large to fit in device memory for inference. Therefore, we need to divide the model into a few parts in order to avoid inference time out-of-memory errors. Summary: - Use custom fallback op to split graph - Add splill fill feature - Add model sharding argument for qnn
20d1e12
to
777fa22
Compare
I can help with that. Sorry I don’t have access to my regular machines recently and had a difficult time to repro the qnn flow on Ubuntu Linux (wsl) system. But this change should be simple enough and doesn’t require setting up qnn dependencies |
Just want to confirm on this
Any specific reason it's needed? Is it because the graph is better? |
The past kv cache is eliminated (frozen) with capture_pre_autograd_graph after convet_pt2e(fold_quantize=True). |
Did you expect changes like this #4942? Also wanted to check, if we replace |
Yes, it is similar to our change, but we don't set "strict". Does it affect anything? |
Hi @cccclai, @shewu-quic, I am trying to load the llama 3.1 8B model sharded in 4 to the QNN example. But the loading failed with memory buffer issues. The same workflow works for llama 2 7B model without using any sharding. I wonder if it is related to this diff and there are things needed to update for the model loading on-device? This is on OnePlus 12 with 16GB RAM, so memory size shouldn't be the issue. Here is the full log
|
Hi @WuhanMonkey, if you saw exact pdId 2
Then it's related to the system. It's hard to do anything on the application side. I can only suggest
|
Thank you. Beside OnePlus phone, can it work on S24+ without any issue? |
For LLM, model size is too large to fit in device memory for inference. Therefore, we need to divide the model into a few parts in order to avoid inference time out-of-memory errors.
Summary: