Skip to content

Qualcomm AI Engine Direct - Add smart mask kv updator for llama3.2 #7694

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion backends/qualcomm/runtime/QnnManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,9 @@ Error QnnManager::RegisterMem(
const std::shared_ptr<TensorWrapper>& tensor_wrapper) {
SharedBuffer& shared_buffer_manager = SharedBuffer::GetSharedBufferManager();
// Not enable shared buffer
if (!options_->shared_buffer())
if (!options_->shared_buffer()) {
return Error::Internal;
}

if (backend_params_ptr_->qnn_mem_manager_ptr_ == nullptr) {
QNN_EXECUTORCH_LOG_WARN(
Expand Down
2 changes: 1 addition & 1 deletion backends/qualcomm/runtime/QnnManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ class QnnManager {
{Qnn_DataType_t::QNN_DATATYPE_UFIXED_POINT_8,
executorch::aten::ScalarType::Byte},
{Qnn_DataType_t::QNN_DATATYPE_UFIXED_POINT_16,
executorch::aten::ScalarType::Bits16},
executorch::aten::ScalarType::UInt16},
};
};
} // namespace qnn
Expand Down
2 changes: 1 addition & 1 deletion backends/qualcomm/runtime/backends/QnnMemManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class QnnMemManager {
Qnn_DataType_t::QNN_DATATYPE_SFIXED_POINT_16},
{executorch::aten::ScalarType::Byte,
Qnn_DataType_t::QNN_DATATYPE_UFIXED_POINT_8},
{executorch::aten::ScalarType::Bits16,
{executorch::aten::ScalarType::UInt16,
Qnn_DataType_t::QNN_DATATYPE_UFIXED_POINT_16},
};
};
Expand Down
4 changes: 2 additions & 2 deletions examples/qualcomm/oss_scripts/llama/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ list(
${CMAKE_CURRENT_LIST_DIR}/qnn_llama_runner.cpp
${CMAKE_CURRENT_LIST_DIR}/runner/runner.cpp
${CMAKE_CURRENT_LIST_DIR}/runner/runner.h
${CMAKE_CURRENT_LIST_DIR}/runner/io_memory.cpp
${CMAKE_CURRENT_LIST_DIR}/runner/io_memory.h
${CMAKE_CURRENT_LIST_DIR}/runner/io_manager.cpp
${CMAKE_CURRENT_LIST_DIR}/runner/io_manager.h
)

