Skip to content

Commit bd26dbf

Browse files
dbortfacebook-github-bot
authored andcommitted
Implement operator<<() for EValue (#479)
Summary: Pull Request resolved: #479 Add a helper utility to print `EValue`s. Doesn't format multidimensional data like PyTorch does, but we can improve that in the future if we want to. ghstack-source-id: 201868939 exported-using-ghexport Reviewed By: JacobSzwejbka Differential Revision: D49574853 fbshipit-source-id: c8757790fb260785cd8a8c39e1f691d097f9e5b5
1 parent 9b384a4 commit bd26dbf

File tree

7 files changed

+1090
-0
lines changed

7 files changed

+1090
-0
lines changed

extension/evalue_util/TARGETS

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# Any targets that should be shared between fbcode and xplat must be defined in
2+
# targets.bzl. This file can contain fbcode-only targets.
3+
4+
load(":targets.bzl", "define_common_targets")
5+
6+
oncall("executorch")
7+
8+
define_common_targets()
Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/extension/evalue_util/print_evalue.h>
10+
11+
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
12+
13+
#include <cmath>
14+
#include <iomanip>
15+
#include <ostream>
16+
#include <sstream>
17+
18+
namespace torch {
19+
namespace executor {
20+
21+
namespace {
22+
23+
/// The default number of first/last list items to print before eliding.
24+
constexpr size_t kDefaultEdgeItems = 3;
25+
26+
/// Returns a globally unique "iword" index that we can use to store the current
27+
/// "edge items" count on arbitrary streams.
28+
int get_edge_items_xalloc() {
29+
// Wrapping this in a function avoids a -Wglobal-constructors warning.
30+
static const int xalloc = std::ios_base::xalloc();
31+
return xalloc;
32+
}
33+
34+
/// Returns the number of "edge items" to print at the beginning and end of
35+
/// lists when using the provided stream.
36+
long get_stream_edge_items(std::ostream& os) {
37+
long edge_items = os.iword(get_edge_items_xalloc());
38+
return edge_items <= 0 ? kDefaultEdgeItems : edge_items;
39+
}
40+
41+
void print_double(std::ostream& os, double value) {
42+
if (std::isfinite(value)) {
43+
// Mimic PyTorch by printing a trailing dot when the float value is
44+
// integral, to distinguish from actual integers.
45+
bool add_dot = false;
46+
if (value == -0.0) {
47+
// Special case that won't be detected by a comparison with int.
48+
add_dot = true;
49+
} else {
50+
std::ostringstream oss_float;
51+
oss_float << value;
52+
std::ostringstream oss_int;
53+
oss_int << static_cast<int64_t>(value);
54+
if (oss_float.str() == oss_int.str()) {
55+
add_dot = true;
56+
}
57+
}
58+
if (add_dot) {
59+
os << value << ".";
60+
} else {
61+
os << value;
62+
}
63+
} else {
64+
// Infinity or NaN.
65+
os << value;
66+
}
67+
}
68+
69+
template <class T>
70+
void print_scalar_list(
71+
std::ostream& os,
72+
exec_aten::ArrayRef<T> list,
73+
bool print_length = true,
74+
bool elide_inner_items = true) {
75+
long edge_items = elide_inner_items ? get_stream_edge_items(os)
76+
: std::numeric_limits<long>::max();
77+
if (print_length) {
78+
os << "(len=" << list.size() << ")";
79+
}
80+
// TODO(T159700776): Wrap at a specified number of columns.
81+
os << "[";
82+
for (size_t i = 0; i < list.size(); ++i) {
83+
os << EValue(exec_aten::Scalar(list[i]));
84+
if (i < list.size() - 1) {
85+
os << ", ";
86+
}
87+
if (i + 1 == edge_items && i + edge_items + 1 < list.size()) {
88+
os << "..., ";
89+
i = list.size() - edge_items - 1;
90+
}
91+
}
92+
os << "]";
93+
}
94+
95+
void print_tensor(std::ostream& os, exec_aten::Tensor tensor) {
96+
os << "tensor(sizes=";
97+
// Always print every element of the sizes list.
98+
print_scalar_list(
99+
os, tensor.sizes(), /*print_length=*/false, /*elide_inner_items=*/false);
100+
os << ", ";
101+
102+
// Print the data as a one-dimensional list.
103+
//
104+
// TODO(T159700776): Print dim_order and strides when they have non-default
105+
// values.
106+
//
107+
// TODO(T159700776): Format multidimensional data like numpy/PyTorch does.
108+
// https://github.com/pytorch/pytorch/blob/main/torch/_tensor_str.py
109+
#define PRINT_TENSOR_DATA(ctype, dtype) \
110+
case ScalarType::dtype: \
111+
print_scalar_list( \
112+
os, \
113+
ArrayRef<ctype>(tensor.data_ptr<ctype>(), tensor.numel()), \
114+
/*print_length=*/false); \
115+
break;
116+
117+
switch (tensor.scalar_type()) {
118+
ET_FORALL_REAL_TYPES_AND(Bool, PRINT_TENSOR_DATA)
119+
default:
120+
os << "[<unhandled scalar type " << (int)tensor.scalar_type() << ">]";
121+
}
122+
os << ")";
123+
124+
#undef PRINT_TENSOR_DATA
125+
}
126+
127+
void print_tensor_list(
128+
std::ostream& os,
129+
exec_aten::ArrayRef<exec_aten::Tensor> list) {
130+
os << "(len=" << list.size() << ")[";
131+
for (size_t i = 0; i < list.size(); ++i) {
132+
if (list.size() > 1) {
133+
os << "\n [" << i << "]: ";
134+
}
135+
print_tensor(os, list[i]);
136+
if (list.size() > 1) {
137+
os << ",";
138+
}
139+
}
140+
if (list.size() > 1) {
141+
os << "\n";
142+
}
143+
os << "]";
144+
}
145+
146+
void print_list_optional_tensor(
147+
std::ostream& os,
148+
exec_aten::ArrayRef<exec_aten::optional<exec_aten::Tensor>> list) {
149+
os << "(len=" << list.size() << ")[";
150+
for (size_t i = 0; i < list.size(); ++i) {
151+
if (list.size() > 1) {
152+
os << "\n [" << i << "]: ";
153+
}
154+
if (list[i].has_value()) {
155+
print_tensor(os, list[i].value());
156+
} else {
157+
os << "None";
158+
}
159+
if (list.size() > 1) {
160+
os << ",";
161+
}
162+
}
163+
if (list.size() > 1) {
164+
os << "\n";
165+
}
166+
os << "]";
167+
}
168+
169+
} // namespace
170+
171+
std::ostream& operator<<(std::ostream& os, const EValue& value) {
172+
switch (value.tag) {
173+
case Tag::None:
174+
os << "None";
175+
break;
176+
case Tag::Bool:
177+
if (value.toBool()) {
178+
os << "True";
179+
} else {
180+
os << "False";
181+
}
182+
break;
183+
case Tag::Int:
184+
os << value.toInt();
185+
break;
186+
case Tag::Double:
187+
print_double(os, value.toDouble());
188+
break;
189+
case Tag::String: {
190+
auto str = value.toString();
191+
os << std::quoted(std::string(str.data(), str.size()));
192+
} break;
193+
case Tag::Tensor:
194+
print_tensor(os, value.toTensor());
195+
break;
196+
case Tag::ListBool:
197+
print_scalar_list(os, value.toBoolList());
198+
break;
199+
case Tag::ListInt:
200+
print_scalar_list(os, value.toIntList());
201+
break;
202+
case Tag::ListDouble:
203+
print_scalar_list(os, value.toDoubleList());
204+
break;
205+
case Tag::ListTensor:
206+
print_tensor_list(os, value.toTensorList());
207+
break;
208+
case Tag::ListOptionalTensor:
209+
print_list_optional_tensor(os, value.toListOptionalTensor());
210+
break;
211+
default:
212+
os << "<Unknown EValue tag " << static_cast<int>(value.tag) << ">";
213+
break;
214+
}
215+
return os;
216+
}
217+
218+
namespace util {
219+
220+
void evalue_edge_items::set_edge_items(std::ostream& os, long edge_items) {
221+
os.iword(get_edge_items_xalloc()) = edge_items;
222+
}
223+
224+
} // namespace util
225+
226+
} // namespace executor
227+
} // namespace torch

extension/evalue_util/print_evalue.h

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <ostream>
12+
13+
#include <executorch/runtime/core/evalue.h>
14+
15+
namespace torch {
16+
namespace executor {
17+
18+
/**
19+
* Prints an Evalue to a stream.
20+
*/
21+
std::ostream& operator<<(std::ostream& os, const EValue& value);
22+
// Note that this must be declared in the same namespace as EValue.
23+
24+
namespace util {
25+
26+
/**
27+
* Sets the number of "edge items" when printing EValue lists to a stream.
28+
*
29+
* The edge item count is used to elide inner elements from large lists, and
30+
* like core PyTorch defaults to 3.
31+
*
32+
* For example,
33+
* ```
34+
* os << torch::executor::util::evalue_edge_items(3) << evalue_int_list << "\n";
35+
* os << torch::executor::util::evalue_edge_items(1) << evalue_int_list << "\n";
36+
* ```
37+
* will print the same list with three edge items, then with only one edge item:
38+
* ```
39+
* [0, 1, 2, ..., 6, 7, 8]
40+
* [0, ..., 8]
41+
* ```
42+
* This setting is sticky, and will affect all subsequent evalues printed to the
43+
* affected stream until the value is changed again.
44+
*
45+
* @param[in] os The stream to modify.
46+
* @param[in] edge_items The number of "edge items" to print at the beginning
47+
* and end of a list before eliding inner elements. If zero or negative,
48+
* uses the default number of edge items.
49+
*/
50+
class evalue_edge_items final {
51+
// See https://stackoverflow.com/a/29337924 for other examples of stream
52+
// manipulators like this.
53+
public:
54+
explicit evalue_edge_items(long edge_items)
55+
: edge_items_(edge_items < 0 ? 0 : edge_items) {}
56+
57+
friend std::ostream& operator<<(
58+
std::ostream& os,
59+
const evalue_edge_items& e) {
60+
set_edge_items(os, e.edge_items_);
61+
return os;
62+
}
63+
64+
private:
65+
static void set_edge_items(std::ostream& os, long edge_items);
66+
67+
const long edge_items_;
68+
};
69+
70+
} // namespace util
71+
} // namespace executor
72+
} // namespace torch

extension/evalue_util/targets.bzl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
3+
def define_common_targets():
4+
"""Defines targets that should be shared between fbcode and xplat.
5+
6+
The directory containing this targets.bzl file should also contain both
7+
TARGETS and BUCK files that call this function.
8+
"""
9+
10+
for aten_mode in (True, False):
11+
aten_suffix = ("_aten" if aten_mode else "")
12+
13+
runtime.cxx_library(
14+
name = "print_evalue" + aten_suffix,
15+
srcs = ["print_evalue.cpp"],
16+
exported_headers = ["print_evalue.h"],
17+
visibility = ["@EXECUTORCH_CLIENTS"],
18+
exported_deps = [
19+
"//executorch/runtime/core:evalue" + aten_suffix,
20+
],
21+
deps = [
22+
"//executorch/runtime/core/exec_aten/util:scalar_type_util" + aten_suffix,
23+
],
24+
)

extension/evalue_util/test/TARGETS

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# Any targets that should be shared between fbcode and xplat must be defined in
2+
# targets.bzl. This file can contain fbcode-only targets.
3+
4+
load(":targets.bzl", "define_common_targets")
5+
6+
define_common_targets()

0 commit comments

Comments
 (0)