Skip to content

feat: Add DrawGraph tool for graph visualization #7172

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
238 changes: 236 additions & 2 deletions backends/qualcomm/aot/python/PyQnnWrapperAdaptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,37 @@ std::unique_ptr<QuantizeParamsWrapper> CreateQuantizationParamWrapper(
return quantize_param_wrapper;
}

std::string GetScalarValue(const Qnn_Scalar_t& scalar) {
switch (scalar.dataType) {
case QNN_DATATYPE_FLOAT_32:
return std::to_string(scalar.floatValue);
case QNN_DATATYPE_FLOAT_64:
return std::to_string(scalar.doubleValue);
case QNN_DATATYPE_UINT_64:
return std::to_string(scalar.uint64Value);
case QNN_DATATYPE_INT_64:
return std::to_string(scalar.int64Value);
case QNN_DATATYPE_UINT_32:
return std::to_string(scalar.uint32Value);
case QNN_DATATYPE_INT_32:
return std::to_string(scalar.int32Value);
case QNN_DATATYPE_UINT_16:
return std::to_string(scalar.uint16Value);
case QNN_DATATYPE_INT_16:
return std::to_string(scalar.int16Value);
case QNN_DATATYPE_UINT_8:
return std::to_string(scalar.uint8Value);
case QNN_DATATYPE_INT_8:
return std::to_string(scalar.int8Value);
case QNN_DATATYPE_BOOL_8:
return std::to_string(static_cast<int>(scalar.bool8Value));
case QNN_DATATYPE_STRING:
return std::string(scalar.stringValue);
default:
return "QNN_DATATYPE_UNDEFINED";
}
}

std::shared_ptr<TensorWrapper> CreateTensorWrapper(
const std::string& tensor_name,
Qnn_TensorType_t tensor_type,
Expand Down Expand Up @@ -176,11 +207,60 @@ PYBIND11_MODULE(PyQnnWrapperAdaptor, m) {
Qnn_QuantizationEncoding_t::
QNN_QUANTIZATION_ENCODING_BW_AXIS_SCALE_OFFSET)
.export_values();

py::class_<OpWrapper, std::shared_ptr<OpWrapper>>(m, "OpWrapper")
.def(py::init<
const std::string&,
const std::string&,
const std::string&>());
const std::string&>())
.def(
"GetInputTensors",
&OpWrapper::GetInputTensors,
"A function which gets input tensors")
.def(
"GetOutputTensors",
&OpWrapper::GetOutputTensors,
"A function which gets output tensors")
.def("GetOpType", &OpWrapper::GetOpType, "A function which gets op type")
.def("GetName", &OpWrapper::GetName, "A function which gets name")
.def(
"GetPackageName",
&OpWrapper::GetPackageName,
"A function which gets package name")
.def(
"GetParams", &OpWrapper::GetRawParams, "A function which gets params")
// lambda function
// python: op_wrapper.GetOpConfig()
.def(
"GetOpConfig",
[](OpWrapper& self) {
auto op_config = self.GetOpConfig();
py::dict result;
py::list params_list;
py::list input_tensors_list;
py::list output_tensors_list;
result["version"] = op_config.version;
result["name"] = op_config.v1.name;
result["packageName"] = op_config.v1.packageName;
result["typeName"] = op_config.v1.typeName;
result["numOfParams"] = op_config.v1.numOfParams;
for (size_t i = 0; i < op_config.v1.numOfParams; ++i) {
params_list.append(op_config.v1.params[i]);
}
result["params"] = params_list;
result["numOfInputs"] = op_config.v1.numOfInputs;
for (size_t i = 0; i < op_config.v1.numOfInputs; ++i) {
input_tensors_list.append(op_config.v1.inputTensors[i]);
}
result["inputTensors"] = input_tensors_list;
result["numOfOutputs"] = op_config.v1.numOfOutputs;
for (size_t i = 0; i < op_config.v1.numOfOutputs; ++i) {
output_tensors_list.append(op_config.v1.outputTensors[i]);
}
result["outputTensors"] = output_tensors_list;
return result;
},
"Get operator configuration");