list(
Expand Down
71 changes: 58 additions & 13 deletions examples/qualcomm/oss_scripts/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,42 @@
logging.getLogger().setLevel(logging.INFO)


def smart_mask_updator(atten_mask, pos, k_caches, v_caches, new_k_caches, new_v_caches):
for i, k_cache in enumerate(k_caches):
k_cache[:, :, pos] = new_k_caches[i][:, :, 0]

for i, v_cache in enumerate(v_caches):
v_cache[:, pos, :] = new_v_caches[i]

atten_mask[0][pos] = 0
pos += 1
return (atten_mask, pos, k_caches, v_caches)


def shift_pointer_updator(
atten_mask, pos, k_caches, v_caches, new_k_caches, new_v_caches
):
k_caches = [
torch.cat([k_cache[:, :, 1:], new_k_caches[i]], dim=-1)
for i, k_cache in enumerate(k_caches)
]
v_caches = [
torch.cat([v_cache[:, 1:, :], new_v_caches[i]], dim=1)
for i, v_cache in enumerate(v_caches)
]

pos += 1
atten_mask[0][-pos - 1] = 0
return (atten_mask, pos, k_caches, v_caches)


def _kv_calibrate(
example_inputs,
user_prompts,
module: torch.fx.GraphModule,
tokenizer,
max_seq_len=512,
updator=smart_mask_updator,
):
_, atten_mask, _, k_caches, v_caches = example_inputs

Expand Down Expand Up @@ -105,17 +135,9 @@ def _kv_calibrate(
*k_caches,
*v_caches,
)
k_caches = [
torch.cat([k_cache[:, :, 1:], new_k_caches[i]], dim=-1)
for i, k_cache in enumerate(k_caches)
]
v_caches = [
torch.cat([v_cache[:, 1:, :], new_v_caches[i]], dim=1)
for i, v_cache in enumerate(v_caches)
]

pos += 1
atten_mask[0][-pos - 1] = 0
atten_mask, pos, k_caches, v_caches = updator(
atten_mask, pos, k_caches, v_caches, new_k_caches, new_v_caches
)
if pos >= len(token_list):
token_list.append(torch.argmax(logits[:, -1], dim=-1).item())

Expand Down Expand Up @@ -174,6 +196,7 @@ def calibrate(
module: torch.fx.GraphModule,
tokenizer,
max_seq_len=512,
kv_updator=smart_mask_updator,
):
if len(example_inputs) == 2:
_prefill_calibrate(
Expand All @@ -190,6 +213,7 @@ def calibrate(
module,
tokenizer,
max_seq_len,
updator=kv_updator,
)
else:
raise RuntimeError("Get wrong inputs")
Expand Down Expand Up @@ -319,13 +343,15 @@ def quantize(self, quant_dtype, args, tokenizer, custom_annotations=()):
self.llama_model, self.inputs, strict=True
).module()
fx_graph_module = prepare_pt2e(fx_graph_module, quantizer)

logging.info("Quantizing the model...")
calibrate(
self.get_example_inputs(self.llama_meta["get_use_kv_cache"]),
args.prompt,
fx_graph_module,
tokenizer=tokenizer,
max_seq_len=self.llama_meta["get_max_seq_len"],
kv_updator=args.kv_updator,
)

self.llama_model = convert_pt2e(fx_graph_module)
Expand All @@ -337,6 +363,7 @@ def lowering_modules(
use_fp16=False,
soc_model=QcomChipset.SM8650,
num_sharding=0,
shared_buffer=False,
):
executorch_config = ExecutorchBackendConfig(
# For shared buffer, user must pass the memory address
Expand All @@ -357,7 +384,7 @@ def lowering_modules(
compiler_specs = generate_qnn_executorch_compiler_spec(
soc_model=soc_model,
backend_options=backend_options,
shared_buffer=False,
shared_buffer=shared_buffer,
)
skip_node_op_set = {"llama.fallback.default"}
partitioner = QnnPartitioner(
Expand Down Expand Up @@ -530,6 +557,7 @@ def compile(args, pte_filename, tokenizer):
use_fp16=use_fp16,
soc_model=get_soc_to_chipset_map()[args.model],
num_sharding=args.num_sharding,
shared_buffer=args.shared_buffer,
)
quant_attrs = llama_instance_list[0].get_quant_attrs()
else:
Expand Down Expand Up @@ -564,7 +592,7 @@ def compile(args, pte_filename, tokenizer):
generate_qnn_executorch_compiler_spec(
soc_model=get_soc_to_chipset_map()[args.model],
backend_options=backend_options,
shared_buffer=True,
shared_buffer=args.shared_buffer,
multiple_graphs=True,
graph_name=graph_name,
)
Expand Down Expand Up @@ -736,6 +764,7 @@ def inference(args, quant_attrs, pte_filename, runtime_tokenizer_path, pre_gen_p
f"--system_prompt '{args.system_prompt}'",
f"--logits_scale {quant_attrs['scale']}",
f"--logits_offset {quant_attrs['zero_point']}",
f"--kv_updator {'SmartMask' if args.kv_updator == smart_mask_updator else 'ShiftPointer'}",
]
)
runner_cmd = " ".join(
Expand Down Expand Up @@ -907,6 +936,14 @@ def main():
type=int,
)

parser.add_argument(
"--kv_updator",
help="Choose how to update kv cache during runtime",
choices=["smart_mask", "shift_pointer"],
default="smart_mask",
type=str,
)

args = parser.parse_args()
if args.compile_only and args.pre_gen_pte:
exit("Cannot set both compile_only and pre_gen_pte as true")
Expand Down Expand Up @@ -941,6 +978,14 @@ def main():
else:
raise RuntimeError(f"Unknown llama_model: {args.llama_model}.")

if args.kv_updator == "smart_mask":
args.shared_buffer = True
args.kv_updator = smart_mask_updator
elif args.kv_updator == "shift_pointer":
args.kv_updator = shift_pointer_updator
else:
exit(f"Using an unkown kv update {args.kv_updator}")

if args.pre_gen_pte:
quant_attrs = json.load(
open(f"{args.pre_gen_pte}/{pte_filename}_quant_attrs.txt")
Expand Down
7 changes: 6 additions & 1 deletion examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ DEFINE_int32(
"0: PromptProcessor(prefill) / 1: TokenGenerator(kv) / 2: HybridMode (prefill+kv)");
DEFINE_double(logits_scale, 0.0, "Logits scale");
DEFINE_int32(logits_offset, 0, "Logits offset");
DEFINE_string(
kv_updator,
"How to update kv cache. Choose between SmartMask and ShiftPointer",
"SmartMask");

int main(int argc, char** argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
Expand All @@ -62,7 +66,8 @@ int main(int argc, char** argv) {
FLAGS_logits_scale,
FLAGS_logits_offset,
FLAGS_temperature,
FLAGS_eval_mode);
FLAGS_eval_mode,
FLAGS_kv_updator);
std::vector<char> buf;
buf.reserve(5 * FLAGS_seq_len); // assume each token is around 5 char
std::ofstream fout(FLAGS_output_path.c_str());
Expand Down
Loading
Loading