Skip to content

Commit 96d4992

Browse files
committed
[SYCL] Enabling image_array.
The patch maps N dimensional image arrays accessors to N + 1 image accessors. operator[](size_t index) method of an N dimensional image array accessor should return an instance of __image_array_slice__ which provides interface of N dimensional image accessor to image specified by "index". The patch introduces this __image_array_slice__ class which holds original accessor and adjusts read, write, get_size and get_count methods to imitate operations with N dimensional image. This mapping approach is an easiest way to implement image array accesors Signed-off-by: Vlad Romanov <[email protected]>
1 parent 422a5fa commit 96d4992

File tree

3 files changed

+324
-52
lines changed

3 files changed

+324
-52
lines changed

sycl/include/CL/sycl/accessor.hpp

Lines changed: 162 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <CL/__spirv/spirv_types.hpp>
1212
#include <CL/sycl/atomic.hpp>
1313
#include <CL/sycl/buffer.hpp>
14+
#include <CL/sycl/exception.hpp>
1415
#include <CL/sycl/detail/accessor_impl.hpp>
1516
#include <CL/sycl/detail/common.hpp>
1617
#include <CL/sycl/detail/generic_type_traits.hpp>
@@ -242,6 +243,26 @@ class accessor_common {
242243
};
243244
};
244245

246+
template <int Dim, typename T> struct IsValidCoordDataT;
247+
template <typename T> struct IsValidCoordDataT<1, T> {
248+
constexpr static bool value =
249+
detail::is_contained<T, detail::type_list<cl_int, cl_float>>::type::value;
250+
};
251+
template <typename T> struct IsValidCoordDataT<2, T> {
252+
constexpr static bool value =
253+
detail::is_contained<T,
254+
detail::type_list<cl_int2, cl_float2>>::type::value;
255+
};
256+
template <typename T> struct IsValidCoordDataT<3, T> {
257+
constexpr static bool value =
258+
detail::is_contained<T,
259+
detail::type_list<cl_int4, cl_float4>>::type::value;
260+
};
261+
262+
template <typename DataT, int Dimensions, access::mode AccessMode,
263+
access::placeholder IsPlaceholder>
264+
class __image_array_slice__;
265+
245266
// Image accessor
246267
template <typename DataT, int Dimensions, access::mode AccessMode,
247268
access::target AccessTarget, access::placeholder IsPlaceholder>
@@ -266,6 +287,9 @@ class image_accessor
266287

267288
private:
268289
#endif
290+
template <typename T1, int T2, access::mode T3, access::placeholder T4>
291+
friend class __image_array_slice__;
292+
269293
constexpr static bool IsHostImageAcc =
270294
(AccessTarget == access::target::host_image);
271295

@@ -310,6 +334,30 @@ class image_accessor
310334
throw feature_not_supported("Images are not supported by this device.");
311335
}
312336

337+
#ifdef __SYCL_DEVICE_ONLY__
338+
339+
sycl::vec<int, Dimensions> getCountInternal() const {
340+
return __invoke_ImageQuerySize<sycl::vec<int, Dimensions>, OCLImageTy>(MImageObj);
341+
}
342+
343+
size_t getElementSize() const {
344+
int ChannelType = __invoke_ImageQueryFormat<int, OCLImageTy>(MImageObj);
345+
int ChannelOrder = __invoke_ImageQueryOrder<int, OCLImageTy>(MImageObj);
346+
int ElementSize = getSPIRVElementSize(ChannelType, ChannelOrder);
347+
return ElementSize;
348+
}
349+
350+
#else
351+
352+
sycl::vec<int, Dimensions> getCountInternal() const {
353+
// TODO: Implement for host.
354+
throw runtime_error(
355+
"image::getCountInternal() is not implemented for host");
356+
return sycl::vec<int, Dimensions>{1};
357+
}
358+
359+
#endif
360+
313361
public:
314362
using value_type = DataT;
315363
using reference = DataT &;
@@ -371,30 +419,6 @@ class image_accessor
371419
}
372420
#endif
373421

374-
template <typename AllocatorT, int Dims = Dimensions,
375-
typename = detail::enable_if_t<(Dims > 0) && (Dims < 3) &&
376-
IsImageArrayAcc>>
377-
image_accessor(image<Dims + 1, AllocatorT> &ImageRef,
378-
handler &CommandGroupHandlerRef, int ImageElementSize)
379-
#ifdef __SYCL_DEVICE_ONLY__
380-
{
381-
// No implementation needed for device. The constructor is only called by
382-
// host.
383-
}
384-
#else
385-
: AccessorBaseHost(id<3>(0, 0, 0) /* Offset,*/,
386-
detail::convertToArrayOfN<3, 1>(ImageRef.get_range()),
387-
detail::convertToArrayOfN<3, 1>(ImageRef.get_range()),
388-
AccessMode, detail::getSyclObjImpl(ImageRef).get(),
389-
Dimensions, ImageElementSize),
390-
MImageCount(ImageRef.get_count()),
391-
MImageSize(MImageCount * ImageElementSize) {
392-
checkDeviceFeatureSupported<info::device::image_support>(
393-
CommandGroupHandlerRef.MQueue->get_device());
394-
// TODO: Implement this function.
395-
}
396-
#endif
397-
398422
/* -- common interface members -- */
399423

