|
1 | 1 | //*****************************************************************************
|
2 |
| -// Copyright (c) 2016-2020, Intel Corporation |
| 2 | +// Copyright (c) 2016-2023, Intel Corporation |
3 | 3 | // All rights reserved.
|
4 | 4 | //
|
5 | 5 | // Redistribution and use in source and binary forms, with or without
|
|
27 | 27 |
|
28 | 28 | #include <dpnp_iface.hpp>
|
29 | 29 | #include "dpnp_fptr.hpp"
|
| 30 | +#include "dpnp_iterator.hpp" |
30 | 31 | #include "dpnpc_memory_adapter.hpp"
|
31 | 32 | #include "queue_sycl.hpp"
|
32 | 33 |
|
@@ -139,6 +140,258 @@ DPCTLSyclEventRef (*dpnp_argmin_ext_c)(DPCTLSyclQueueRef,
|
139 | 140 | size_t,
|
140 | 141 | const DPCTLEventVectorRef) = dpnp_argmin_c<_DataType, _idx_DataType>;
|
141 | 142 |
|
| 143 | + |
| 144 | +template <typename _DataType_output, typename _DataType_input1, typename _DataType_input2> |
| 145 | +class dpnp_where_c_broadcast_kernel; |
| 146 | + |
| 147 | +template <typename _DataType_output, typename _DataType_input1, typename _DataType_input2> |
| 148 | +class dpnp_where_c_strides_kernel; |
| 149 | + |
| 150 | +template <typename _DataType_output, typename _DataType_input1, typename _DataType_input2> |
| 151 | +class dpnp_where_c_kernel; |
| 152 | + |
| 153 | +template <typename _DataType_output, typename _DataType_input1, typename _DataType_input2> |
| 154 | +DPCTLSyclEventRef dpnp_where_c(DPCTLSyclQueueRef q_ref, |
| 155 | + void* result_out, |
| 156 | + const size_t result_size, |
| 157 | + const size_t result_ndim, |
| 158 | + const shape_elem_type* result_shape, |
| 159 | + const shape_elem_type* result_strides, |
| 160 | + const void* condition_in, |
| 161 | + const size_t condition_size, |
| 162 | + const size_t condition_ndim, |
| 163 | + const shape_elem_type* condition_shape, |
| 164 | + const shape_elem_type* condition_strides, |
| 165 | + const void* input1_in, |
| 166 | + const size_t input1_size, |
| 167 | + const size_t input1_ndim, |
| 168 | + const shape_elem_type* input1_shape, |
| 169 | + const shape_elem_type* input1_strides, |
| 170 | + const void* input2_in, |
| 171 | + const size_t input2_size, |
| 172 | + const size_t input2_ndim, |
| 173 | + const shape_elem_type* input2_shape, |
| 174 | + const shape_elem_type* input2_strides, |
| 175 | + const DPCTLEventVectorRef dep_event_vec_ref) |
| 176 | +{ |
| 177 | + /* avoid warning unused variable*/ |
| 178 | + (void)dep_event_vec_ref; |
| 179 | + |
| 180 | + DPCTLSyclEventRef event_ref = nullptr; |
| 181 | + |
| 182 | + if (!condition_size || !input1_size || !input2_size) |
| 183 | + { |
| 184 | + return event_ref; |
| 185 | + } |
| 186 | + |
| 187 | + sycl::queue q = *(reinterpret_cast<sycl::queue*>(q_ref)); |
| 188 | + |
| 189 | + bool* condition_data = static_cast<bool*>(const_cast<void*>(condition_in)); |
| 190 | + _DataType_input1* input1_data = static_cast<_DataType_input1*>(const_cast<void*>(input1_in)); |
| 191 | + _DataType_input2* input2_data = static_cast<_DataType_input2*>(const_cast<void*>(input2_in)); |
| 192 | + _DataType_output* result = static_cast<_DataType_output*>(result_out); |
| 193 | + |
| 194 | + bool use_broadcasting = !array_equal(input1_shape, input1_ndim, input2_shape, input2_ndim); |
| 195 | + use_broadcasting = use_broadcasting || !array_equal(condition_shape, condition_ndim, input1_shape, input1_ndim); |
| 196 | + use_broadcasting = use_broadcasting || !array_equal(condition_shape, condition_ndim, input2_shape, input2_ndim); |
| 197 | + |
| 198 | + shape_elem_type* condition_shape_offsets = new shape_elem_type[condition_ndim]; |
| 199 | + |
| 200 | + get_shape_offsets_inkernel(condition_shape, condition_ndim, condition_shape_offsets); |
| 201 | + bool use_strides = !array_equal(condition_strides, condition_ndim, condition_shape_offsets, condition_ndim); |
| 202 | + delete[] condition_shape_offsets; |
| 203 | + |
| 204 | + shape_elem_type* input1_shape_offsets = new shape_elem_type[input1_ndim]; |
| 205 | + |
| 206 | + get_shape_offsets_inkernel(input1_shape, input1_ndim, input1_shape_offsets); |
| 207 | + use_strides = use_strides || !array_equal(input1_strides, input1_ndim, input1_shape_offsets, input1_ndim); |
| 208 | + delete[] input1_shape_offsets; |
| 209 | + |
| 210 | + shape_elem_type* input2_shape_offsets = new shape_elem_type[input2_ndim]; |
| 211 | + |
| 212 | + get_shape_offsets_inkernel(input2_shape, input2_ndim, input2_shape_offsets); |
| 213 | + use_strides = use_strides || !array_equal(input2_strides, input2_ndim, input2_shape_offsets, input2_ndim); |
| 214 | + delete[] input2_shape_offsets; |
| 215 | + |
| 216 | + sycl::event event; |
| 217 | + sycl::range<1> gws(result_size); |
| 218 | + |
| 219 | + if (use_broadcasting) |
| 220 | + { |
| 221 | + DPNPC_id<bool>* condition_it; |
| 222 | + const size_t condition_it_it_size_in_bytes = sizeof(DPNPC_id<bool>); |
| 223 | + condition_it = reinterpret_cast<DPNPC_id<bool>*>(dpnp_memory_alloc_c(q_ref, condition_it_it_size_in_bytes)); |
| 224 | + new (condition_it) DPNPC_id<bool>(q_ref, condition_data, condition_shape, condition_strides, condition_ndim); |
| 225 | + |
| 226 | + condition_it->broadcast_to_shape(result_shape, result_ndim); |
| 227 | + |
| 228 | + DPNPC_id<_DataType_input1>* input1_it; |
| 229 | + const size_t input1_it_size_in_bytes = sizeof(DPNPC_id<_DataType_input1>); |
| 230 | + input1_it = reinterpret_cast<DPNPC_id<_DataType_input1>*>(dpnp_memory_alloc_c(q_ref, input1_it_size_in_bytes)); |
| 231 | + new (input1_it) DPNPC_id<_DataType_input1>(q_ref, input1_data, input1_shape, input1_strides, input1_ndim); |
| 232 | + |
| 233 | + input1_it->broadcast_to_shape(result_shape, result_ndim); |
| 234 | + |
| 235 | + DPNPC_id<_DataType_input2>* input2_it; |
| 236 | + const size_t input2_it_size_in_bytes = sizeof(DPNPC_id<_DataType_input2>); |
| 237 | + input2_it = reinterpret_cast<DPNPC_id<_DataType_input2>*>(dpnp_memory_alloc_c(q_ref, input2_it_size_in_bytes)); |
| 238 | + new (input2_it) DPNPC_id<_DataType_input2>(q_ref, input2_data, input2_shape, input2_strides, input2_ndim); |
| 239 | + |
| 240 | + input2_it->broadcast_to_shape(result_shape, result_ndim); |
| 241 | + |
| 242 | + auto kernel_parallel_for_func = [=](sycl::id<1> global_id) { |
| 243 | + const size_t i = global_id[0]; /* for (size_t i = 0; i < result_size; ++i) */ |
| 244 | + { |
| 245 | + const bool condition = (*condition_it)[i]; |
| 246 | + const _DataType_output input1_elem = (*input1_it)[i]; |
| 247 | + const _DataType_output input2_elem = (*input2_it)[i]; |
| 248 | + result[i] = (condition) ? input1_elem : input2_elem; |
| 249 | + } |
| 250 | + }; |
| 251 | + auto kernel_func = [&](sycl::handler& cgh) { |
| 252 | + cgh.parallel_for<class dpnp_where_c_broadcast_kernel<_DataType_output, _DataType_input1, _DataType_input2>>( |
| 253 | + gws, kernel_parallel_for_func); |
| 254 | + }; |
| 255 | + |
| 256 | + q.submit(kernel_func).wait(); |
| 257 | + |
| 258 | + condition_it->~DPNPC_id(); |
| 259 | + input1_it->~DPNPC_id(); |
| 260 | + input2_it->~DPNPC_id(); |
| 261 | + |
| 262 | + return event_ref; |
| 263 | + } |
| 264 | + else if (use_strides) |
| 265 | + { |
| 266 | + if ((result_ndim != condition_ndim) || (result_ndim != input1_ndim) || (result_ndim != input2_ndim)) |
| 267 | + { |
| 268 | + throw std::runtime_error("Result ndim=" + std::to_string(result_ndim) + |
| 269 | + " mismatches with either condition ndim=" + std::to_string(condition_ndim) + |
| 270 | + " or input1 ndim=" + std::to_string(input1_ndim) + |
| 271 | + " or input2 ndim=" + std::to_string(input2_ndim)); |
| 272 | + } |
| 273 | + |
| 274 | + /* memory transfer optimization, use USM-host for temporary speeds up tranfer to device */ |
| 275 | + using usm_host_allocatorT = sycl::usm_allocator<shape_elem_type, sycl::usm::alloc::host>; |
| 276 | + |
| 277 | + size_t strides_size = 4 * result_ndim; |
| 278 | + shape_elem_type* dev_strides_data = sycl::malloc_device<shape_elem_type>(strides_size, q); |
| 279 | + |
| 280 | + /* create host temporary for packed strides managed by shared pointer */ |
| 281 | + auto strides_host_packed = |
| 282 | + std::vector<shape_elem_type, usm_host_allocatorT>(strides_size, usm_host_allocatorT(q)); |
| 283 | + |
| 284 | + /* packed vector is concatenation of result_strides, condition_strides, input1_strides and input2_strides */ |
| 285 | + std::copy(result_strides, result_strides + result_ndim, strides_host_packed.begin()); |
| 286 | + std::copy(condition_strides, condition_strides + result_ndim, strides_host_packed.begin() + result_ndim); |
| 287 | + std::copy(input1_strides, input1_strides + result_ndim, strides_host_packed.begin() + 2 * result_ndim); |
| 288 | + std::copy(input2_strides, input2_strides + result_ndim, strides_host_packed.begin() + 3 * result_ndim); |
| 289 | + |
| 290 | + auto copy_strides_ev = |
| 291 | + q.copy<shape_elem_type>(strides_host_packed.data(), dev_strides_data, strides_host_packed.size()); |
| 292 | + |
| 293 | + auto kernel_parallel_for_func = [=](sycl::id<1> global_id) { |
| 294 | + const size_t output_id = global_id[0]; /* for (size_t i = 0; i < result_size; ++i) */ |
| 295 | + { |
| 296 | + const shape_elem_type* result_strides_data = &dev_strides_data[0]; |
| 297 | + const shape_elem_type* condition_strides_data = &dev_strides_data[1]; |
| 298 | + const shape_elem_type* input1_strides_data = &dev_strides_data[2]; |
| 299 | + const shape_elem_type* input2_strides_data = &dev_strides_data[3]; |
| 300 | + |
| 301 | + size_t condition_id = 0; |
| 302 | + size_t input1_id = 0; |
| 303 | + size_t input2_id = 0; |
| 304 | + |
| 305 | + for (size_t i = 0; i < result_ndim; ++i) |
| 306 | + { |
| 307 | + const size_t output_xyz_id = |
| 308 | + get_xyz_id_by_id_inkernel(output_id, result_strides_data, result_ndim, i); |
| 309 | + condition_id += output_xyz_id * condition_strides_data[i]; |
| 310 | + input1_id += output_xyz_id * input1_strides_data[i]; |
| 311 | + input2_id += output_xyz_id * input2_strides_data[i]; |
| 312 | + } |
| 313 | + |
| 314 | + const bool condition = condition_data[condition_id]; |
| 315 | + const _DataType_output input1_elem = input1_data[input1_id]; |
| 316 | + const _DataType_output input2_elem = input2_data[input2_id]; |
| 317 | + result[output_id] = (condition) ? input1_elem : input2_elem; |
| 318 | + } |
| 319 | + }; |
| 320 | + auto kernel_func = [&](sycl::handler& cgh) { |
| 321 | + cgh.depends_on(copy_strides_ev); |
| 322 | + cgh.parallel_for<class dpnp_where_c_strides_kernel<_DataType_output, _DataType_input1, _DataType_input2>>( |
| 323 | + gws, kernel_parallel_for_func); |
| 324 | + }; |
| 325 | + |
| 326 | + q.submit(kernel_func).wait(); |
| 327 | + |
| 328 | + sycl::free(dev_strides_data, q); |
| 329 | + return event_ref; |
| 330 | + } |
| 331 | + else |
| 332 | + { |
| 333 | + auto kernel_parallel_for_func = [=](sycl::id<1> global_id) { |
| 334 | + const size_t i = global_id[0]; /* for (size_t i = 0; i < result_size; ++i) */ |
| 335 | + |
| 336 | + const bool condition = condition_data[i]; |
| 337 | + const _DataType_output input1_elem = input1_data[i]; |
| 338 | + const _DataType_output input2_elem = input2_data[i]; |
| 339 | + result[i] = (condition) ? input1_elem : input2_elem; |
| 340 | + }; |
| 341 | + auto kernel_func = [&](sycl::handler& cgh) { |
| 342 | + cgh.parallel_for<class dpnp_where_c_kernel<_DataType_output, _DataType_input1, _DataType_input2>>( |
| 343 | + gws, kernel_parallel_for_func); |
| 344 | + }; |
| 345 | + event = q.submit(kernel_func); |
| 346 | + } |
| 347 | + |
| 348 | + event_ref = reinterpret_cast<DPCTLSyclEventRef>(&event); |
| 349 | + return DPCTLEvent_Copy(event_ref); |
| 350 | + |
| 351 | + return event_ref; |
| 352 | +} |
| 353 | + |
| 354 | +template <typename _DataType_output, typename _DataType_input1, typename _DataType_input2> |
| 355 | +DPCTLSyclEventRef (*dpnp_where_ext_c)(DPCTLSyclQueueRef, |
| 356 | + void*, |
| 357 | + const size_t, |
| 358 | + const size_t, |
| 359 | + const shape_elem_type*, |
| 360 | + const shape_elem_type*, |
| 361 | + const void*, |
| 362 | + const size_t, |
| 363 | + const size_t, |
| 364 | + const shape_elem_type*, |
| 365 | + const shape_elem_type*, |
| 366 | + const void*, |
| 367 | + const size_t, |
| 368 | + const size_t, |
| 369 | + const shape_elem_type*, |
| 370 | + const shape_elem_type*, |
| 371 | + const void*, |
| 372 | + const size_t, |
| 373 | + const size_t, |
| 374 | + const shape_elem_type*, |
| 375 | + const shape_elem_type*, |
| 376 | + const DPCTLEventVectorRef) = dpnp_where_c<_DataType_output, _DataType_input1, _DataType_input2>; |
| 377 | + |
| 378 | +template <DPNPFuncType FT1, DPNPFuncType... FTs> |
| 379 | +static void func_map_searching_2arg_3type_core(func_map_t& fmap) |
| 380 | +{ |
| 381 | + ((fmap[DPNPFuncName::DPNP_FN_WHERE_EXT][FT1][FTs] = |
| 382 | + {populate_func_types<FT1, FTs>(), |
| 383 | + (void*)dpnp_where_ext_c<func_type_map_t::find_type<populate_func_types<FT1, FTs>()>, |
| 384 | + func_type_map_t::find_type<FT1>, |
| 385 | + func_type_map_t::find_type<FTs>>}), |
| 386 | + ...); |
| 387 | +} |
| 388 | + |
| 389 | +template <DPNPFuncType... FTs> |
| 390 | +static void func_map_searching_2arg_3type_helper(func_map_t& fmap) |
| 391 | +{ |
| 392 | + ((func_map_searching_2arg_3type_core<FTs, FTs...>(fmap)), ...); |
| 393 | +} |
| 394 | + |
142 | 395 | void func_map_init_searching(func_map_t& fmap)
|
143 | 396 | {
|
144 | 397 | fmap[DPNPFuncName::DPNP_FN_ARGMAX][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_argmax_default_c<int32_t, int32_t>};
|
@@ -177,5 +430,7 @@ void func_map_init_searching(func_map_t& fmap)
|
177 | 430 | fmap[DPNPFuncName::DPNP_FN_ARGMIN_EXT][eft_DBL][eft_INT] = {eft_INT, (void*)dpnp_argmin_ext_c<double, int32_t>};
|
178 | 431 | fmap[DPNPFuncName::DPNP_FN_ARGMIN_EXT][eft_DBL][eft_LNG] = {eft_LNG, (void*)dpnp_argmin_ext_c<double, int64_t>};
|
179 | 432 |
|
| 433 | + func_map_searching_2arg_3type_helper<eft_BLN, eft_INT, eft_LNG, eft_FLT, eft_DBL, eft_C64, eft_C128>(fmap); |
| 434 | + |
180 | 435 | return;
|
181 | 436 | }
|
0 commit comments