@@ -267,7 +267,6 @@ class image_accessor
267
267
#ifndef __SYCL_DEVICE_ONLY__
268
268
: public detail::AccessorBaseHost {
269
269
size_t MImageCount;
270
- size_t MImageSize;
271
270
image_channel_order MImgChannelOrder;
272
271
image_channel_type MImgChannelType;
273
272
#else
@@ -277,9 +276,8 @@ class image_accessor
277
276
AccessTarget>::type;
278
277
OCLImageTy MImageObj;
279
278
char MPadding[sizeof (detail::AccessorBaseHost) +
280
- sizeof (size_t /* MImageSize*/ ) + sizeof (size_t /* MImageCount*/ ) +
281
- sizeof (image_channel_order) + sizeof (image_channel_type) -
282
- sizeof (OCLImageTy)];
279
+ sizeof (size_t /* MImageCount*/ ) + sizeof (image_channel_order) +
280
+ sizeof (image_channel_type) - sizeof (OCLImageTy)];
283
281
284
282
protected:
285
283
void imageAccessorInit (OCLImageTy Image) { MImageObj = Image; }
@@ -342,7 +340,7 @@ class image_accessor
342
340
343
341
#ifdef __SYCL_DEVICE_ONLY__
344
342
345
- sycl::vec<int , Dimensions> getCountInternal () const {
343
+ sycl::vec<int , Dimensions> getRangeInternal () const {
346
344
return __invoke_ImageQuerySize<sycl::vec<int , Dimensions>, OCLImageTy>(
347
345
MImageObj);
348
346
}
@@ -356,10 +354,10 @@ class image_accessor
356
354
357
355
#else
358
356
359
- sycl::vec<int , Dimensions> getCountInternal () const {
357
+ sycl::vec<int , Dimensions> getRangeInternal () const {
360
358
// TODO: Implement for host.
361
359
throw runtime_error (
362
- " image::getCountInternal () is not implemented for host" );
360
+ " image::getRangeInternal () is not implemented for host" );
363
361
return sycl::vec<int , Dimensions>{1 };
364
362
}
365
363
@@ -397,7 +395,6 @@ class image_accessor
397
395
AccessMode, detail::getSyclObjImpl (ImageRef).get (),
398
396
Dimensions, ImageElementSize),
399
397
MImageCount (ImageRef.get_count ()),
400
- MImageSize (MImageCount * ImageElementSize),
401
398
MImgChannelOrder (detail::getSyclObjImpl (ImageRef)->getChannelOrder ()),
402
399
MImgChannelType (detail::getSyclObjImpl (ImageRef)->getChannelType ()) {
403
400
detail::EventImplPtr Event =
@@ -429,7 +426,6 @@ class image_accessor
429
426
AccessMode, detail::getSyclObjImpl (ImageRef).get (),
430
427
Dimensions, ImageElementSize),
431
428
MImageCount (ImageRef.get_count ()),
432
- MImageSize (MImageCount * ImageElementSize),
433
429
MImgChannelOrder (detail::getSyclObjImpl (ImageRef)->getChannelOrder ()),
434
430
MImgChannelType (detail::getSyclObjImpl (ImageRef)->getChannelType ()) {
435
431
checkDeviceFeatureSupported<info::device::image_support>(
@@ -455,32 +451,39 @@ class image_accessor
455
451
// get_count() method : Returns the number of elements of the SYCL image this
456
452
// SYCL accessor is accessing.
457
453
//
458
- // get_size() method : Returns the size in bytes of the SYCL image this SYCL
459
- // accessor is accessing. Returns ElementSize*get_count().
454
+ // get_range() method : Returns a range object which represents the number of
455
+ // elements of dataT per dimension that this accessor may access.
456
+ // The range object returned must equal to the range of the image this
457
+ // accessor is associated with.
460
458
461
459
#ifdef __SYCL_DEVICE_ONLY__
462
- size_t get_size () const {
463
- int ChannelType = __invoke_ImageQueryFormat<int , OCLImageTy>(MImageObj);
464
- int ChannelOrder = __invoke_ImageQueryOrder<int , OCLImageTy>(MImageObj);
465
- int ElementSize = getSPIRVElementSize (ChannelType, ChannelOrder);
466
- return (ElementSize * get_count ());
467
- }
468
460
469
- template < int Dims = Dimensions> size_t get_count () const ;
461
+ size_t get_count () const { return get_range< Dimensions>(). size (); }
470
462
471
- template <> size_t get_count<1 >() const { return getCountInternal (); }
472
- template <> size_t get_count<2 >() const {
473
- cl_int2 Count = getCountInternal ();
474
- return (Count.x () * Count.y ());
475
- };
476
- template <> size_t get_count<3 >() const {
477
- cl_int3 Count = getCountInternal ();
478
- return (Count.x () * Count.y () * Count.z ());
479
- };
463
+ template <int Dims = Dimensions, typename = detail::enable_if_t <Dims == 1 >>
464
+ range<1 > get_range () const {
465
+ cl_int Range = getRangeInternal ();
466
+ return range<1 >(Range);
467
+ }
468
+ template <int Dims = Dimensions, typename = detail::enable_if_t <Dims == 2 >>
469
+ range<2 > get_range () const {
470
+ cl_int2 Range = getRangeInternal ();
471
+ return range<2 >(Range[0 ], Range[1 ]);
472
+ }
473
+ template <int Dims = Dimensions, typename = detail::enable_if_t <Dims == 3 >>
474
+ range<3 > get_range () const {
475
+ cl_int3 Range = getRangeInternal ();
476
+ return range<3 >(Range[0 ], Range[1 ], Range[3 ]);
477
+ }
480
478
481
479
#else
482
- size_t get_size () const { return MImageSize; };
483
480
size_t get_count () const { return MImageCount; };
481
+
482
+ template <int Dims = Dimensions, typename = detail::enable_if_t <(Dims > 0 )>>
483
+ range<Dims> get_range () const {
484
+ return detail::convertToArrayOfN<Dims, 1 >(getAccessRange ());
485
+ }
486
+
484
487
#endif
485
488
486
489
// Available only when:
@@ -566,7 +569,7 @@ class __image_array_slice__ {
566
569
CoordElemType LastCoord = 0 ;
567
570
568
571
if (std::is_same<float , CoordElemType>::value) {
569
- sycl::vec<int , Dimensions + 1 > Size = MBaseAcc.getCountInternal ();
572
+ sycl::vec<int , Dimensions + 1 > Size = MBaseAcc.getRangeInternal ();
570
573
LastCoord =
571
574
MIdx / static_cast <float >(Size.template swizzle <Dimensions>());
572
575
} else {
@@ -608,27 +611,31 @@ class __image_array_slice__ {
608
611
}
609
612
610
613
#ifdef __SYCL_DEVICE_ONLY__
611
- size_t get_size () const { return MBaseAcc.getElementSize () * get_count (); }
612
-
613
- template <int Dims = Dimensions> size_t get_count () const ;
614
+ size_t get_count () const { return get_range<Dimensions>().size (); }
614
615
615
- template <> size_t get_count<1 >() const {
616
- cl_int2 Count = MBaseAcc.getCountInternal ();
617
- return Count.x ();
616
+ template <int Dims = Dimensions, typename = detail::enable_if_t <Dims == 1 >>
617
+ range<1 > get_range () const {
618
+ cl_int2 Count = MBaseAcc.getRangeInternal ();
619
+ return range<1 >(Count.x ());
620
+ }
621
+ template <int Dims = Dimensions, typename = detail::enable_if_t <Dims == 2 >>
622
+ range<2 > get_range () const {
623
+ cl_int3 Count = MBaseAcc.getRangeInternal ();
624
+ return range<2 >(Count.x (), Count.y ());
618
625
}
619
- template <> size_t get_count<2 >() const {
620
- cl_int3 Count = MBaseAcc.getCountInternal ();
621
- return (Count.x () * Count.y ());
622
- };
623
- #else
624
626
625
- size_t get_size () const {
626
- return MBaseAcc.MImageSize / MBaseAcc.getAccessRange ()[Dimensions];
627
- };
627
+ #else
628
628
629
629
size_t get_count () const {
630
630
return MBaseAcc.MImageCount / MBaseAcc.getAccessRange ()[Dimensions];
631
- };
631
+ }
632
+
633
+ template <int Dims = Dimensions,
634
+ typename = detail::enable_if_t <(Dims == 1 || Dims == 2 )>>
635
+ range<Dims> get_range () const {
636
+ return detail::convertToArrayOfN<Dims, 1 >(MBaseAcc.getAccessRange ());
637
+ }
638
+
632
639
#endif
633
640
634
641
private:
@@ -1099,6 +1106,11 @@ class accessor<DataT, Dimensions, AccessMode, access::target::local,
1099
1106
1100
1107
size_t get_count () const { return getSize ().size (); }
1101
1108
1109
+ template <int Dims = Dimensions, typename = detail::enable_if_t <(Dims > 0 )>>
1110
+ range<Dims> get_range () const {
1111
+ return detail::convertToArrayOfN<Dims, 1 >(getSize ());
1112
+ }
1113
+
1102
1114
template <int Dims = Dimensions,
1103
1115
typename = detail::enable_if_t <Dims == 0 && IsAccessAnyWrite>>
1104
1116
operator RefType () const {
0 commit comments