Skip to content

Commit 7095358

Browse files
committed
Factor out checks for fill value scalar type into _validate_fill_value function
1 parent 320e5e4 commit 7095358

File tree

1 file changed

+17
-20
lines changed

1 file changed

+17
-20
lines changed

dpctl/tensor/_ctors.py

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1038,6 +1038,19 @@ def _cast_fill_val(fill_val, dt):
10381038
return fill_val
10391039

10401040

1041+
def _validate_fill_value(fill_val):
1042+
"""
1043+
Validates that `fill_val` is a numeric or boolean scalar.
1044+
"""
1045+
# TODO: verify if `np.True_` and `np.False_` should be instances of
1046+
# Number in NumPy, like other NumPy scalars and like Python bools
1047+
# check for `np.bool_` separately as NumPy<2 has no `np.bool`
1048+
if not isinstance(fill_val, Number) and not isinstance(fill_val, np.bool_):
1049+
raise TypeError(
1050+
f"array cannot be filled with scalar of type {type(fill_val)}"
1051+
)
1052+
1053+
10411054
def full(
10421055
shape,
10431056
fill_value,
@@ -1111,16 +1124,8 @@ def full(
11111124
sycl_queue=sycl_queue,
11121125
)
11131126
return dpt.copy(dpt.broadcast_to(X, shape), order=order)
1114-
# TODO: verify if `np.True_` and `np.False_` should be instances of
1115-
# Number in NumPy, like other NumPy scalars and like Python bools
1116-
# check for `np.bool_` separately as NumPy<2 has no `np.bool`
1117-
elif not isinstance(fill_value, Number) and not isinstance(
1118-
fill_value, np.bool_
1119-
):
1120-
raise TypeError(
1121-
"`full` array cannot be constructed with value of type "
1122-
f"{type(fill_value)}"
1123-
)
1127+
else:
1128+
_validate_fill_value(fill_value)
11241129

11251130
sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
11261131
usm_type = usm_type if usm_type is not None else "device"
@@ -1491,16 +1496,8 @@ def full_like(
14911496
)
14921497
_manager.add_event_pair(hev, copy_ev)
14931498
return res
1494-
# TODO: verify if `np.True_` and `np.False_` should be instances of
1495-
# Number in NumPy, like other NumPy scalars and like Python bools
1496-
# check for `np.bool_` separately as NumPy<2 has no `np.bool`
1497-
elif not isinstance(fill_value, Number) and not isinstance(
1498-
fill_value, np.bool_
1499-
):
1500-
raise TypeError(
1501-
"`full` array cannot be constructed with value of type "
1502-
f"{type(fill_value)}"
1503-
)
1499+
else:
1500+
_validate_fill_value(fill_value)
15041501

15051502
dtype = _get_dtype(dtype, sycl_queue, ref_type=type(fill_value))
15061503
res = _empty_like_orderK(x, dtype, usm_type, sycl_queue)

0 commit comments

Comments
 (0)