Skip to content

Commit d385570

Browse files
yipjustinfacebook-github-bot
authored andcommitted
Compute graph print readable (#2825)
Summary: Add capability to print the node list with arguments to allow better debugging. Reviewed By: SS-JIA Differential Revision: D55510335
1 parent 02f565e commit d385570

File tree

4 files changed

+223
-0
lines changed

4 files changed

+223
-0
lines changed

backends/vulkan/runtime/api/Runtime.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,22 @@
1616
namespace vkcompute {
1717
namespace api {
1818

19+
#define PRINT_CASE(name) \
20+
case MemoryAccessType::name: \
21+
out << #name; \
22+
break;
23+
24+
std::ostream& operator<<(std::ostream& out, const MemoryAccessType& tag) {
25+
switch (tag) {
26+
PRINT_CASE(NONE)
27+
PRINT_CASE(READ)
28+
PRINT_CASE(WRITE)
29+
}
30+
return out;
31+
}
32+
33+
#undef PRINT_CASE
34+
1935
namespace {
2036

2137
void find_requested_layers_and_extensions(

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,12 @@ class ComputeGraph final {
312312

313313
void resize_input(const int64_t idx, const std::vector<int64_t>& new_sizes);
314314
void propagate_resize();
315+
316+
//
317+
// Debug support (implemented in Logging.cpp)
318+
//
319+
320+
void print_readable();
315321
};
316322

317323
template <typename T>
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/Logging.h>
10+
11+
#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
12+
13+
#include <iomanip>
14+
#include <iostream>
15+
#include <map>
16+
#include <set>
17+
18+
namespace vkcompute {
19+
20+
void ComputeGraph::print_readable() {
21+
std::set<ValueRef> input_set;
22+
for (const IOValueRef& io_val : inputs()) {
23+
input_set.insert(io_val.value);
24+
}
25+
26+
std::set<ValueRef> output_set;
27+
for (const IOValueRef& io_val : outputs()) {
28+
output_set.insert(io_val.value);
29+
}
30+
31+
std::set<ValueRef> prepack_set;
32+
for (const std::unique_ptr<PrepackNode>& node : prepack_nodes()) {
33+
prepack_set.insert(node->tref_);
34+
prepack_set.insert(node->packed_);
35+
}
36+
37+
std::map<ValueRef, size_t> value_ref_to_shared_object_idx;
38+
39+
std::cout << "====================" << std::left << std::setfill('=')
40+
<< std::setw(40) << " Shared Object List " << std::right
41+
<< std::setfill(' ') << std::endl;
42+
43+
std::cout << std::setw(6) << "idx" << std::setw(20) << "sizes"
44+
<< std::setw(24) << "users" << std::endl;
45+
46+
size_t so_idx = 0;
47+
for (const SharedObject& shared_object : shared_objects_) {
48+
std::cout << std::setw(6) << so_idx;
49+
{
50+
std::stringstream ss;
51+
ss << shared_object.aggregate_memory_requirements.size;
52+
std::cout << std::setw(20) << ss.str();
53+
}
54+
55+
{
56+
std::stringstream ss;
57+
ss << shared_object.users;
58+
std::cout << std::setw(24) << ss.str();
59+
}
60+
std::cout << std::endl;
61+
62+
for (const ValueRef& user : shared_object.users) {
63+
value_ref_to_shared_object_idx[user] = so_idx;
64+
}
65+
66+
so_idx++;
67+
}
68+
69+
std::cout << "====================" << std::left << std::setfill('=')
70+
<< std::setw(40) << " Value List " << std::right
71+
<< std::setfill(' ') << std::endl;
72+
73+
std::cout << std::setw(6) << "idx" << std::setw(10) << "type" << std::setw(20)
74+
<< "sizes" << std::setw(10) << "node_type" << std::setw(10)
75+
<< "so_idx" << std::endl;
76+
77+
size_t value_idx = 0;
78+
for (Value& val : values_) {
79+
std::cout << std::setw(6) << value_idx << std::setw(10) << val.type();
80+
81+
// sizes
82+
std::cout << std::setw(20);
83+
if (val.isTensor()) {
84+
const vTensor& v_tensor = val.toTensor();
85+
std::stringstream ss;
86+
ss << v_tensor.sizes();
87+
std::cout << ss.str();
88+
} else if (val.isTensorRef()) {
89+
const TensorRef& tensor_ref = val.toTensorRef();
90+
std::stringstream ss;
91+
ss << tensor_ref.sizes;
92+
std::cout << ss.str();
93+
} else {
94+
std::cout << "";
95+
}
96+
97+
// Node type
98+
std::cout << std::setw(10);
99+
{
100+
if (input_set.count(value_idx) > 0) {
101+
std::cout << "INPUT";
102+
} else if (output_set.count(value_idx) > 0) {
103+
std::cout << "OUTPUT";
104+
} else if (prepack_set.count(value_idx) > 0) {
105+
std::cout << "PREPACK";
106+
} else {
107+
std::cout << "";
108+
}
109+
}
110+
111+
std::cout << std::setw(10);
112+
if (value_ref_to_shared_object_idx.count(value_idx) > 0) {
113+
size_t shared_obj_idx = value_ref_to_shared_object_idx.at(value_idx);
114+
std::cout << shared_obj_idx;
115+
} else {
116+
std::cout << "";
117+
}
118+
119+
std::cout << std::endl;
120+
value_idx++;
121+
}
122+
123+
std::cout << "====================" << std::left << std::setfill('=')
124+
<< std::setw(40) << " Prepack Node List " << std::right
125+
<< std::setfill(' ') << std::endl;
126+
std::cout << std::setw(6) << "idx" << std::setw(32) << "shader_name"
127+
<< std::setw(8) << "tref" << std::setw(8) << "packed" << std::endl;
128+
129+
size_t prepack_node_idx = 0;
130+
for (const std::unique_ptr<PrepackNode>& node : prepack_nodes()) {
131+
std::cout << std::setw(6) << prepack_node_idx << std::setw(32)
132+
<< node->shader_.kernel_name << std::setw(8) << node->tref_
133+
<< std::setw(8) << node->packed_ << std::endl;
134+
135+
prepack_node_idx++;
136+
}
137+
138+
std::cout << "====================" << std::left << std::setfill('=')
139+
<< std::setw(40) << " Execute Node List " << std::right
140+
<< std::setfill(' ') << std::endl;
141+
142+
std::cout << std::setw(6) << "idx" << std::setw(32) << "shader_name"
143+
<< std::setw(24) << "READ_arg" << std::setw(24) << "WRITE_arg"
144+
<< std::endl;
145+
146+
size_t node_idx = 0;
147+
for (const std::unique_ptr<ExecuteNode>& node : execute_nodes()) {
148+
std::cout << std::setw(6) << node_idx;
149+
std::cout << std::setw(32) << node->shader_.kernel_name;
150+
151+
std::stringstream read_s;
152+
for (const ArgGroup& arg_group : node->args_) {
153+
if (arg_group.access != api::MemoryAccessType::READ) {
154+
continue;
155+
}
156+
read_s << arg_group.refs;
157+
}
158+
std::cout << std::setw(24) << read_s.str();
159+
160+
std::stringstream write_s;
161+
for (const ArgGroup& arg_group : node->args_) {
162+
if (arg_group.access != api::MemoryAccessType::WRITE) {
163+
continue;
164+
}
165+
write_s << arg_group.refs;
166+
}
167+
std::cout << std::setw(24) << write_s.str();
168+
169+
std::cout << std::endl;
170+
171+
node_idx++;
172+
}
173+
}
174+
175+
} // namespace vkcompute
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <ostream>
12+
#include <vector>
13+
14+
namespace vkcompute {
15+
16+
template <typename T>
17+
inline std::ostream& operator<<(std::ostream& os, const std::vector<T>& vec) {
18+
os << '[';
19+
for (const auto& elem : vec) {
20+
os << elem << ',';
21+
}
22+
os << ']';
23+
return os; // Return the ostream to allow chaining
24+
}
25+
26+
} // namespace vkcompute

0 commit comments

Comments
 (0)