Skip to content

Commit 1116125

Browse files
yipjustinfacebook-github-bot
authored andcommitted
Compute graph print readable
Summary: Add capability to print the node list with arguments to allow better debugging. Differential Revision: D55510335
1 parent bd6ceab commit 1116125

File tree

4 files changed

+206
-0
lines changed

4 files changed

+206
-0
lines changed

backends/vulkan/runtime/api/Runtime.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,20 @@
1616
namespace vkcompute {
1717
namespace api {
1818

19+
#define PRINT_CASE(name) \
20+
case MemoryAccessType::name: \
21+
out << #name; \
22+
break;
23+
24+
std::ostream& operator<<(std::ostream& out, const MemoryAccessType& tag) {
25+
switch (tag) {
26+
PRINT_CASE(NONE)
27+
PRINT_CASE(READ)
28+
PRINT_CASE(WRITE)
29+
}
30+
return out;
31+
}
32+
1933
namespace {
2034

2135
void find_requested_layers_and_extensions(

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,12 @@ class ComputeGraph final {
296296

297297
void resize_input(const int64_t idx, const std::vector<int64_t>& new_sizes);
298298
void propagate_resize();
299+
300+
//
301+
// Debug support
302+
//
303+
304+
void print_readable();
299305
};
300306

301307
template <typename T>
Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
10+
11+
#include <iomanip>
12+
#include <iostream>
13+
#include <map>
14+
#include <ratio>
15+
#include <set>
16+
17+
namespace vkcompute {
18+
19+
template <typename T>
20+
inline std::ostream& operator<<(std::ostream& os, const std::vector<T>& vec) {
21+
os << '[';
22+
for (const auto& elem : vec) {
23+
os << elem << ',';
24+
}
25+
os << ']';
26+
return os; // Return the ostream to allow chaining
27+
}
28+
29+
void ComputeGraph::print_readable() {
30+
std::set<ValueRef> input_set;
31+
for (const IOValueRef& io_val : inputs_) {
32+
input_set.insert(io_val.value);
33+
}
34+
35+
std::set<ValueRef> output_set;
36+
for (const IOValueRef& io_val : outputs_) {
37+
output_set.insert(io_val.value);
38+
}
39+
40+
std::set<ValueRef> prepack_set;
41+
for (const std::unique_ptr<PrepackNode>& node : prepack_nodes_) {
42+
prepack_set.insert(node->tref_);
43+
prepack_set.insert(node->packed_);
44+
}
45+
46+
std::map<ValueRef, size_t> value_ref_to_shared_object_idx;
47+
48+
std::cout << "====================" << std::left << std::setfill('=')
49+
<< std::setw(40) << " Shared Object List " << std::right
50+
<< std::setfill(' ') << std::endl;
51+
52+
std::cout << std::setw(6) << "idx" << std::setw(20) << "sizes"
53+
<< std::setw(24) << "users" << std::endl;
54+
55+
size_t so_idx = 0;
56+
for (const SharedObject& shared_object : shared_objects_) {
57+
std::cout << std::setw(6) << so_idx;
58+
{
59+
std::stringstream ss;
60+
ss << shared_object.aggregate_memory_requirements.size;
61+
std::cout << std::setw(20) << ss.str();
62+
}
63+
64+
{
65+
std::stringstream ss;
66+
ss << shared_object.users;
67+
std::cout << std::setw(24) << ss.str();
68+
}
69+
std::cout << std::endl;
70+
71+
for (const ValueRef& user : shared_object.users) {
72+
value_ref_to_shared_object_idx[user] = so_idx;
73+
}
74+
75+
so_idx++;
76+
}
77+
78+
std::cout << "====================" << std::left << std::setfill('=')
79+
<< std::setw(40) << " Value List " << std::right
80+
<< std::setfill(' ') << std::endl;
81+
82+
std::cout << std::setw(6) << "idx" << std::setw(10) << "type" << std::setw(20)
83+
<< "sizes" << std::setw(10) << "node_type" << std::setw(10)
84+
<< "so_idx" << std::endl;
85+
86+
size_t value_idx = 0;
87+
for (Value& val : values_) {
88+
std::cout << std::setw(6) << value_idx << std::setw(10) << val.type();
89+
90+
// sizes
91+
std::cout << std::setw(20);
92+
if (val.isTensor()) {
93+
vTensor& v_tensor = val.toTensor();
94+
std::stringstream ss;
95+
ss << v_tensor.sizes();
96+
std::cout << ss.str();
97+
} else if (val.isTensorRef()) {
98+
TensorRef tensor_ref = val.toTensorRef();
99+
std::stringstream ss;
100+
ss << tensor_ref.sizes;
101+
std::cout << ss.str();
102+
} else {
103+
std::cout << "";
104+
}
105+
106+
// Node type
107+
std::cout << std::setw(10);
108+
{
109+
if (input_set.count(value_idx) > 0) {
110+
std::cout << "INPUT";
111+
} else if (output_set.count(value_idx) > 0) {
112+
std::cout << "OUTPUT";
113+
} else if (prepack_set.count(value_idx) > 0) {
114+
std::cout << "PREPACK";
115+
} else {
116+
std::cout << "";
117+
}
118+
}
119+
120+
std::cout << std::setw(10);
121+
if (value_ref_to_shared_object_idx.count(value_idx) > 0) {
122+
size_t shared_obj_idx = value_ref_to_shared_object_idx.at(value_idx);
123+
std::cout << shared_obj_idx;
124+
} else {
125+
std::cout << "";
126+
}
127+
128+
std::cout << std::endl;
129+
value_idx++;
130+
}
131+
132+
std::cout << "====================" << std::left << std::setfill('=')
133+
<< std::setw(40) << " Prepack Node List " << std::right
134+
<< std::setfill(' ') << std::endl;
135+
std::cout << std::setw(6) << "idx" << std::setw(32) << "shader_name"
136+
<< std::setw(8) << "tref" << std::setw(8) << "packed" << std::endl;
137+
138+
size_t prepack_node_idx = 0;
139+
for (const std::unique_ptr<PrepackNode>& node : prepack_nodes_) {
140+
std::cout << std::setw(6) << prepack_node_idx << std::setw(32)
141+
<< node->shader_.kernel_name << std::setw(8) << node->tref_
142+
<< std::setw(8) << node->packed_ << std::endl;
143+
144+
prepack_node_idx++;
145+
}
146+
147+
std::cout << "====================" << std::left << std::setfill('=')
148+
<< std::setw(40) << " Execute Node List " << std::right
149+
<< std::setfill(' ') << std::endl;
150+
151+
std::cout << std::setw(6) << "idx" << std::setw(32) << "shader_name"
152+
<< std::setw(24) << "READ_arg" << std::setw(24) << "WRITE_arg"
153+
<< std::endl;
154+
155+
size_t node_idx = 0;
156+
for (const std::unique_ptr<ExecuteNode>& node : execute_nodes_) {
157+
std::cout << std::setw(6) << node_idx;
158+
std::cout << std::setw(32) << node->shader_.kernel_name;
159+
160+
std::stringstream read_s;
161+
for (const ArgGroup& arg_group : node->args_) {
162+
if (arg_group.access != api::MemoryAccessType::READ) {
163+
continue;
164+
}
165+
read_s << arg_group.refs;
166+
}
167+
std::cout << std::setw(24) << read_s.str();
168+
169+
std::stringstream write_s;
170+
for (const ArgGroup& arg_group : node->args_) {
171+
if (arg_group.access != api::MemoryAccessType::WRITE) {
172+
continue;
173+
}
174+
write_s << arg_group.refs;
175+
}
176+
std::cout << std::setw(24) << write_s.str();
177+
178+
std::cout << std::endl;
179+
180+
node_idx++;
181+
}
182+
}
183+
184+
} // namespace vkcompute

backends/vulkan/test/vulkan_compute_api_test.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1024,6 +1024,8 @@ void test_mm(
10241024

10251025
out.staging = graph.set_output_tensor(out.value);
10261026

1027+
graph.print_readable();
1028+
10271029
graph.prepare();
10281030
graph.encode_prepack();
10291031
graph.prepack();

0 commit comments

Comments
 (0)