Skip to content

Commit 30aa372

Browse files
Implement check_zeros_shape_gesvd
1 parent 4ee5576 commit 30aa372

File tree

3 files changed

+53
-2
lines changed

3 files changed

+53
-2
lines changed

dpnp/backend/extensions/lapack/gesvd.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,14 @@ std::pair<sycl::event, sycl::event>
158158
jobu_val, jobvt_val, expected_a_u_vt_ndim,
159159
expected_s_ndim);
160160

161-
// TODO: check non_zero shape
161+
// Ensure `m` and 'n' are non-zero, otherwise return empty
162+
// events
163+
if (gesvd_utils::check_zeros_shape_gesvd(a_array, out_s, out_u, out_vt,
164+
jobu_val, jobvt_val))
165+
{
166+
// nothing to do
167+
return std::make_pair(sycl::event(), sycl::event());
168+
}
162169

163170
auto array_types = dpctl_td_ns::usm_ndarray_types();
164171
int a_array_type_id =

dpnp/backend/extensions/lapack/gesvd_batch.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,14 @@ std::pair<sycl::event, sycl::event>
222222
jobu_val, jobvt_val, expected_a_u_vt_ndim,
223223
expected_s_ndim);
224224

225-
// TODO: check non_zero shape
225+
// Ensure `batch_size`, `m` and 'n' are non-zero, otherwise return empty
226+
// events
227+
if (gesvd_utils::check_zeros_shape_gesvd(a_array, out_s, out_u, out_vt,
228+
jobu_val, jobvt_val))
229+
{
230+
// nothing to do
231+
return std::make_pair(sycl::event(), sycl::event());
232+
}
226233

227234
auto array_types = dpctl_td_ns::usm_ndarray_types();
228235
int a_array_type_id =

dpnp/backend/extensions/lapack/gesvd_common_utils.hpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
#include "utils/output_validation.hpp"
3333
#include "utils/type_dispatch.hpp"
3434

35+
#include "common_helpers.hpp"
36+
3537
namespace dpnp::extensions::lapack::gesvd_utils
3638
{
3739
namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
@@ -170,6 +172,41 @@ inline void common_gesvd_checks(sycl::queue &exec_q,
170172
}
171173
}
172174

175+
// Checks if the shape of input arrays for gesvd has any non-zero dimension.
176+
inline bool check_zeros_shape_gesvd(dpctl::tensor::usm_ndarray a_array,
177+
dpctl::tensor::usm_ndarray out_s,
178+
dpctl::tensor::usm_ndarray out_u,
179+
dpctl::tensor::usm_ndarray out_vt,
180+
const std::int8_t jobu_val,
181+
const std::int8_t jobvt_val)
182+
{
183+
184+
const int a_array_nd = a_array.get_ndim();
185+
const int out_u_array_nd = out_u.get_ndim();
186+
const int out_s_array_nd = out_s.get_ndim();
187+
const int out_vt_array_nd = out_vt.get_ndim();
188+
189+
const py::ssize_t *a_array_shape = a_array.get_shape_raw();
190+
const py::ssize_t *s_out_shape = out_s.get_shape_raw();
191+
const py::ssize_t *u_out_shape = out_u.get_shape_raw();
192+
const py::ssize_t *vt_out_shape = out_vt.get_shape_raw();
193+
194+
bool is_zeros_shape = helper::check_zeros_shape(a_array_nd, a_array_shape);
195+
if (jobu_val == 'N' && jobvt_val == 'N') {
196+
is_zeros_shape = is_zeros_shape || helper::check_zeros_shape(
197+
out_vt_array_nd, vt_out_shape);
198+
}
199+
else {
200+
is_zeros_shape =
201+
is_zeros_shape ||
202+
helper::check_zeros_shape(out_u_array_nd, s_out_shape) ||
203+
helper::check_zeros_shape(out_s_array_nd, u_out_shape) ||
204+
helper::check_zeros_shape(out_vt_array_nd, vt_out_shape);
205+
}
206+
207+
return is_zeros_shape;
208+
}
209+
173210
inline void handle_lapack_exc(std::int64_t scratchpad_size,
174211
const oneapi::mkl::lapack::exception &e,
175212
std::stringstream &error_msg)

0 commit comments

Comments
 (0)