Skip to content

Commit 3c512ca

Browse files
authored
Merge pull request #1043 from IntelPython/print-dtype-bug-fix
Fixed spacing of dtype string in array printing
2 parents 1916370 + da9b756 commit 3c512ca

File tree

2 files changed

+28
-1
lines changed

2 files changed

+28
-1
lines changed

dpctl/tensor/_print.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,11 @@ def _usm_ndarray_repr(x, line_width=None, precision=None, suppress=None):
316316
dtype_str = "dtype={}".format(x.dtype.name)
317317
bottom_len = len(s) - (s.rfind("\n") + 1)
318318
next_line = bottom_len + len(dtype_str) + 1 > line_width
319-
dtype_str = ",\n" + dtype_str if next_line else ", " + dtype_str
319+
dtype_str = (
320+
",\n" + " " * len(prefix) + dtype_str
321+
if next_line
322+
else ", " + dtype_str
323+
)
320324
else:
321325
dtype_str = ""
322326

dpctl/tests/test_usm_ndarray_print.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,16 @@ def test_print_repr(self):
211211
x = dpt.arange(4, dtype="i4", sycl_queue=q)
212212
assert repr(x) == "usm_ndarray([0, 1, 2, 3], dtype=int32)"
213213

214+
dpt.set_print_options(linewidth=1)
215+
np.testing.assert_equal(
216+
repr(x),
217+
"usm_ndarray([0,"
218+
"\n 1,"
219+
"\n 2,"
220+
"\n 3],"
221+
"\n dtype=int32)",
222+
)
223+
214224
def test_print_repr_abbreviated(self):
215225
q = get_queue_or_skip()
216226

@@ -237,6 +247,19 @@ def test_print_repr_abbreviated(self):
237247
"\n [6, ..., 8]], dtype=int32)",
238248
)
239249

250+
dpt.set_print_options(linewidth=1)
251+
np.testing.assert_equal(
252+
repr(y),
253+
"usm_ndarray([[0,"
254+
"\n ...,"
255+
"\n 2],"
256+
"\n ...,"
257+
"\n [6,"
258+
"\n ...,"
259+
"\n 8]],"
260+
"\n dtype=int32)",
261+
)
262+
240263
@pytest.mark.parametrize(
241264
"dtype",
242265
[

0 commit comments

Comments
 (0)