Skip to content

remove mixed host\dev impl from dpnp_all dpnp_any #1155

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 27 additions & 34 deletions dpnp/backend/kernels/dpnp_krnl_logic.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//*****************************************************************************
// Copyright (c) 2016-2020, Intel Corporation
// Copyright (c) 2016-2022, Intel Corporation
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
Expand Down Expand Up @@ -51,18 +51,16 @@ DPCTLSyclEventRef dpnp_all_c(DPCTLSyclQueueRef q_ref,
}

sycl::queue q = *(reinterpret_cast<sycl::queue*>(q_ref));
sycl::event event;

DPNPC_ptr_adapter<_DataType> input1_ptr(q_ref, array1_in, size);
DPNPC_ptr_adapter<_ResultType> result1_ptr(q_ref, result1, 1, true, true);
const _DataType* array_in = input1_ptr.get_ptr();
_ResultType* result = result1_ptr.get_ptr();
const _DataType* array_in = reinterpret_cast<const _DataType*>(array1_in);
_ResultType* result = reinterpret_cast<_ResultType*>(result1);

result[0] = true;
auto fill_event = q.fill<_ResultType>(result, true, 1);

if (!size)
{
return event_ref;
event_ref = reinterpret_cast<DPCTLSyclEventRef>(&fill_event);
return DPCTLEvent_Copy(event_ref);
}

sycl::range<1> gws(size);
Expand All @@ -76,13 +74,12 @@ DPCTLSyclEventRef dpnp_all_c(DPCTLSyclQueueRef q_ref,
};

auto kernel_func = [&](sycl::handler& cgh) {
cgh.depends_on(fill_event);
cgh.parallel_for<class dpnp_all_c_kernel<_DataType, _ResultType>>(gws, kernel_parallel_for_func);
};

event = q.submit(kernel_func);

auto event = q.submit(kernel_func);
event_ref = reinterpret_cast<DPCTLSyclEventRef>(&event);

return DPCTLEvent_Copy(event_ref);
}

Expand All @@ -97,6 +94,7 @@ void dpnp_all_c(const void* array1_in, void* result1, const size_t size)
size,
dep_event_vec_ref);
DPCTLEvent_WaitAndThrow(event_ref);
DPCTLEvent_Delete(event_ref);
}

template <typename _DataType, typename _ResultType>
Expand Down Expand Up @@ -127,26 +125,23 @@ DPCTLSyclEventRef dpnp_allclose_c(DPCTLSyclQueueRef q_ref,

DPCTLSyclEventRef event_ref = nullptr;

if (!array1_in || !result1)
if (!array1_in || !array2_in || !result1)
{
return event_ref;
}

sycl::queue q = *(reinterpret_cast<sycl::queue*>(q_ref));
sycl::event event;

DPNPC_ptr_adapter<_DataType1> input1_ptr(q_ref, array1_in, size);
DPNPC_ptr_adapter<_DataType2> input2_ptr(q_ref, array2_in, size);
DPNPC_ptr_adapter<_ResultType> result1_ptr(q_ref, result1, 1, true, true);
const _DataType1* array1 = input1_ptr.get_ptr();
const _DataType2* array2 = input2_ptr.get_ptr();
_ResultType* result = result1_ptr.get_ptr();
const _DataType1* array1 = reinterpret_cast<const _DataType1*>(array1_in);
const _DataType2* array2 = reinterpret_cast<const _DataType2*>(array2_in);
_ResultType* result = reinterpret_cast<_ResultType*>(result1);

result[0] = true;
auto fill_event = q.fill<_ResultType>(result, true, 1);

if (!size)
{
return event_ref;
event_ref = reinterpret_cast<DPCTLSyclEventRef>(&fill_event);
return DPCTLEvent_Copy(event_ref);
}

sycl::range<1> gws(size);
Expand All @@ -160,14 +155,13 @@ DPCTLSyclEventRef dpnp_allclose_c(DPCTLSyclQueueRef q_ref,
};

auto kernel_func = [&](sycl::handler& cgh) {
cgh.depends_on(fill_event);
cgh.parallel_for<class dpnp_allclose_c_kernel<_DataType1, _DataType2, _ResultType>>(gws,
kernel_parallel_for_func);
};

event = q.submit(kernel_func);

auto event = q.submit(kernel_func);
event_ref = reinterpret_cast<DPCTLSyclEventRef>(&event);

return DPCTLEvent_Copy(event_ref);
}

Expand All @@ -186,6 +180,7 @@ void dpnp_allclose_c(
atol_val,
dep_event_vec_ref);
DPCTLEvent_WaitAndThrow(event_ref);
DPCTLEvent_Delete(event_ref);
}

template <typename _DataType1, typename _DataType2, typename _ResultType>
Expand Down Expand Up @@ -228,18 +223,16 @@ DPCTLSyclEventRef dpnp_any_c(DPCTLSyclQueueRef q_ref,
}

sycl::queue q = *(reinterpret_cast<sycl::queue*>(q_ref));
sycl::event event;

DPNPC_ptr_adapter<_DataType> input1_ptr(q_ref, array1_in, size);
DPNPC_ptr_adapter<_ResultType> result1_ptr(q_ref, result1, 1, true, true);
const _DataType* array_in = input1_ptr.get_ptr();
_ResultType* result = result1_ptr.get_ptr();
const _DataType* array_in = reinterpret_cast<const _DataType*>(array1_in);
_ResultType* result = reinterpret_cast<_ResultType*>(result1);

result[0] = false;
auto fill_event = q.fill<_ResultType>(result, false, 1);

if (!size)
{
return event_ref;
event_ref = reinterpret_cast<DPCTLSyclEventRef>(&fill_event);
return DPCTLEvent_Copy(event_ref);
}

sycl::range<1> gws(size);
Expand All @@ -253,13 +246,12 @@ DPCTLSyclEventRef dpnp_any_c(DPCTLSyclQueueRef q_ref,
};

auto kernel_func = [&](sycl::handler& cgh) {
cgh.depends_on(fill_event);
cgh.parallel_for<class dpnp_any_c_kernel<_DataType, _ResultType>>(gws, kernel_parallel_for_func);
};

event = q.submit(kernel_func);

auto event = q.submit(kernel_func);
event_ref = reinterpret_cast<DPCTLSyclEventRef>(&event);

return DPCTLEvent_Copy(event_ref);
}

Expand All @@ -274,6 +266,7 @@ void dpnp_any_c(const void* array1_in, void* result1, const size_t size)
size,
dep_event_vec_ref);
DPCTLEvent_WaitAndThrow(event_ref);
DPCTLEvent_Delete(event_ref);
}

template <typename _DataType, typename _ResultType>
Expand Down