Skip to content

Commit ae989c1

Browse files
committed
[executorch] Implement operator<<() for EValue
Pull Request resolved: #479 Add a helper utility to print `EValue`s. Doesn't format multidimensional data like PyTorch does, but we can improve that in the future if we want to. ghstack-source-id: 201845601 @exported-using-ghexport Differential Revision: [D49574853](https://our.internmc.facebook.com/intern/diff/D49574853/)
1 parent a89017d commit ae989c1

File tree

7 files changed

+1082
-0
lines changed

7 files changed

+1082
-0
lines changed

extension/evalue_util/TARGETS

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# Any targets that should be shared between fbcode and xplat must be defined in
2+
# targets.bzl. This file can contain fbcode-only targets.
3+
4+
load(":targets.bzl", "define_common_targets")
5+
6+
oncall("executorch")
7+
8+
define_common_targets()
Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/extension/evalue_util/print_evalue.h>
10+
11+
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
12+
13+
#include <cmath>
14+
#include <iomanip>
15+
#include <ostream>
16+
#include <sstream>
17+
18+
namespace torch {
19+
namespace executor {
20+
21+
namespace {
22+
23+
/// The default number of first/last list items to print before eliding.
24+
constexpr size_t kDefaultEdgeItems = 3;
25+
26+
/// Init-time allocation of a globally unique "iword" stream index that we can
27+
/// use to store the current "edge items" count on arbitrary streams.
28+
const int kPrintEvalueEdgeItemsXalloc = std::ios_base::xalloc();
29+
30+
void print_double(std::ostream& os, double value) {
31+
if (std::isfinite(value)) {
32+
// Mimic PyTorch by printing a trailing dot when the float value is
33+
// integral, to distinguish from actual integers.
34+
bool add_dot = false;
35+
if (value == -0.0) {
36+
// Special case that won't be detected by a comparison with int.
37+
add_dot = true;
38+
} else {
39+
std::ostringstream oss_float;
40+
oss_float << value;
41+
std::ostringstream oss_int;
42+
oss_int << static_cast<int64_t>(value);
43+
if (oss_float.str() == oss_int.str()) {
44+
add_dot = true;
45+
}
46+
}
47+
if (add_dot) {
48+
os << value << ".";
49+
} else {
50+
os << value;
51+
}
52+
} else {
53+
// Infinity or NaN.
54+
os << value;
55+
}
56+
}
57+
58+
template <class T>
59+
void print_scalar_list(
60+
std::ostream& os,
61+
exec_aten::ArrayRef<T> list,
62+
bool print_length = true,
63+
bool elide_inner_items = true) {
64+
long edge_items;
65+
if (elide_inner_items) {
66+
edge_items = os.iword(kPrintEvalueEdgeItemsXalloc);
67+
if (edge_items <= 0) {
68+
edge_items = kDefaultEdgeItems;
69+
}
70+
} else {
71+
edge_items = std::numeric_limits<long>::max();
72+
}
73+
74+
if (print_length) {
75+
os << "(len=" << list.size() << ")";
76+
}
77+
// TODO(T159700776): Wrap at a specified number of columns.
78+
os << "[";
79+
for (size_t i = 0; i < list.size(); ++i) {
80+
os << EValue(exec_aten::Scalar(list[i]));
81+
if (i < list.size() - 1) {
82+
os << ", ";
83+
}
84+
if (i + 1 == edge_items && i + edge_items + 1 < list.size()) {
85+
os << "..., ";
86+
i = list.size() - edge_items - 1;
87+
}
88+
}
89+
os << "]";
90+
}
91+
92+
void print_tensor(std::ostream& os, exec_aten::Tensor tensor) {
93+
os << "tensor(sizes=";
94+
// Always print every element of the sizes list.
95+
print_scalar_list(
96+
os, tensor.sizes(), /*print_length=*/false, /*elide_inner_items=*/false);
97+
os << ", ";
98+
99+
// Print the data as a one-dimensional list.
100+
//
101+
// TODO(T159700776): Print dim_order and strides when they have non-default
102+
// values.
103+
//
104+
// TODO(T159700776): Format multidimensional data like numpy/PyTorch does.
105+
// https://github.com/pytorch/pytorch/blob/main/torch/_tensor_str.py
106+
#define PRINT_TENSOR_DATA(ctype, dtype) \
107+
case ScalarType::dtype: \
108+
print_scalar_list( \
109+
os, \
110+
ArrayRef<ctype>(tensor.data_ptr<ctype>(), tensor.numel()), \
111+
/*print_length=*/false); \
112+
break;
113+
114+
switch (tensor.scalar_type()) {
115+
ET_FORALL_REAL_TYPES_AND(Bool, PRINT_TENSOR_DATA)
116+
default:
117+
os << "[<unhandled scalar type " << (int)tensor.scalar_type() << ">]";
118+
}
119+
os << ")";
120+
121+
#undef PRINT_TENSOR_DATA
122+
}
123+
124+
void print_tensor_list(
125+
std::ostream& os,
126+
exec_aten::ArrayRef<exec_aten::Tensor> list) {
127+
os << "(len=" << list.size() << ")[";
128+
for (size_t i = 0; i < list.size(); ++i) {
129+
if (list.size() > 1) {
130+
os << "\n [" << i << "]: ";
131+
}
132+
print_tensor(os, list[i]);
133+
if (list.size() > 1) {
134+
os << ",";
135+
}
136+
}
137+
if (list.size() > 1) {
138+
os << "\n";
139+
}
140+
os << "]";
141+
}
142+
143+
void print_list_optional_tensor(
144+
std::ostream& os,
145+
exec_aten::ArrayRef<exec_aten::optional<exec_aten::Tensor>> list) {
146+
os << "(len=" << list.size() << ")[";
147+
for (size_t i = 0; i < list.size(); ++i) {
148+
if (list.size() > 1) {
149+
os << "\n [" << i << "]: ";
150+
}
151+
if (list[i].has_value()) {
152+
print_tensor(os, list[i].value());
153+
} else {
154+
os << "None";
155+
}
156+
if (list.size() > 1) {
157+
os << ",";
158+
}
159+
}
160+
if (list.size() > 1) {
161+
os << "\n";
162+
}
163+
os << "]";
164+
}
165+
166+
} // namespace
167+
168+
std::ostream& operator<<(std::ostream& os, const EValue& value) {
169+
switch (value.tag) {
170+
case Tag::None:
171+
os << "None";
172+
break;
173+
case Tag::Bool:
174+
if (value.toBool()) {
175+
os << "True";
176+
} else {
177+
os << "False";
178+
}
179+
break;
180+
case Tag::Int:
181+
os << value.toInt();
182+
break;
183+
case Tag::Double:
184+
print_double(os, value.toDouble());
185+
break;
186+
case Tag::String: {
187+
auto str = value.toString();
188+
os << std::quoted(std::string(str.data(), str.size()));
189+
} break;
190+
case Tag::Tensor:
191+
print_tensor(os, value.toTensor());
192+
break;
193+
case Tag::ListBool:
194+
print_scalar_list(os, value.toBoolList());
195+
break;
196+
case Tag::ListInt:
197+
print_scalar_list(os, value.toIntList());
198+
break;
199+
case Tag::ListDouble:
200+
print_scalar_list(os, value.toDoubleList());
201+
break;
202+
case Tag::ListTensor:
203+
print_tensor_list(os, value.toTensorList());
204+
break;
205+
case Tag::ListOptionalTensor:
206+
print_list_optional_tensor(os, value.toListOptionalTensor());
207+
break;
208+
default:
209+
os << "<Unknown EValue tag " << static_cast<int>(value.tag) << ">";
210+
break;
211+
}
212+
return os;
213+
}
214+
215+
namespace util {
216+
// Lets us avoid exposing kPrintEvalueEdgeItemsXalloc in the header.
217+
long evalue_edge_items::xalloc_index() {
218+
return kPrintEvalueEdgeItemsXalloc;
219+
}
220+
} // namespace util
221+
222+
} // namespace executor
223+
} // namespace torch

