Skip to content

Commit 574d835

Browse files
author
Chun-I Tsai
committed
shared buffer + smart mask
- Add flag to use smart mask or shift pointer - Add llama3_2 python with smart mask updator - Change Memory class to IoMgrBase - Change HybridMemory class to ShiftPointerIoMgr
1 parent cf8d0cf commit 574d835

File tree

12 files changed

+1334
-566
lines changed

12 files changed

+1334
-566
lines changed

backends/qualcomm/runtime/QnnManager.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,8 +154,9 @@ Error QnnManager::RegisterMem(
154154
const std::shared_ptr<TensorWrapper>& tensor_wrapper) {
155155
SharedBuffer& shared_buffer_manager = SharedBuffer::GetSharedBufferManager();
156156
// Not enable shared buffer
157-
if (!options_->shared_buffer())
157+
if (!options_->shared_buffer()) {
158158
return Error::Internal;
159+
}
159160

160161
if (backend_params_ptr_->qnn_mem_manager_ptr_ == nullptr) {
161162
QNN_EXECUTORCH_LOG_WARN(

backends/qualcomm/runtime/QnnManager.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ class QnnManager {
145145
{Qnn_DataType_t::QNN_DATATYPE_UFIXED_POINT_8,
146146
executorch::aten::ScalarType::Byte},
147147
{Qnn_DataType_t::QNN_DATATYPE_UFIXED_POINT_16,
148-
executorch::aten::ScalarType::Bits16},
148+
executorch::aten::ScalarType::UInt16},
149149
};
150150
};
151151
} // namespace qnn

backends/qualcomm/runtime/backends/QnnMemManager.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ class QnnMemManager {
7777
Qnn_DataType_t::QNN_DATATYPE_SFIXED_POINT_16},
7878
{executorch::aten::ScalarType::Byte,
7979
Qnn_DataType_t::QNN_DATATYPE_UFIXED_POINT_8},
80-
{executorch::aten::ScalarType::Bits16,
80+
{executorch::aten::ScalarType::UInt16,
8181
Qnn_DataType_t::QNN_DATATYPE_UFIXED_POINT_16},
8282
};
8383
};

examples/qualcomm/oss_scripts/llama2/runner/runner.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ class IoMemMgr {
195195
{executorch::aten::ScalarType::Short, sizeof(int16_t)},
196196
{executorch::aten::ScalarType::Byte, sizeof(uint8_t)},
197197
{executorch::aten::ScalarType::Bits16, sizeof(uint16_t)},
198+
{executorch::aten::ScalarType::UInt16, sizeof(uint16_t)},
198199
};
199200
};
200201

examples/qualcomm/oss_scripts/llama3_2/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ list(
2828
${CMAKE_CURRENT_LIST_DIR}/qnn_llama3_2_runner.cpp
2929
${CMAKE_CURRENT_LIST_DIR}/runner/runner.cpp
3030
${CMAKE_CURRENT_LIST_DIR}/runner/runner.h
31-
${CMAKE_CURRENT_LIST_DIR}/runner/io_memory.cpp
32-
${CMAKE_CURRENT_LIST_DIR}/runner/io_memory.h
31+
${CMAKE_CURRENT_LIST_DIR}/runner/io_manager.cpp
32+
${CMAKE_CURRENT_LIST_DIR}/runner/io_manager.h
3333
)
3434