py::class_<TensorWrapper, std::shared_ptr<TensorWrapper>>(m, "TensorWrapper")
.def(py::init(py::overload_cast<
Expand All @@ -197,7 +277,9 @@ PYBIND11_MODULE(PyQnnWrapperAdaptor, m) {
py::class_<QuantizeParamsWrapper>(m, "QuantizeParamsWrapper");

py::class_<Qnn_ScaleOffset_t>(m, "Qnn_ScaleOffset_t")
.def(py::init<float, int32_t>());
.def(py::init<float, int32_t>())
.def_readonly("scale", &Qnn_ScaleOffset_t::scale)
.def_readonly("offset", &Qnn_ScaleOffset_t::offset);

py::class_<PyQnnOpWrapper, std::shared_ptr<PyQnnOpWrapper>>(
m, "PyQnnOpWrapper")
Expand Down Expand Up @@ -248,6 +330,158 @@ PYBIND11_MODULE(PyQnnWrapperAdaptor, m) {
.def("GetDataType", &PyQnnTensorWrapper::GetDataType)
.def("GetName", &PyQnnTensorWrapper::GetName)
.def("GetEncodings", &PyQnnTensorWrapper::GetEncodings);

py::class_<Qnn_OpConfig_t>(m, "Qnn_OpConfig")
.def_readonly("version", &Qnn_OpConfig_t::version)
// getter
// python: op_wrapper.GetOpConfig().v1
.def_property_readonly(
"v1", [](const Qnn_OpConfig_t& config) -> const Qnn_OpConfigV1_t& {
return config.v1;
});

py::enum_<Qnn_OpConfigVersion_t>(m, "Qnn_OpConfigVersion")
.value("QNN_OPCONFIG_VERSION_1", QNN_OPCONFIG_VERSION_1)
.value("QNN_OPCONFIG_VERSION_UNDEFINED", QNN_OPCONFIG_VERSION_UNDEFINED)
.export_values();

py::class_<Qnn_OpConfigV1_t>(m, "Qnn_OpConfigV1")
.def_readonly("name", &Qnn_OpConfigV1_t::name)
.def_readonly("packageName", &Qnn_OpConfigV1_t::packageName)
.def_readonly("typeName", &Qnn_OpConfigV1_t::typeName)
.def_readonly("numOfParams", &Qnn_OpConfigV1_t::numOfParams)
.def_readonly("params", &Qnn_OpConfigV1_t::params)
.def_readonly("numOfInputs", &Qnn_OpConfigV1_t::numOfInputs)
.def_readonly("inputTensors", &Qnn_OpConfigV1_t::inputTensors)
.def_readonly("numOfOutputs", &Qnn_OpConfigV1_t::numOfOutputs)
.def_readonly("outputTensors", &Qnn_OpConfigV1_t::outputTensors);

py::class_<Qnn_Param_t>(m, "Qnn_Param")
.def_readonly("paramType", &Qnn_Param_t::paramType)
.def_readonly("name", &Qnn_Param_t::name)
.def_property_readonly(
"scalarParam",
[](const Qnn_Param_t& param) -> const Qnn_Scalar_t& {
if (param.paramType == Qnn_ParamType_t::QNN_PARAMTYPE_SCALAR) {
return param.scalarParam;
}
throw std::runtime_error("ParamType is not scalar.");
})
.def_property_readonly(
"tensorParam", [](const Qnn_Param_t& param) -> const Qnn_Tensor_t& {
if (param.paramType == Qnn_ParamType_t::QNN_PARAMTYPE_TENSOR) {
return param.tensorParam;
}
throw std::runtime_error("ParamType is not tensor.");
});

py::enum_<Qnn_ParamType_t>(m, "Qnn_ParamType_t")
.value("QNN_PARAMTYPE_SCALAR", Qnn_ParamType_t::QNN_PARAMTYPE_SCALAR)
.value("QNN_PARAMTYPE_TENSOR", Qnn_ParamType_t::QNN_PARAMTYPE_TENSOR)
.value(
"QNN_PARAMTYPE_UNDEFINED", Qnn_ParamType_t::QNN_PARAMTYPE_UNDEFINED)
.export_values();

py::class_<Qnn_Scalar_t>(m, "Qnn_Scalar_t")
.def_readonly("dataType", &Qnn_Scalar_t::dataType)
.def("value", &GetScalarValue, "Get the value of the scalar as a string");

py::class_<Qnn_Tensor_t>(m, "Qnn_Tensor_t")
.def_readonly("version", &Qnn_Tensor_t::version)
.def_property_readonly(
"v1",
[](Qnn_Tensor_t& t) -> Qnn_TensorV1_t& {
if (t.version == QNN_TENSOR_VERSION_1) {
return t.v1;
}
throw std::runtime_error("Tensor version is not V1.");
})
.def_property_readonly("v2", [](Qnn_Tensor_t& t) -> Qnn_TensorV2_t& {
if (t.version == QNN_TENSOR_VERSION_2) {
return t.v2;
}
throw std::runtime_error("Tensor version is not V2.");
});

py::enum_<Qnn_TensorVersion_t>(m, "Qnn_TensorVersion_t")
.value("QNN_TENSOR_VERSION_1", Qnn_TensorVersion_t::QNN_TENSOR_VERSION_1)
.value("QNN_TENSOR_VERSION_2", Qnn_TensorVersion_t::QNN_TENSOR_VERSION_2)
.value(
"QNN_TENSOR_VERSION_UNDEFINED",
Qnn_TensorVersion_t::QNN_TENSOR_VERSION_UNDEFINED)
.export_values();

py::class_<Qnn_TensorV1_t>(m, "QnnTensorV1")
.def_readonly("id", &Qnn_TensorV1_t::id)
.def_readonly("name", &Qnn_TensorV1_t::name)
.def_readonly("type", &Qnn_TensorV1_t::type)
.def_readonly("dataFormat", &Qnn_TensorV1_t::dataFormat)
.def_readonly("dataType", &Qnn_TensorV1_t::dataType)
.def_readonly("quantizeParams", &Qnn_TensorV1_t::quantizeParams)
.def_readonly("rank", &Qnn_TensorV1_t::rank)
// change dimensions pointer to vector(begin to rank)
.def_property_readonly(
"dimensions",
[](const Qnn_TensorV1_t& t) {
return std::vector<uint32_t>(t.dimensions, t.dimensions + t.rank);
})
.def_readonly("memType", &Qnn_TensorV1_t::memType);

py::enum_<Qnn_TensorMemType_t>(m, "Qnn_TensorMemType_t")
.value(
"QNN_TENSORMEMTYPE_RAW", Qnn_TensorMemType_t::QNN_TENSORMEMTYPE_RAW)
.value(
"QNN_TENSORMEMTYPE_MEMHANDLE",
Qnn_TensorMemType_t::QNN_TENSORMEMTYPE_MEMHANDLE)
.value(
"QNN_TENSORMEMTYPE_UNDEFINED",
Qnn_TensorMemType_t::QNN_TENSORMEMTYPE_UNDEFINED)
.export_values();

py::class_<Qnn_QuantizeParams_t>(m, "QnnQuantizeParams")
.def_readonly(
"encodingDefinition", &Qnn_QuantizeParams_t::encodingDefinition)
.def_readonly(
"quantizationEncoding", &Qnn_QuantizeParams_t::quantizationEncoding)
.def_property_readonly(
"scaleOffsetEncoding",
[](const Qnn_QuantizeParams_t& qp) {
if (qp.quantizationEncoding ==
QNN_QUANTIZATION_ENCODING_SCALE_OFFSET) {
return qp.scaleOffsetEncoding;
}
throw std::runtime_error(
"Invalid quantization encoding type for scaleOffsetEncoding.");
})
.def_property_readonly(
"axisScaleOffsetEncoding", [](const Qnn_QuantizeParams_t& qp) {
if (qp.quantizationEncoding ==
QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET) {
return qp.axisScaleOffsetEncoding;
}
throw std::runtime_error(
"Invalid quantization encoding type for axisScaleOffsetEncoding.");
});

py::enum_<Qnn_Definition_t>(m, "QnnDefinition")
.value(
"QNN_DEFINITION_IMPL_GENERATED",
Qnn_Definition_t::QNN_DEFINITION_IMPL_GENERATED)
.value("QNN_DEFINITION_DEFINED", Qnn_Definition_t::QNN_DEFINITION_DEFINED)
.value(
"QNN_DEFINITION_UNDEFINED",
Qnn_Definition_t::QNN_DEFINITION_UNDEFINED)
.export_values();

py::class_<Qnn_AxisScaleOffset_t>(m, "QnnAxisScaleOffset")
.def_readonly("axis", &Qnn_AxisScaleOffset_t::axis)
.def_readonly("numScaleOffsets", &Qnn_AxisScaleOffset_t::numScaleOffsets)
.def_property_readonly(
"scaleOffset", [](const Qnn_AxisScaleOffset_t& aso) {
return std::vector<Qnn_ScaleOffset_t>(
aso.scaleOffset, aso.scaleOffset + aso.numScaleOffsets);
});
// op_wrapper.GetParams() get std::vector<ParamWrapper*>
}
} // namespace qnn
} // namespace backends
Expand Down
13 changes: 13 additions & 0 deletions backends/qualcomm/aot/wrappers/OpWrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,19 @@ class OpWrapper final {
const std::string GetOpType() {
return op_type_;
}
const std::string GetName() {
return name_;
}
const std::string GetPackageName() {
return package_name_;
}
std::vector<ParamWrapper*> GetRawParams() const {
std::vector<ParamWrapper*> raw_params;
for (const auto& param : params_) {
raw_params.push_back(param.get());
}
return raw_params;
}
Qnn_OpConfig_t GetOpConfig();

private:
Expand Down
Loading
Loading