Skip to content

[ET-VK] Nit Arithmetic cleanup #2246

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 25 additions & 29 deletions backends/vulkan/runtime/graph/ops/impl/Arithmetic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,44 +36,40 @@ DEFINE_ARITHMETIC_FN(pow, POW);

ValueRef add_arithmetic_node(
ComputeGraph& graph,
const ValueRef t1,
const ValueRef t2,
const ValueRef in1,
const ValueRef in2,
const float alpha,
const arithmetic::OpType optype,
const int64_t shared_object_idx) {
std::vector<int64_t> t1_sizes = graph.get_val_sizes(t1);
api::ScalarType t1_dtype = graph.get_val_dtype(t1);
std::vector<int64_t> in1_sizes = graph.get_val_sizes(in1);
api::ScalarType in1_dtype = graph.get_val_dtype(in1);

ValueRef out = graph.add_tensor(t1_sizes, t1_dtype, shared_object_idx);
add_arithmetic_node(graph, t1, t2, out, alpha, optype);
ValueRef out = graph.add_tensor(in1_sizes, in1_dtype, shared_object_idx);
add_arithmetic_node(graph, in1, in2, out, alpha, optype);
return out;
}

// TODO(T181006464): Move to Utils when we remove ArithmeticPrepack.
ValueRef prepack_if_tensor_ref(ComputeGraph& graph, const ValueRef v) {
if (graph.get_val(v).isTensor()) {
return v;
} else {
TensorRef& tRef = graph.get_val(v).toTensorRef();
ValueRef vTen = graph.add_tensor(tRef.sizes, tRef.dtype);
graph.prepack_nodes().emplace_back(new ArithmeticPrepack(v, vTen));
return vTen;
}
}

void add_arithmetic_node(
ComputeGraph& graph,
const ValueRef t1,
const ValueRef t2,
const ValueRef in1,
const ValueRef in2,
const ValueRef out,
const float alpha,
const arithmetic::OpType optype) {
// Prepacking first arg (if needed)
ValueRef arg1 = t1;
if (graph.get_val(t1).isTensorRef()) {
TensorRef& t1_asref = graph.get_val(t1).toTensorRef();
ValueRef t1_vten = graph.add_tensor(t1_asref.sizes, t1_asref.dtype);
graph.prepack_nodes().emplace_back(new ArithmeticPrepack(t1, t1_vten));
arg1 = t1_vten;
}
VK_CHECK_COND(graph.get_val(arg1).isTensor());
// Prepacking second arg (if needed)
ValueRef arg2 = t2;
if (graph.get_val(t2).isTensorRef()) {
TensorRef& t2_asref = graph.get_val(t2).toTensorRef();
ValueRef t2_vten = graph.add_tensor(t2_asref.sizes, t2_asref.dtype);
graph.prepack_nodes().emplace_back(new ArithmeticPrepack(t2, t2_vten));
arg2 = t2_vten;
}
VK_CHECK_COND(graph.get_val(arg2).isTensor());
ValueRef arg1 = prepack_if_tensor_ref(graph, in1);
ValueRef arg2 = prepack_if_tensor_ref(graph, in2);

graph.execute_nodes().emplace_back(
new ArithmeticNode(arg1, arg2, out, alpha, optype));
Expand All @@ -97,12 +93,12 @@ void ArithmeticPrepack::encode(ComputeGraph* graph) const {
}

ArithmeticNode::ArithmeticNode(
const ValueRef t1,
const ValueRef t2,
const ValueRef in1,
const ValueRef in2,
const ValueRef out,
const float alpha,
const arithmetic::OpType optype)
: ExecuteNode({t1, t2}, {out}), alpha_(alpha), optype_(optype) {}
: ExecuteNode({in1, in2}, {out}), alpha_(alpha), optype_(optype) {}

void ArithmeticNode::encode(ComputeGraph* graph) const {
vTensor& in1 = graph->get_val(inputs_[0]).toTensor();
Expand Down
12 changes: 6 additions & 6 deletions backends/vulkan/runtime/graph/ops/impl/Arithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,16 @@ DECLARE_OP_FN(pow);

ValueRef add_arithmetic_node(
ComputeGraph& graph,
const ValueRef t1,
const ValueRef t2,
const ValueRef in1,
const ValueRef in2,
const float alpha,
const arithmetic::OpType optype,
const int64_t shared_object_idx = -1);

void add_arithmetic_node(
ComputeGraph& graph,
const ValueRef t1,
const ValueRef t2,
const ValueRef in1,
const ValueRef in2,
const ValueRef out,
const float alpha,
const arithmetic::OpType optype);
Expand All @@ -53,8 +53,8 @@ class ArithmeticPrepack : public virtual PrepackNode {
class ArithmeticNode : public virtual ExecuteNode {
public:
explicit ArithmeticNode(
const ValueRef t1,
const ValueRef t2,
const ValueRef in1,
const ValueRef in2,
const ValueRef out,
const float alpha,
const arithmetic::OpType optype);
Expand Down