3535
list(

examples/qualcomm/oss_scripts/llama3_2/llama.py

Lines changed: 59 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,42 @@
6666
logging.getLogger().setLevel(logging.INFO)
6767

6868

69+
def smart_mask_updator(atten_mask, pos, k_caches, v_caches, new_k_caches, new_v_caches):
70+
for i, k_cache in enumerate(k_caches):
71+
k_cache[:, :, pos] = new_k_caches[i][:, :, 0]
72+
73+
for i, v_cache in enumerate(v_caches):
74+
v_cache[:, pos, :] = new_v_caches[i]
75+
76+
atten_mask[0][pos] = 0
77+
pos += 1
78+
return (atten_mask, pos, k_caches, v_caches)
79+
80+
81+
def shift_pointer_updator(
82+
atten_mask, pos, k_caches, v_caches, new_k_caches, new_v_caches
83+
):
84+
k_caches = [
85+
torch.cat([k_cache[:, :, 1:], new_k_caches[i]], dim=-1)
86+
for i, k_cache in enumerate(k_caches)
87+
]
88+
v_caches = [
89+
torch.cat([v_cache[:, 1:, :], new_v_caches[i]], dim=1)
90+
for i, v_cache in enumerate(v_caches)
91+
]
92+
93+
pos += 1
94+
atten_mask[0][-pos - 1] = 0
95+
return (atten_mask, pos, k_caches, v_caches)
96+
97+
6998
def _kv_calibrate(
7099
example_inputs,
71100
user_prompts,
72101
module: torch.fx.GraphModule,
73102
tokenizer_model_path="tokenizer.model",
74103
max_seq_len=512,
104+
updator=smart_mask_updator,
75105
):
76106
sp_model = get_tokenizer(tokenizer_model_path)
77107
_, atten_mask, _, k_caches, v_caches = example_inputs
@@ -92,17 +122,9 @@ def _kv_calibrate(
92122
*k_caches,
93123
*v_caches,
94124
)
95-
k_caches = [
96-
torch.cat([k_cache[:, :, 1:], new_k_caches[i]], dim=-1)
97-
for i, k_cache in enumerate(k_caches)
98-
]
99-
v_caches = [
100-
torch.cat([v_cache[:, 1:, :], new_v_caches[i]], dim=1)
101-
for i, v_cache in enumerate(v_caches)
102-
]
103-
104-
pos += 1
105-
atten_mask[0][-pos - 1] = 0
125+
atten_mask, pos, k_caches, v_caches = updator(
126+
atten_mask, pos, k_caches, v_caches, new_k_caches, new_v_caches
127+
)
106128
if pos >= len(token_list):
107129
token_list.append(torch.argmax(logits[:, -1], dim=-1).item())
108130

@@ -153,6 +175,7 @@ def calibrate(
153175
module: torch.fx.GraphModule,
154176
tokenizer_model_path="tokenizer.model",
155177
max_seq_len=512,
178+
kv_updator=smart_mask_updator,
156179
):
157180
if len(example_inputs) == 2:
158181
_prefill_calibrate(
@@ -169,6 +192,7 @@ def calibrate(
169192
module,
170193
tokenizer_model_path,
171194
max_seq_len,
195+
updator=kv_updator,
172196
)
173197
else:
174198
raise RuntimeError("Get wrong inputs")
@@ -298,13 +322,15 @@ def quantize(self, quant_dtype, args, custom_annotations=()):
298322
self.llama_model, self.inputs, strict=True
299323
).module()
300324
fx_graph_module = prepare_pt2e(fx_graph_module, quantizer)
325+
301326
logging.info("Quantizing the model...")
302327
calibrate(
303328
self.get_example_inputs(self.llama_meta["get_use_kv_cache"]),
304329
args.prompt,
305330
fx_graph_module,
306331
tokenizer_model_path=args.tokenizer_model,
307332
max_seq_len=self.llama_meta["get_max_seq_len"],
333+
kv_updator=args.kv_updator,
308334
)
309335

310336
self.llama_model = convert_pt2e(fx_graph_module)
@@ -316,6 +342,7 @@ def lowering_modules(
316342
use_fp16=False,
317343
soc_model=QcomChipset.SM8650,
318344
num_sharding=0,
345+
shared_buffer=False,
319346
):
320347
executorch_config = ExecutorchBackendConfig(
321348
# For shared buffer, user must pass the memory address
@@ -336,7 +363,7 @@ def lowering_modules(
336363
compiler_specs = generate_qnn_executorch_compiler_spec(
337364
soc_model=soc_model,
338365
backend_options=backend_options,
339-
shared_buffer=False,
366+
shared_buffer=shared_buffer,
340367
)
341368
skip_node_op_set = {"llama.fallback.default"}
342369
partitioner = QnnPartitioner(
@@ -366,7 +393,7 @@ def lowering_modules(
366393
if num_sharding > 0:
367394
update_spill_fill_size(edge_prog_mgr.exported_program())
368395
exec_prog_mgr = edge_prog_mgr.to_executorch(config=executorch_config)
369-
with open(f"{work_space}/{pte_filename}.pte", "wb") as file:
396+
with open(f"{work_space}/{self.pte_filename}.pte", "wb") as file:
370397
exec_prog_mgr.write_to_file(file)
371398

372399
def get_example_inputs(self, use_kv_cache=True):
@@ -491,6 +518,7 @@ def compile(args, pte_filename):
491518
use_fp16=use_fp16,
492519
soc_model=get_soc_to_chipset_map()[args.model],
493520
num_sharding=args.num_sharding,
521+
shared_buffer=args.shared_buffer,
494522
)
495523
quant_attrs = llama_instance_list[0].get_quant_attrs()
496524
else:
@@ -525,7 +553,7 @@ def compile(args, pte_filename):
525553
generate_qnn_executorch_compiler_spec(
526554
soc_model=get_soc_to_chipset_map()[args.model],
527555
backend_options=backend_options,
528-
shared_buffer=True,
556+
shared_buffer=args.shared_buffer,
529557
multiple_graphs=True,
530558
graph_name=graph_name,
531559
)
@@ -697,6 +725,7 @@ def inference(args, quant_attrs, pte_filename, pre_gen_pte=""):
697725
f"--system_prompt '{args.system_prompt}'",
698726
f"--logits_scale {quant_attrs['scale']}",
699727
f"--logits_offset {quant_attrs['zero_point']}",
728+
f"--kv_updator {'SmartMask' if args.kv_updator == smart_mask_updator else 'ShiftPointer'}",
700729
]
701730
)
702731
runner_cmd = " ".join(
@@ -862,6 +891,14 @@ def main():
862891
type=int,
863892
)
864893

894+
parser.add_argument(
895+
"--kv_updator",
896+
help="Choose how to update kv cache during runtime",
897+
choices=["smart_mask", "shift_pointer"],
898+
default="smart_mask",
899+
type=str,
900+
)
901+
865902
args = parser.parse_args()
866903
if args.compile_only and args.pre_gen_pte:
867904
exit("Cannot set both compile_only and pre_gen_pte as true")
@@ -878,6 +915,14 @@ def main():
878915
else:
879916
raise RuntimeError(f"No such model_mode {args.model_mode}.")
880917

918+
if args.kv_updator == "smart_mask":
919+
args.shared_buffer = True
920+
args.kv_updator = smart_mask_updator
921+
elif args.kv_updator == "shift_pointer":
922+
args.kv_updator = shift_pointer_updator
923+
else:
924+
exit(f"Using an unkown kv update {args.kv_updator}")
925+
881926
if args.pre_gen_pte:
882927
quant_attrs = json.load(
883928
open(f"{args.pre_gen_pte}/{pte_filename}_quant_attrs.txt")

examples/qualcomm/oss_scripts/llama3_2/qnn_llama3_2_runner.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ DEFINE_int32(
5050
"0: PromptProcessor(prefill) / 1: TokenGenerator(kv) / 2: HybridMode (prefill+kv)");
5151
DEFINE_double(logits_scale, 0.0, "Logits scale");
5252
DEFINE_int32(logits_offset, 0, "Logits offset");
53+
DEFINE_string(
54+
kv_updator,
55+
"How to update kv cache. Choose between SmartMask and ShiftPointer",
56+
"SmartMask");
5357

5458
int main(int argc, char** argv) {
5559
gflags::ParseCommandLineFlags(&argc, &argv, true);
@@ -61,7 +65,8 @@ int main(int argc, char** argv) {
6165
FLAGS_logits_scale,
6266
FLAGS_logits_offset,
6367
FLAGS_temperature,
64-
FLAGS_eval_mode);
68+
FLAGS_eval_mode,
69+
FLAGS_kv_updator);
6570
std::vector<char> buf;
6671
buf.reserve(5 * FLAGS_seq_len); // assume each token is around 5 char
6772
std::ofstream fout(FLAGS_output_path.c_str());

0 commit comments

Comments
 (0)