400424
// operator == and != need to be defined only for host application as per the
@@ -426,37 +450,21 @@ class image_accessor
426450

427451
template <int Dims = Dimensions> size_t get_count() const;
428452

429-
template <> size_t get_count<1>() const {
430-
return __invoke_ImageQuerySize<int, OCLImageTy>(MImageObj);
431-
}
453+
template <> size_t get_count<1>() const { return getCountInternal(); }
432454
template <> size_t get_count<2>() const {
433-
cl_int2 Count = __invoke_ImageQuerySize<cl_int2, OCLImageTy>(MImageObj);
455+
cl_int2 Count = getCountInternal();
434456
return (Count.x() * Count.y());
435457
};
436458
template <> size_t get_count<3>() const {
437-
cl_int3 Count = __invoke_ImageQuerySize<cl_int3, OCLImageTy>(MImageObj);
459+
cl_int3 Count = getCountInternal();
438460
return (Count.x() * Count.y() * Count.z());
439461
};
462+
440463
#else
441464
size_t get_size() const { return MImageSize; };
442465
size_t get_count() const { return MImageCount; };
443466
#endif
444467

445-
template <int Dim, typename T> struct IsValidCoordDataT;
446-
template <typename T> struct IsValidCoordDataT<1, T> {
447-
constexpr static bool value =
448-
detail::is_contained<T,
449-
detail::type_list<cl_int, cl_float>>::type::value;
450-
};
451-
template <typename T> struct IsValidCoordDataT<2, T> {
452-
constexpr static bool value = detail::is_contained<
453-
T, detail::type_list<cl_int2, cl_float2>>::type::value;
454-
};
455-
template <typename T> struct IsValidCoordDataT<3, T> {
456-
constexpr static bool value = detail::is_contained<
457-
T, detail::type_list<cl_int4, cl_float4>>::type::value;
458-
};
459-
460468
// Available only when:
461469
// (accessTarget == access::target::image && accessMode == access::mode::read)
462470
// || (accessTarget == access::target::host_image && ( accessMode ==
@@ -514,10 +522,96 @@ class image_accessor
514522
"Read API is not implemented on host.");
515523
#endif
516524
}
525+
};
526+
527+
template <typename DataT, int Dimensions, access::mode AccessMode,
528+
access::placeholder IsPlaceholder>
529+
class __image_array_slice__ {
530+
531+
static_assert(Dimensions < 3,
532+
"Image slice cannot have more then 2 dimensions");
533+
534+
constexpr static int AdjustedDims = (Dimensions == 2) ? 4 : Dimensions + 1;
535+
536+
template <typename CoordT,
537+
typename CoordElemType =
538+
typename detail::TryToGetElementType<CoordT>::type>
539+
sycl::vec<CoordElemType, AdjustedDims>
540+
getAdjustedCoords(const CoordT &Coords) const {
541+
CoordElemType LastCoord = 0;
542+
543+
if (std::is_same<float, CoordElemType>::value) {
544+
sycl::vec<int, Dimensions + 1> Size = MBaseAcc.getCountInternal();
545+
LastCoord =
546+
MIdx / static_cast<float>(Size.template swizzle<Dimensions>());
547+
} else {
548+
LastCoord = MIdx;
549+
}
517550

518-
// Available only when: accessTarget == access::target::image_array &&
519-
// dimensions < 3
520-
//__image_array_slice__ operator[](size_t index) const;
551+
sycl::vec<CoordElemType, Dimensions> LeftoverCoords{LastCoord};
552+
sycl::vec<CoordElemType, AdjustedDims> AdjustedCoords{Coords,
553+
LeftoverCoords};
554+
return AdjustedCoords;
555+
}
556+
557+
public:
558+
__image_array_slice__(accessor<DataT, Dimensions, AccessMode,
559+
access::target::image_array, IsPlaceholder>
560+
BaseAcc,
561+
size_t Idx)
562+
: MBaseAcc(BaseAcc), MIdx(Idx) {}
563+
564+
template <typename CoordT, int Dims = Dimensions,
565+
typename = detail::enable_if_t<
566+
(Dims > 0) && (IsValidCoordDataT<Dims, CoordT>::value)>>
567+
DataT read(const CoordT &Coords) const {
568+
return MBaseAcc.read(getAdjustedCoords(Coords));
569+
}
570+
571+
template <typename CoordT, int Dims = Dimensions,
572+
typename = detail::enable_if_t<
573+
(Dims > 0) && IsValidCoordDataT<Dims, CoordT>::value>>
574+
DataT read(const CoordT &Coords, const sampler &Smpl) const {
575+
return MBaseAcc.read(getAdjustedCoords(Coords), Smpl);
576+
}
577+
578+
template <typename CoordT, int Dims = Dimensions,
579+
typename = detail::enable_if_t<
580+
(Dims > 0) && IsValidCoordDataT<Dims, CoordT>::value>>
581+
void write(const CoordT &Coords, const DataT &Color) const {
582+
return MBaseAcc.write(getAdjustedCoords(Coords), Color);
583+
}
584+
585+
586+
#ifdef __SYCL_DEVICE_ONLY__
587+
size_t get_size() const { return MBaseAcc.getElementSize() * get_count(); }
588+
589+
template <int Dims = Dimensions> size_t get_count() const;
590+
591+
template <> size_t get_count<1>() const {
592+
cl_int2 Count = MBaseAcc.getCountInternal();
593+
return Count.x();
594+
}
595+
template <> size_t get_count<2>() const {
596+
cl_int3 Count = MBaseAcc.getCountInternal();
597+
return (Count.x() * Count.y());
598+
};
599+
#else
600+
601+
size_t get_size() const {
602+
return MBaseAcc.MImageSize / MBaseAcc.getAccessRange()[Dimensions];
603+
};
604+
605+
size_t get_count() const {
606+
return MBaseAcc.MImageCount / MBaseAcc.getAccessRange()[Dimensions];
607+
};
608+
#endif
609+
610+
private:
611+
size_t MIdx;
612+
accessor<DataT, Dimensions, AccessMode, access::target::image_array,
613+
IsPlaceholder>
614+
MBaseAcc;
521615
};
522616

