Skip to content

Conversion from raw to multi_ptr should be done with address_space_cast #1366

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions dpctl/tensor/libtensor/include/kernels/copy_and_cast.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,25 +244,26 @@ class ContigCopyFunctor

if (base + n_vecs * vec_sz * sgSize < nelems &&
sgSize == max_sgSize) {
using src_ptrT =
sycl::multi_ptr<const srcT,
sycl::access::address_space::global_space>;
using dst_ptrT =
sycl::multi_ptr<dstT,
sycl::access::address_space::global_space>;
sycl::vec<srcT, vec_sz> src_vec;
sycl::vec<dstT, vec_sz> dst_vec;

#pragma unroll
for (std::uint8_t it = 0; it < n_vecs * vec_sz; it += vec_sz) {
src_vec =
sg.load<vec_sz>(src_ptrT(&src_p[base + it * sgSize]));
auto src_multi_ptr = sycl::address_space_cast<
sycl::access::address_space::global_space,
sycl::access::decorated::yes>(
&src_p[base + it * sgSize]);
auto dst_multi_ptr = sycl::address_space_cast<
sycl::access::address_space::global_space,
sycl::access::decorated::yes>(
&dst_p[base + it * sgSize]);

src_vec = sg.load<vec_sz>(src_multi_ptr);
#pragma unroll
for (std::uint8_t k = 0; k < vec_sz; k++) {
dst_vec[k] = fn(src_vec[k]);
}
sg.store<vec_sz>(dst_ptrT(&dst_p[base + it * sgSize]),
dst_vec);
sg.store<vec_sz>(dst_multi_ptr, dst_vec);
}
}
else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,6 @@ struct UnaryContigFunctor
if constexpr (UnaryOperatorT::is_constant::value) {
// value of operator is known to be a known constant
constexpr resT const_val = UnaryOperatorT::constant_value;
using out_ptrT =
sycl::multi_ptr<resT,
sycl::access::address_space::global_space>;

auto sg = ndit.get_sub_group();
std::uint8_t sgSize = sg.get_local_range()[0];
Expand All @@ -80,8 +77,11 @@ struct UnaryContigFunctor
sycl::vec<resT, vec_sz> res_vec(const_val);
#pragma unroll
for (std::uint8_t it = 0; it < n_vecs * vec_sz; it += vec_sz) {
sg.store<vec_sz>(out_ptrT(&out[base + it * sgSize]),
res_vec);
auto out_multi_ptr = sycl::address_space_cast<
sycl::access::address_space::global_space,
sycl::access::decorated::yes>(&out[base + it * sgSize]);

sg.store<vec_sz>(out_multi_ptr, res_vec);
}
}
else {
Expand All @@ -94,13 +94,6 @@ struct UnaryContigFunctor
else if constexpr (UnaryOperatorT::supports_sg_loadstore::value &&
UnaryOperatorT::supports_vec::value)
{
using in_ptrT =
sycl::multi_ptr<const argT,
sycl::access::address_space::global_space>;
using out_ptrT =
sycl::multi_ptr<resT,
sycl::access::address_space::global_space>;

auto sg = ndit.get_sub_group();
std::uint16_t sgSize = sg.get_local_range()[0];
std::uint16_t max_sgSize = sg.get_max_local_range()[0];
Expand All @@ -113,10 +106,16 @@ struct UnaryContigFunctor

#pragma unroll
for (std::uint16_t it = 0; it < n_vecs * vec_sz; it += vec_sz) {
x = sg.load<vec_sz>(in_ptrT(&in[base + it * sgSize]));
auto in_multi_ptr = sycl::address_space_cast<
sycl::access::address_space::global_space,
sycl::access::decorated::yes>(&in[base + it * sgSize]);
auto out_multi_ptr = sycl::address_space_cast<
sycl::access::address_space::global_space,
sycl::access::decorated::yes>(&out[base + it * sgSize]);

x = sg.load<vec_sz>(in_multi_ptr);
sycl::vec<resT, vec_sz> res_vec = op(x);
sg.store<vec_sz>(out_ptrT(&out[base + it * sgSize]),
res_vec);
sg.store<vec_sz>(out_multi_ptr, res_vec);
}
}
else {
Expand All @@ -141,23 +140,23 @@ struct UnaryContigFunctor

if ((base + n_vecs * vec_sz * sgSize < nelems_) &&
(maxsgSize == sgSize)) {
using in_ptrT =
sycl::multi_ptr<const argT,
sycl::access::address_space::global_space>;
using out_ptrT =
sycl::multi_ptr<resT,
sycl::access::address_space::global_space>;
sycl::vec<argT, vec_sz> arg_vec;

#pragma unroll
for (std::uint8_t it = 0; it < n_vecs * vec_sz; it += vec_sz) {
arg_vec = sg.load<vec_sz>(in_ptrT(&in[base + it * sgSize]));
auto in_multi_ptr = sycl::address_space_cast<
sycl::access::address_space::global_space,
sycl::access::decorated::yes>(&in[base + it * sgSize]);
auto out_multi_ptr = sycl::address_space_cast<
sycl::access::address_space::global_space,
sycl::access::decorated::yes>(&out[base + it * sgSize]);

arg_vec = sg.load<vec_sz>(in_multi_ptr);
#pragma unroll
for (std::uint8_t k = 0; k < vec_sz; ++k) {
arg_vec[k] = op(arg_vec[k]);
}
sg.store<vec_sz>(out_ptrT(&out[base + it * sgSize]),
arg_vec);
sg.store<vec_sz>(out_multi_ptr, arg_vec);
}
}
else {
Expand All @@ -179,24 +178,24 @@ struct UnaryContigFunctor

if ((base + n_vecs * vec_sz * sgSize < nelems_) &&
(maxsgSize == sgSize)) {
using in_ptrT =
sycl::multi_ptr<const argT,
sycl::access::address_space::global_space>;
using out_ptrT =
sycl::multi_ptr<resT,
sycl::access::address_space::global_space>;
sycl::vec<argT, vec_sz> arg_vec;
sycl::vec<resT, vec_sz> res_vec;

#pragma unroll
for (std::uint8_t it = 0; it < n_vecs * vec_sz; it += vec_sz) {
arg_vec = sg.load<vec_sz>(in_ptrT(&in[base + it * sgSize]));
auto in_multi_ptr = sycl::address_space_cast<
sycl::access::address_space::global_space,
sycl::access::decorated::yes>(&in[base + it * sgSize]);
auto out_multi_ptr = sycl::address_space_cast<
sycl::access::address_space::global_space,
sycl::access::decorated::yes>(&out[base + it * sgSize]);

arg_vec = sg.load<vec_sz>(in_multi_ptr);
#pragma unroll
for (std::uint8_t k = 0; k < vec_sz; ++k) {
res_vec[k] = op(arg_vec[k]);
}
sg.store<vec_sz>(out_ptrT(&out[base + it * sgSize]),
res_vec);
sg.store<vec_sz>(out_multi_ptr, res_vec);
}
}
else {
Expand Down Expand Up @@ -365,28 +364,26 @@ struct BinaryContigFunctor

if ((base + n_vecs * vec_sz * sgSize < nelems_) &&
(sgSize == maxsgSize)) {
using in_ptrT1 =
sycl::multi_ptr<const argT1,
sycl::access::address_space::global_space>;
using in_ptrT2 =
sycl::multi_ptr<const argT2,
sycl::access::address_space::global_space>;
using out_ptrT =
sycl::multi_ptr<resT,
sycl::access::address_space::global_space>;
sycl::vec<argT1, vec_sz> arg1_vec;
sycl::vec<argT2, vec_sz> arg2_vec;
sycl::vec<resT, vec_sz> res_vec;

#pragma unroll
for (std::uint8_t it = 0; it < n_vecs * vec_sz; it += vec_sz) {
arg1_vec =
sg.load<vec_sz>(in_ptrT1(&in1[base + it * sgSize]));
arg2_vec =
sg.load<vec_sz>(in_ptrT2(&in2[base + it * sgSize]));
auto in1_multi_ptr = sycl::address_space_cast<
sycl::access::address_space::global_space,
sycl::access::decorated::yes>(&in1[base + it * sgSize]);
auto in2_multi_ptr = sycl::address_space_cast<
sycl::access::address_space::global_space,
sycl::access::decorated::yes>(&in2[base + it * sgSize]);
auto out_multi_ptr = sycl::address_space_cast<
sycl::access::address_space::global_space,
sycl::access::decorated::yes>(&out[base + it * sgSize]);

arg1_vec = sg.load<vec_sz>(in1_multi_ptr);
arg2_vec = sg.load<vec_sz>(in2_multi_ptr);
res_vec = op(arg1_vec, arg2_vec);
sg.store<vec_sz>(out_ptrT(&out[base + it * sgSize]),
res_vec);
sg.store<vec_sz>(out_multi_ptr, res_vec);
}
}
else {
Expand All @@ -407,32 +404,30 @@ struct BinaryContigFunctor

if ((base + n_vecs * vec_sz * sgSize < nelems_) &&
(sgSize == maxsgSize)) {
using in_ptrT1 =
sycl::multi_ptr<const argT1,
sycl::access::address_space::global_space>;
using in_ptrT2 =
sycl::multi_ptr<const argT2,
sycl::access::address_space::global_space>;
using out_ptrT =
sycl::multi_ptr<resT,
sycl::access::address_space::global_space>;
sycl::vec<argT1, vec_sz> arg1_vec;
sycl::vec<argT2, vec_sz> arg2_vec;
sycl::vec<resT, vec_sz> res_vec;

#pragma unroll
for (std::uint8_t it = 0; it < n_vecs * vec_sz; it += vec_sz) {
arg1_vec =
sg.load<vec_sz>(in_ptrT1(&in1[base + it * sgSize]));
arg2_vec =
sg.load<vec_sz>(in_ptrT2(&in2[base + it * sgSize]));
auto in1_multi_ptr = sycl::address_space_cast<
sycl::access::address_space::global_space,
sycl::access::decorated::yes>(&in1[base + it * sgSize]);
auto in2_multi_ptr = sycl::address_space_cast<
sycl::access::address_space::global_space,
sycl::access::decorated::yes>(&in2[base + it * sgSize]);
auto out_multi_ptr = sycl::address_space_cast<
sycl::access::address_space::global_space,
sycl::access::decorated::yes>(&out[base + it * sgSize]);

arg1_vec = sg.load<vec_sz>(in1_multi_ptr);
arg2_vec = sg.load<vec_sz>(in2_multi_ptr);
#pragma unroll
for (std::uint8_t vec_id = 0; vec_id < vec_sz; ++vec_id) {
res_vec[vec_id] =
op(arg1_vec[vec_id], arg2_vec[vec_id]);
}
sg.store<vec_sz>(out_ptrT(&out[base + it * sgSize]),
res_vec);
sg.store<vec_sz>(out_multi_ptr, res_vec);
}
}
else {
Expand Down Expand Up @@ -530,22 +525,24 @@ struct BinaryContigMatrixContigRowBroadcastingFunctor
size_t base = gid - sg.get_local_id()[0];

if (base + sgSize < n_elems) {
using in_ptrT1 =
sycl::multi_ptr<const argT1,
sycl::access::address_space::global_space>;
using in_ptrT2 =
sycl::multi_ptr<const argT2,
sycl::access::address_space::global_space>;
using res_ptrT =
sycl::multi_ptr<resT,
sycl::access::address_space::global_space>;

const argT1 mat_el = sg.load(in_ptrT1(&mat[base]));
const argT2 vec_el = sg.load(in_ptrT2(&padded_vec[base % n1]));
auto in1_multi_ptr = sycl::address_space_cast<
sycl::access::address_space::global_space,
sycl::access::decorated::yes>(&mat[base]);

auto in2_multi_ptr = sycl::address_space_cast<
sycl::access::address_space::global_space,
sycl::access::decorated::yes>(&padded_vec[base % n1]);

auto out_multi_ptr = sycl::address_space_cast<
sycl::access::address_space::global_space,
sycl::access::decorated::yes>(&res[base]);

const argT1 mat_el = sg.load(in1_multi_ptr);
const argT2 vec_el = sg.load(in2_multi_ptr);

resT res_el = op(mat_el, vec_el);

sg.store(res_ptrT(&res[base]), res_el);
sg.store(out_multi_ptr, res_el);
}
else {
for (size_t k = base + sg.get_local_id()[0]; k < n_elems;
Expand Down Expand Up @@ -592,22 +589,24 @@ struct BinaryContigRowContigMatrixBroadcastingFunctor
size_t base = gid - sg.get_local_id()[0];

if (base + sgSize < n_elems) {
using in_ptrT1 =
sycl::multi_ptr<const argT1,
sycl::access::address_space::global_space>;
using in_ptrT2 =
sycl::multi_ptr<const argT2,
sycl::access::address_space::global_space>;
using res_ptrT =
sycl::multi_ptr<resT,
sycl::access::address_space::global_space>;

const argT2 mat_el = sg.load(in_ptrT2(&mat[base]));
const argT1 vec_el = sg.load(in_ptrT1(&padded_vec[base % n1]));
auto in1_multi_ptr = sycl::address_space_cast<
sycl::access::address_space::global_space,
sycl::access::decorated::yes>(&padded_vec[base % n1]);

auto in2_multi_ptr = sycl::address_space_cast<
sycl::access::address_space::global_space,
sycl::access::decorated::yes>(&mat[base]);

auto out_multi_ptr = sycl::address_space_cast<
sycl::access::address_space::global_space,
sycl::access::decorated::yes>(&res[base]);

const argT2 mat_el = sg.load(in2_multi_ptr);
const argT1 vec_el = sg.load(in1_multi_ptr);

resT res_el = op(vec_el, mat_el);

sg.store(res_ptrT(&res[base]), res_el);
sg.store(out_multi_ptr, res_el);
}
else {
for (size_t k = base + sg.get_local_id()[0]; k < n_elems;
Expand Down
Loading