Skip to content

Commit c80c123

Browse files
haowhsu-quichaowhsu
authored andcommitted
apply review comments #1
1 parent 1997683 commit c80c123

File tree

12 files changed

+102
-100
lines changed

12 files changed

+102
-100
lines changed

backends/qualcomm/aot/ir/qcir_utils.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ flatbuffers::Offset<qcir::QuantizeParam> ToQuantizeParam(
136136
case qcir::QuantizeType::AXIS_SCALE_OFFSET: {
137137
size_t len = param.axisScaleOffsetEncoding.numScaleOffsets;
138138
axis = param.axisScaleOffsetEncoding.axis;
139+
data.reserve(len);
139140
for (uint i = 0; i < len; ++i) {
140141
data.emplace_back(qcir::ScaleOffset(
141142
param.axisScaleOffsetEncoding.scaleOffset[i].scale,
@@ -151,6 +152,8 @@ flatbuffers::Offset<qcir::QuantizeParam> ToQuantizeParam(
151152
bitwidth = param.bwAxisScaleOffsetEncoding.bitwidth;
152153
axis = param.bwAxisScaleOffsetEncoding.axis;
153154
size_t len = param.bwAxisScaleOffsetEncoding.numElements;
155+
scales.reserve(len);
156+
offsets.reserve(len);
154157
for (size_t i = 0; i < len; ++i) {
155158
scales.push_back(param.bwAxisScaleOffsetEncoding.scales[i]);
156159
offsets.push_back(param.bwAxisScaleOffsetEncoding.offsets[i]);
@@ -216,10 +219,10 @@ Qnn_QuantizeParams_t ToQuantizeParam(const qparam_type& param) {
216219
p.bwAxisScaleOffsetEncoding.bitwidth = param->bitwidth();
217220
p.bwAxisScaleOffsetEncoding.axis = param->axis();
218221
p.bwAxisScaleOffsetEncoding.numElements = param->scales()->size();
219-
p.bwAxisScaleOffsetEncoding.scales = reinterpret_cast<float*>(
220-
const_cast<uint8_t*>(param->scales()->Data()));
221-
p.bwAxisScaleOffsetEncoding.offsets = reinterpret_cast<int32_t*>(
222-
const_cast<uint8_t*>(param->offsets()->Data()));
222+
p.bwAxisScaleOffsetEncoding.scales =
223+
const_cast<float*>(param->scales()->data());
224+
p.bwAxisScaleOffsetEncoding.offsets =
225+
const_cast<int32_t*>(param->offsets()->data());
223226
} break;
224227
default:
225228
QNN_EXECUTORCH_LOG_ERROR("qcir::QuantizeType::UNDEFINED detected");
@@ -260,8 +263,7 @@ Qnn_Tensor_t ToTensor(const tensor_type& tensor) {
260263
QNN_VER_PTR(t)->dataType = ToDataType(tensor->dtype());
261264
QNN_VER_PTR(t)->quantizeParams = ToQuantizeParam(tensor->qparam());
262265
QNN_VER_PTR(t)->rank = tensor->shape()->size();
263-
QNN_VER_PTR(t)->dimensions = reinterpret_cast<uint32_t*>(
264-
const_cast<uint8_t*>(tensor->shape()->Data()));
266+
QNN_VER_PTR(t)->dimensions = const_cast<uint32_t*>(tensor->shape()->data());
265267
QNN_VER_PTR(t)->clientBuf.dataSize = tensor->data()->size();
266268
QNN_VER_PTR(t)->clientBuf.data = is_io_tensor(QNN_VER_PTR(t)->type)
267269
? nullptr

backends/qualcomm/aot/wrappers/TensorWrapper.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ class TensorWrapper {
7979
return QNN_VER_PTR(tensor_)->quantizeParams;
8080
}
8181

82-
std::string GetName() const {
82+
const std::string& GetName() const {
8383
return qnn_tensor_name_;
8484
};
8585

backends/qualcomm/runtime/QnnManager.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ Error QnnManager::AllocateTensor() {
329329
std::shared_ptr<TensorWrapper> tensor_wrapper =
330330
CreateTensorWrapper(output_tensors[i]);
331331
tensor_wrapper->UpdateQnnTensorMeta(output_tensors[i]);
332-
std::string tensor_name = tensor_wrapper->GetName();
332+
const std::string& tensor_name = tensor_wrapper->GetName();
333333
// this is required by identifying shared buffer mechanism
334334
// info might be missed if context binary came from qnn_converter
335335
if (tensor_name.find("output_") == std::string::npos) {

backends/qualcomm/runtime/backends/QnnGraphCommon.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ Error QnnGraph::EnsureTensorInQnnGraph(
8282

8383
int name_conflict_count = 0;
8484
while (error == QNN_TENSOR_ERROR_NAME_HASH_COLLISION) {
85-
std::string old_name = tensor_wrapper->GetName();
85+
const std::string& old_name = tensor_wrapper->GetName();
8686

8787
std::string new_name =
8888
old_name + "_" + std::to_string(name_conflict_count);

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1360,7 +1360,7 @@ def test_qnn_backend_context_direct(self):
13601360
module,
13611361
tuple(
13621362
torch.randn(size=v.shape, dtype=v.dtype)
1363-
for _, v in bundle_program["inputs"].items()
1363+
for v in bundle_program["inputs"].values()
13641364
),
13651365
lowered_module,
13661366
)
@@ -1528,7 +1528,7 @@ def test_qnn_backend_context_direct(self):
15281528
module,
15291529
tuple(
15301530
torch.randn(size=v.shape, dtype=v.dtype)
1531-
for _, v in bundle_program["inputs"].items()
1531+
for v in bundle_program["inputs"].values()
15321532
),
15331533
lowered_module,
15341534
)

backends/qualcomm/tests/utils.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,26 +53,25 @@ def generate_context_binary(
5353
assert qnn_sdk, "QNN_SDK_ROOT was not found in environment variable"
5454
assert ndk, "ANDROID_NDK_ROOT was not found in environment variable"
5555

56-
inputs_tup = tuple(v for _, v in inputs.items())
56+
inputs_tup = tuple(inputs.values())
5757
jit_module = torch.jit.trace(module, inputs_tup)
5858
torch.jit.save(jit_module, f"{artifact_dir}/jit_module.pt")
5959

6060
# input data
6161
if quantized:
62-
input_list, idx = [], 0
62+
input_list = []
6363
for name, data in inputs.items():
6464
file_name = f"{artifact_dir}/{name}.raw"
6565
data.detach().numpy().tofile(file_name)
6666
input_list.append(file_name)
67-
idx += 1
6867

6968
with open(f"{artifact_dir}/input_list.txt", "w") as f:
7069
f.write(" ".join(input_list))
7170

7271
# flow of qnn tools
7372
target = "x86_64-linux-clang"
7473
inputs_str = [
75-
f"-d '{k}' " + str(tuple(v.shape)).replace(" ", "")[1:-1]
74+
f"-d '{k}' {str(tuple(v.shape)).replace(' ', '')[1:-1]}"
7675
for k, v in inputs.items()
7776
]
7877
cmds = [

backends/qualcomm/utils/utils.py

Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -228,37 +228,34 @@ def capture_program(
228228

229229

230230
def from_context_binary(ctx_path: str, op_name: str):
231+
def implement_op(custom_op, op_name, outputs):
232+
@torch.library.impl(
233+
custom_op, str(op_name), dispatch_key="CompositeExplicitAutograd"
234+
)
235+
def op_impl(inputs: List[torch.Tensor]):
236+
return tuple(
237+
torch.zeros(tuple(v.shape), device="meta", dtype=v.dtype)
238+
for v in outputs.values()
239+
)
240+
231241
def build_graph(inputs, outputs):
232242
# custom op declaration
233243
inputs_str = "Tensor[] inputs"
234244
func_proto = f"{op_name}({inputs_str}) -> Any"
235245
custom_op = Library(OpContextLoader.namespace, "FRAGMENT")
236246
custom_op.define(func_proto)
237247
# custom op implementation
238-
args_name = "inputs"
239-
inputs_str = f"{args_name}: List[torch.Tensor]"
240-
outputs_str = "return " + ", ".join(
241-
[
242-
f"torch.zeros({tuple(v.shape)}, device='meta', dtype={v.dtype})"
243-
for _, v in outputs.items()
244-
]
245-
)
246-
exec(
247-
f'@torch.library.impl(custom_op, "{op_name}", '
248-
'dispatch_key="CompositeExplicitAutograd")'
249-
f"\ndef {op_name}_impl({inputs_str}):"
250-
f"\n\t{outputs_str}",
251-
)
248+
implement_op(custom_op, op_name, outputs)
249+
252250
# model architecture mimicking context binary
253-
inputs_str = ", ".join(k for k in inputs.keys())
254-
exec(
255-
"class Model(torch.nn.Module):"
256-
f"\n\tdef forward(self, {inputs_str}):"
257-
f"\n\t\t{args_name} = [{inputs_str}]"
258-
f"\n\t\treturn torch.ops.{OpContextLoader.namespace}.{op_name}.default({args_name})",
259-
)
260-
model = eval("Model()")
261-
prog = torch.export.export(model, tuple(v for _, v in inputs.items()))
251+
class Model(torch.nn.Module):
252+
def forward(self, *inputs):
253+
return getattr(
254+
getattr(torch.ops, OpContextLoader.namespace), op_name
255+
).default(inputs)
256+
257+
model = Model()
258+
prog = torch.export.export(model, tuple(inputs.values()))
262259
# bookkeeping for variables' life cycle
263260
return {
264261
"custom_op": custom_op,
@@ -292,7 +289,7 @@ def build_tensor(tensors, dtype_map):
292289
assert qnn_mgr.Init().value == 0, "failed to load context binary"
293290
qnn_mgr.AllocateTensor()
294291
dtype_map = {}
295-
for type_map in [QNN_QUANT_TYPE_MAP, QNN_TENSOR_TYPE_MAP]:
292+
for type_map in (QNN_QUANT_TYPE_MAP, QNN_TENSOR_TYPE_MAP):
296293
for k, v in type_map.items():
297294
dtype_map.setdefault(v, k)
298295
inputs = build_tensor(qnn_mgr.GetGraphInputs(), dtype_map)

examples/qualcomm/executor_runner/qnn_qaihub_llama_runner.cpp

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ DEFINE_double(
4343
DEFINE_int32(
4444
eval_mode,
4545
0,
46-
"0: BERT-like evaluation / 1: KV cache based token generation / 2: Mixed mode (TBD)");
46+
"0: PromptProcessor / 1: TokenGenerator / 2: MixedMode (TBD)");
4747
DEFINE_int32(
4848
seq_len,
4949
128,
@@ -74,15 +74,11 @@ int main(int argc, char** argv) {
7474
FLAGS_logits_scale,
7575
FLAGS_logits_offset);
7676

77-
// generate tokens
78-
std::string inference_output;
77+
// generate tokens & store inference output
78+
std::ofstream fout(FLAGS_output_path.c_str());
7979
runner.generate(FLAGS_prompt, FLAGS_seq_len, [&](const std::string& piece) {
80-
inference_output += piece;
80+
fout << piece;
8181
});
82-
83-
// store inference output
84-
std::ofstream fout(FLAGS_output_path.c_str());
85-
fout << inference_output;
8682
fout.close();
8783
return 0;
8884
}

examples/qualcomm/llama2/llama_qaihub.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass
2929

3030

31-
if __name__ == "__main__":
31+
def main():
3232
parser = setup_common_args_and_variables()
3333

3434
parser.add_argument(
@@ -106,28 +106,27 @@
106106
use_multi_contexts=True,
107107
)
108108
compiler_specs = generate_qnn_executorch_compiler_spec(
109-
soc_model=eval(f"QcomChipset.{args.model}"),
109+
soc_model=getattr(QcomChipset, args.model),
110110
backend_options=backend_options,
111111
is_from_context_binary=True,
112112
)
113113

114114
if args.pre_gen_pte is None:
115115
# create custom operators as context loader
116-
bundle_programs = []
117-
for i, target in enumerate(target_names):
118-
file_name = f"{args.context_binaries}/{target}"
119-
bundle_programs.append(from_context_binary(file_name, f"ctx_loader_{i}"))
116+
bundle_programs = [
117+
from_context_binary(f"{args.context_binaries}/{target}", f"ctx_loader_{i}")
118+
for i, target in enumerate(target_names)
119+
]
120120
# lower with QnnBackend
121-
lowered_modules = []
122-
for prog in bundle_programs:
123-
lowered_modules.append(
124-
to_backend("QnnBackend", prog["edge_program"], compiler_specs)
125-
)
121+
lowered_modules = [
122+
to_backend("QnnBackend", prog["edge_program"], compiler_specs)
123+
for prog in bundle_programs
124+
]
126125
# setup spill-fill buffer for relieving runtime memory usage
127126
canonicalize_program(lowered_modules)
128127
# export pte files
129128
pte_name, pte_files = "qaihub_llama7b", []
130-
for i, _ in enumerate(target_names):
129+
for i in range(len(target_names)):
131130
memory_planning_pass = MemoryPlanningPass(
132131
memory_planning_algo="greedy",
133132
alloc_graph_input=False,
@@ -147,7 +146,7 @@
147146
pte_files = [f"{args.pre_gen_pte}/{pte_name}_{i}.pte" for i in range(4)]
148147

149148
if args.compile_only:
150-
exit(0)
149+
return
151150

152151
def get_logit_encoding(path_to_last_shard: str):
153152
with open(f"{args.context_binaries}/{path_to_last_shard}", "rb") as f:
@@ -230,3 +229,7 @@ def post_process():
230229
adb.push(files=custom_files)
231230
adb.execute(custom_runner_cmd=runner_cmds)
232231
adb.pull(args.artifact, callback=post_process)
232+
233+
234+
if __name__ == "__main__":
235+
main()

0 commit comments

Comments
 (0)