523617
} // namespace detail
@@ -1066,6 +1160,7 @@ class accessor<DataT, Dimensions, AccessMode, access::target::host_image,
10661160
Image, (detail::getSyclObjImpl(Image))->getElementSize()) {}
10671161
};
10681162

1163+
10691164
// Available only when: accessTarget == access::target::image_array &&
10701165
// dimensions < 3
10711166
// template <typename AllocatorT> accessor(image<dimensions + 1,
@@ -1074,20 +1169,35 @@ template <typename DataT, int Dimensions, access::mode AccessMode,
10741169
access::placeholder IsPlaceholder>
10751170
class accessor<DataT, Dimensions, AccessMode, access::target::image_array,
10761171
IsPlaceholder>
1077-
: public detail::image_accessor<DataT, Dimensions, AccessMode,
1078-
access::target::image_array,
1172+
: public detail::image_accessor<DataT, Dimensions + 1, AccessMode,
1173+
access::target::image,
10791174
IsPlaceholder> {
1080-
// TODO: Add Constructor.
10811175
#ifdef __SYCL_DEVICE_ONLY__
10821176
private:
10831177
using OCLImageTy =
1084-
typename detail::opencl_image_type<Dimensions, AccessMode,
1085-
access::target::image_array>::type;
1178+
typename detail::opencl_image_type<Dimensions + 1, AccessMode,
1179+
access::target::image>::type;
10861180

10871181
// Front End requires this method to be defined in the accessor class.
10881182
// It does not call the base class's init method.
10891183
void __init(OCLImageTy Image) { this->imageAccessorInit(Image); }
10901184
#endif
1185+
public:
1186+
template <typename AllocatorT>
1187+
accessor(cl::sycl::image<Dimensions + 1, AllocatorT> &Image,
1188+
handler &CommandGroupHandler)
1189+
: detail::image_accessor<DataT, Dimensions + 1, AccessMode,
1190+
access::target::image, IsPlaceholder>(
1191+
Image, CommandGroupHandler,
1192+
(detail::getSyclObjImpl(Image))->getElementSize()) {
1193+
CommandGroupHandler.associateWithHandler(*this);
1194+
}
1195+
1196+
detail::__image_array_slice__<DataT, Dimensions, AccessMode, IsPlaceholder>
1197+
operator[](size_t Index) const {
1198+
return detail::__image_array_slice__<DataT, Dimensions, AccessMode,
1199+
IsPlaceholder>(*this, Index);
1200+
}
10911201
};
10921202

10931203
} // namespace sycl

sycl/include/CL/sycl/detail/image_ocl_types.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ inline int getSPIRVNumChannels(int ImageChannelOrder) {
136136
case 18: // sBGRA
137137
// TODO: Enable the below assert after assert is supported for device
138138
// compiler. assert(!"Unhandled image channel order in sycl.");
139+
default:
139140
return 0;
140141
}
141142
}

0 commit comments

Comments
 (0)