extension/evalue_util/print_evalue.h

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <ostream>
12+
13+
#include <executorch/runtime/core/evalue.h>
14+
15+
namespace torch {
16+
namespace executor {
17+
18+
/**
19+
* Prints an Evalue to a stream.
20+
*/
21+
std::ostream& operator<<(std::ostream& os, const EValue& value);
22+
// Note that this must be declared in the same namespace as EValue.
23+
24+
namespace util {
25+
26+
/**
27+
* Sets the number of "edge items" when printing EValue lists to a stream.
28+
*
29+
* The edge item count is used to elide inner elements from large lists, and
30+
* like core PyTorch defaults to 3.
31+
*
32+
* For example,
33+
* ```
34+
* os << torch::executor::util::evalue_edge_items(3) << evalue_int_list << "\n";
35+
* os << torch::executor::util::evalue_edge_items(1) << evalue_int_list << "\n";
36+
* ```
37+
* will print the same list with three edge items, then with only one edge item:
38+
* ```
39+
* [0, 1, 2, ..., 6, 7, 8]
40+
* [0, ..., 8]
41+
* ```
42+
* This setting is sticky, and will affect all subsequent evalues printed to the
43+
* affected stream until the value is changed again.
44+
*
45+
* @param[in] os The stream to modify.
46+
* @param[in] edge_items The number of "edge items" to print at the beginning
47+
* and end of a list before eliding inner elements. If zero or negative,
48+
* uses the default number of edge items.
49+
*/
50+
class evalue_edge_items final {
51+
// See https://stackoverflow.com/a/29337924 for other examples of stream
52+
// manipulators like this.
53+
public:
54+
explicit evalue_edge_items(long edge_items)
55+
: edge_items_(edge_items < 0 ? 0 : edge_items) {}
56+
57+
friend std::ostream& operator<<(
58+
std::ostream& os,
59+
const evalue_edge_items& e) {
60+
os.iword(xalloc_index()) = e.edge_items_;
61+
return os;
62+
}
63+
64+
private:
65+
static long xalloc_index();
66+
67+
const long edge_items_;
68+
};
69+
70+
} // namespace util
71+
} // namespace executor
72+
} // namespace torch

extension/evalue_util/targets.bzl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
3+
def define_common_targets():
4+
"""Defines targets that should be shared between fbcode and xplat.
5+
6+
The directory containing this targets.bzl file should also contain both
7+
TARGETS and BUCK files that call this function.
8+
"""
9+
10+
for aten_mode in (True, False):
11+
aten_suffix = ("_aten" if aten_mode else "")
12+
13+
runtime.cxx_library(
14+
name = "print_evalue" + aten_suffix,
15+
srcs = ["print_evalue.cpp"],
16+
exported_headers = ["print_evalue.h"],
17+
visibility = ["@EXECUTORCH_CLIENTS"],
18+
exported_deps = [
19+
"//executorch/runtime/core:evalue" + aten_suffix,
20+
],
21+
deps = [
22+
"//executorch/runtime/core/exec_aten/util:scalar_type_util" + aten_suffix,
23+
],
24+
)

extension/evalue_util/test/TARGETS

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# Any targets that should be shared between fbcode and xplat must be defined in
2+
# targets.bzl. This file can contain fbcode-only targets.
3+
4+
load(":targets.bzl", "define_common_targets")
5+
6+
define_common_targets()

0 commit comments

Comments
 (0)