Skip to content

Commit 9aa1842

Browse files
committed
Improve dpctl.tensor.full error when fill_value is of an invalid type
pybind11 raises RuntimeError when a value cannot be cast to a type, i.e., a string to an integer. Now when scalar is not an array, a check is performed, and TypeError is raised
1 parent 2023622 commit 9aa1842

File tree

2 files changed

+13
-0
lines changed

2 files changed

+13
-0
lines changed

dpctl/tensor/_ctors.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# limitations under the License.
1616

1717
import operator
18+
from numbers import Number
1819

1920
import numpy as np
2021

@@ -1113,6 +1114,11 @@ def full(
11131114

11141115
sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
11151116
usm_type = usm_type if usm_type is not None else "device"
1117+
if not isinstance(fill_value, Number):
1118+
raise TypeError(
1119+
"`full` array cannot be constructed with value of type "
1120+
f"{type(fill_value)}"
1121+
)
11161122
dtype = _get_dtype(dtype, sycl_queue, ref_type=type(fill_value))
11171123
res = dpt.usm_ndarray(
11181124
shape,

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2621,3 +2621,10 @@ def test_setitem_from_numpy_contig():
26212621

26222622
expected = dpt.reshape(dpt.arange(-10, 10, dtype=fp_dt), (4, 5))
26232623
assert dpt.all(dpt.flip(Xdpt, axis=-1) == expected)
2624+
2625+
2626+
def test_full_raises_type_error():
2627+
get_queue_or_skip()
2628+
2629+
with pytest.raises(TypeError):
2630+
dpt.full(1, "0")

0 commit comments

Comments
 (0)