Skip to content

Commit a3f8141

Browse files
authored
Merge pull request #1013 from IntelPython/print-feature
Implemented printing for usm_ndarrays
2 parents bc9c5be + 23b3311 commit a3f8141

File tree

4 files changed

+612
-0
lines changed

4 files changed

+612
-0
lines changed

dpctl/tensor/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,11 @@
5858
squeeze,
5959
stack,
6060
)
61+
from dpctl.tensor._print import (
62+
get_print_options,
63+
print_options,
64+
set_print_options,
65+
)
6166
from dpctl.tensor._reshape import reshape
6267
from dpctl.tensor._usmarray import usm_ndarray
6368

@@ -129,4 +134,7 @@
129134
"can_cast",
130135
"result_type",
131136
"meshgrid",
137+
"get_print_options",
138+
"set_print_options",
139+
"print_options",
132140
]

dpctl/tensor/_print.py

Lines changed: 323 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,323 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2022 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import contextlib
18+
import operator
19+
20+
import numpy as np
21+
22+
import dpctl.tensor as dpt
23+
24+
__doc__ = "Print functions for :class:`dpctl.tensor.usm_ndarray`."
25+
26+
_print_options = {
27+
"linewidth": 75,
28+
"edgeitems": 3,
29+
"threshold": 1000,
30+
"precision": 8,
31+
"floatmode": "maxprec",
32+
"suppress": False,
33+
"nanstr": "nan",
34+
"infstr": "inf",
35+
"sign": "-",
36+
}
37+
38+
39+
def _options_dict(
40+
linewidth=None,
41+
edgeitems=None,
42+
threshold=None,
43+
precision=None,
44+
floatmode=None,
45+
suppress=None,
46+
nanstr=None,
47+
infstr=None,
48+
sign=None,
49+
numpy=False,
50+
):
51+
if numpy:
52+
numpy_options = np.get_printoptions()
53+
options = {k: numpy_options[k] for k in _print_options.keys()}
54+
else:
55+
options = _print_options.copy()
56+
57+
if suppress:
58+
options["suppress"] = True
59+
60+
local = dict(locals().items())
61+
for int_arg in ["linewidth", "precision", "threshold", "edgeitems"]:
62+
val = local[int_arg]
63+
if val is not None:
64+
options[int_arg] = operator.index(val)
65+
66+
for str_arg in ["nanstr", "infstr"]:
67+
val = local[str_arg]
68+
if val is not None:
69+
if not isinstance(val, str):
70+
raise TypeError(
71+
"`{}` ".format(str_arg) + "must be of `string` type."
72+
)
73+
options[str_arg] = val
74+
75+
signs = ["-", "+", " "]
76+
if sign is not None:
77+
if sign not in signs:
78+
raise ValueError(
79+
"`sign` must be one of"
80+
+ ", ".join("`{}`".format(s) for s in signs)
81+
)
82+
options["sign"] = sign
83+
84+
floatmodes = ["fixed", "unique", "maxprec", "maxprec_equal"]
85+
if floatmode is not None:
86+
if floatmode not in floatmodes:
87+
raise ValueError(
88+
"`floatmode` must be one of"
89+
+ ", ".join("`{}`".format(m) for m in floatmodes)
90+
)
91+
options["floatmode"] = floatmode
92+
93+
return options
94+
95+
96+
def set_print_options(
97+
linewidth=None,
98+
edgeitems=None,
99+
threshold=None,
100+
precision=None,
101+
floatmode=None,
102+
suppress=None,
103+
nanstr=None,
104+
infstr=None,
105+
sign=None,
106+
numpy=False,
107+
):
108+
"""
109+
set_print_options(linewidth=None, edgeitems=None, threshold=None,
110+
precision=None, floatmode=None, suppress=None, nanstr=None,
111+
infstr=None, sign=None, numpy=False)
112+
113+
Set options for printing ``dpctl.tensor.usm_ndarray`` class.
114+
115+
Args:
116+
linewidth (int, optional): Number of characters printed per line.
117+
Raises `TypeError` if linewidth is not an integer.
118+
Default: `75`.
119+
edgeitems (int, optional): Number of elements at the beginning and end
120+
when the printed array is abbreviated.
121+
Raises `TypeError` if edgeitems is not an integer.
122+
Default: `3`.
123+
threshold (int, optional): Number of elements that triggers array
124+
abbreviation.
125+
Raises `TypeError` if threshold is not an integer.
126+
Default: `1000`.
127+
precision (int or None, optional): Number of digits printed for
128+
floating point numbers.
129+
Raises `TypeError` if precision is not an integer.
130+
Default: `8`.
131+
floatmode (str, optional): Controls how floating point
132+
numbers are interpreted.
133+
134+
`"fixed:`: Always prints exactly `precision` digits.
135+
`"unique"`: Ignores precision, prints the number of
136+
digits necessary to uniquely specify each number.
137+
`"maxprec"`: Prints `precision` digits or fewer,
138+
if fewer will uniquely represent a number.
139+
`"maxprec_equal"`: Prints an equal number of digits
140+
for each number. This number is `precision` digits or fewer,
141+
if fewer will uniquely represent each number.
142+
Raises `ValueError` if floatmode is not one of
143+
`fixed`, `unique`, `maxprec`, or `maxprec_equal`.
144+
Default: "maxprec_equal"
145+
suppress (bool, optional): If `True,` numbers equal to zero
146+
in the current precision will print as zero.
147+
Default: `False`.
148+
nanstr (str, optional): String used to repesent nan.
149+
Raises `TypeError` if nanstr is not a string.
150+
Default: `"nan"`.
151+
infstr (str, optional): String used to represent infinity.
152+
Raises `TypeError` if infstr is not a string.
153+
Default: `"inf"`.
154+
sign (str, optional): Controls the sign of floating point
155+
numbers.
156+
`"-"`: Omit the sign of positive numbers.
157+
`"+"`: Always print the sign of positive numbers.
158+
`" "`: Always print a whitespace in place of the
159+
sign of positive numbers.
160+
Raises `ValueError` if sign is not one of
161+
`"-"`, `"+"`, or `" "`.
162+
Default: `"-"`.
163+
numpy (bool, optional): If `True,` then before other specified print
164+
options are set, a dictionary of Numpy's print options
165+
will be used to initialize dpctl's print options.
166+
Default: "False"
167+
"""
168+
options = _options_dict(
169+
linewidth=linewidth,
170+
edgeitems=edgeitems,
171+
threshold=threshold,
172+
precision=precision,
173+
floatmode=floatmode,
174+
suppress=suppress,
175+
nanstr=nanstr,
176+
infstr=infstr,
177+
sign=sign,
178+
numpy=numpy,
179+
)
180+
_print_options.update(options)
181+
182+
183+
def get_print_options():
184+
"""
185+
get_print_options() -> dict
186+
187+
Returns a copy of current options for printing
188+
``dpctl.tensor.usm_ndarray`` class.
189+
190+
Options:
191+
- "linewidth" : int, default 75
192+
- "edgeitems" : int, default 3
193+
- "threshold" : int, default 1000
194+
- "precision" : int, default 8
195+
- "floatmode" : str, default "maxprec_equal"
196+
- "suppress" : bool, default False
197+
- "nanstr" : str, default "nan"
198+
- "infstr" : str, default "inf"
199+
- "sign" : str, default "-"
200+
"""
201+
return _print_options.copy()
202+
203+
204+
@contextlib.contextmanager
205+
def print_options(*args, **kwargs):
206+
"""
207+
Context manager for print options.
208+
209+
Set print options for the scope of a `with` block.
210+
`as` yields dictionary of print options.
211+
"""
212+
options = dpt.get_print_options()
213+
try:
214+
dpt.set_print_options(*args, **kwargs)
215+
yield dpt.get_print_options()
216+
finally:
217+
dpt.set_print_options(**options)
218+
219+
220+
def _nd_corners(x, edge_items, slices=()):
221+
axes_reduced = len(slices)
222+
if axes_reduced == x.ndim:
223+
return x[slices]
224+
225+
if x.shape[axes_reduced] > 2 * edge_items:
226+
return dpt.concat(
227+
(
228+
_nd_corners(
229+
x, edge_items, slices + (slice(None, edge_items, None),)
230+
),
231+
_nd_corners(
232+
x, edge_items, slices + (slice(-edge_items, None, None),)
233+
),
234+
),
235+
axis=axes_reduced,
236+
)
237+
else:
238+
return _nd_corners(x, edge_items, slices + (slice(None, None, None),))
239+
240+
241+
def _usm_ndarray_str(
242+
x,
243+
line_width=None,
244+
edge_items=None,
245+
threshold=None,
246+
precision=None,
247+
floatmode=None,
248+
suppress=None,
249+
sign=None,
250+
numpy=False,
251+
separator=" ",
252+
prefix="",
253+
suffix="",
254+
):
255+
if not isinstance(x, dpt.usm_ndarray):
256+
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
257+
258+
options = get_print_options()
259+
options.update(
260+
_options_dict(
261+
linewidth=line_width,
262+
edgeitems=edge_items,
263+
threshold=threshold,
264+
precision=precision,
265+
floatmode=floatmode,
266+
suppress=suppress,
267+
sign=sign,
268+
numpy=numpy,
269+
)
270+
)
271+
272+
threshold = options["threshold"]
273+
edge_items = options["edgeitems"]
274+
275+
if x.size > threshold:
276+
# need edge_items + 1 elements for np.array2string to abbreviate
277+
data = dpt.asnumpy(_nd_corners(x, edge_items + 1))
278+
options["threshold"] = 0
279+
else:
280+
data = dpt.asnumpy(x)
281+
with np.printoptions(**options):
282+
s = np.array2string(
283+
data, separator=separator, prefix=prefix, suffix=suffix
284+
)
285+
return s
286+
287+
288+
def _usm_ndarray_repr(x, line_width=None, precision=None, suppress=None):
289+
if not isinstance(x, dpt.usm_ndarray):
290+
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
291+
292+
if line_width is None:
293+
line_width = _print_options["linewidth"]
294+
295+
show_dtype = x.dtype not in [
296+
dpt.bool,
297+
dpt.int64,
298+
dpt.float64,
299+
dpt.complex128,
300+
]
301+
302+
prefix = "usm_ndarray("
303+
suffix = ")"
304+
305+
s = _usm_ndarray_str(
306+
x,
307+
line_width=line_width,
308+
precision=precision,
309+
suppress=suppress,
310+
separator=", ",
311+
prefix=prefix,
312+
suffix=suffix,
313+
)
314+
315+
if show_dtype:
316+
dtype_str = "dtype={}".format(x.dtype.name)
317+
bottom_len = len(s) - (s.rfind("\n") + 1)
318+
next_line = bottom_len + len(dtype_str) + 1 > line_width
319+
dtype_str = ",\n" + dtype_str if next_line else ", " + dtype_str
320+
else:
321+
dtype_str = ""
322+
323+
return prefix + s + dtype_str + suffix

dpctl/tensor/_usmarray.pyx

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import dpctl
2626
import dpctl.memory as dpmem
2727

2828
from ._device import Device
29+
from ._print import _usm_ndarray_repr, _usm_ndarray_str
2930

3031
from cpython.mem cimport PyMem_Free
3132
from cpython.tuple cimport PyTuple_New, PyTuple_SetItem
@@ -1131,6 +1132,12 @@ cdef class usm_ndarray:
11311132
self.__setitem__(Ellipsis, res)
11321133
return self
11331134

1135+
def __str__(self):
1136+
return _usm_ndarray_str(self)
1137+
1138+
def __repr__(self):
1139+
return _usm_ndarray_repr(self)
1140+
11341141

11351142
cdef usm_ndarray _real_view(usm_ndarray ary):
11361143
"""

0 commit comments

Comments
 (0)