Skip to content

Commit 40be0e3

Browse files
committed
Fix hanging in tril
1 parent c286b52 commit 40be0e3

File tree

3 files changed

+20
-0
lines changed

3 files changed

+20
-0
lines changed

dpnp/backend/extensions/vm/div.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ static sycl::event div_impl(sycl::queue exec_q,
6464
{
6565
type_utils::validate_type_for_device<T>(exec_q);
6666

67+
std::cout << typeid(T).name() << std::endl;
68+
6769
const T* a = reinterpret_cast<const T*>(in_a);
6870
const T* b = reinterpret_cast<const T*>(in_b);
6971
T* y = reinterpret_cast<T*>(out_y);
@@ -169,12 +171,14 @@ std::pair<sycl::event, sycl::event> div(sycl::queue exec_q,
169171
throw py::value_error("Input and outpur arrays must be C-contiguous");
170172
}
171173

174+
std::cout << "dst_typeid = " << int(dst_typeid) << std::endl;
172175
auto div_fn = div_dispatch_vector[dst_typeid];
173176
if (div_fn == nullptr)
174177
{
175178
throw py::value_error("No div implementation defined");
176179
}
177180
sycl::event sum_ev = div_fn(exec_q, src_nelems, src1_data, src2_data, dst_data, depends);
181+
std::cout << "leaving div_fn" << std::endl;
178182

179183
sycl::event ht_ev = dpctl::utils::keep_args_alive(exec_q, {src1, src2, dst}, {sum_ev});
180184
return std::make_pair(ht_ev, sum_ev);

dpnp/dpnp_algo/dpnp_elementwise_common.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,17 @@ def dpnp_divide(x1, x2, out=None, order='K'):
7575
def _call_divide(src1, src2, dst, sycl_queue, depends=[]):
7676
"""A callback to register in BinaryElementwiseFunc class of dpctl.tensor"""
7777

78+
print("_call_divide", sycl_queue)
79+
print("src1 =", src1, type(src1), src1.sycl_queue, src1.device, src1.usm_type, src1.ndim, src1.dtype, src1.shape)
80+
print(src1.__sycl_usm_array_interface__)
81+
print(src1._byte_bounds)
82+
print("src2 =", src2, type(src2), src2.sycl_queue, src2.device, src2.usm_type, src2.ndim, src2.dtype, src2.shape)
83+
print(src2.__sycl_usm_array_interface__)
84+
print(src2._byte_bounds)
85+
print("dst =", dst, type(dst), dst.sycl_queue, dst.device, dst.usm_type, dst.ndim, dst.dtype, dst.shape)
86+
print(dst.__sycl_usm_array_interface__)
87+
print(dst._byte_bounds)
88+
7889
if vmi._can_call_div(sycl_queue, src1, src2, dst):
7990
# call pybind11 extension for div() function from OneMKL VM
8091
return vmi._div(sycl_queue, src1, src2, dst, depends)
@@ -86,5 +97,6 @@ def _call_divide(src1, src2, dst, sycl_queue, depends=[]):
8697
out_usm = None if out is None else dpnp.get_usm_ndarray(out)
8798

8899
func = BinaryElementwiseFunc("divide", ti._divide_result_type, _call_divide, _divide_docstring_)
100+
print("func is done")
89101
res_usm = func(x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order)
90102
return dpnp_array._create_from_usm_ndarray(res_usm)

dpnp/dpnp_array.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,10 @@ def __sub__(self, other):
388388
# '__subclasshook__',
389389

390390
def __truediv__(self, other):
391+
print("__truediv__")
392+
print("self =", self, type(self), self.sycl_queue, self.device, self.usm_type, self.ndim, self.dtype)
393+
print(self.__sycl_usm_array_interface__)
394+
print("other =", other, type(other))
391395
return dpnp.true_divide(self, other)
392396

393397
def __xor__(self, other):

0 commit comments

Comments
 (0)