|
11 | 11 | #include <CL/__spirv/spirv_types.hpp>
|
12 | 12 | #include <sycl/atomic.hpp>
|
13 | 13 | #include <sycl/buffer.hpp>
|
14 |
| -#include <sycl/detail/accessor_impl.hpp> |
15 | 14 | #include <sycl/detail/cl.h>
|
16 | 15 | #include <sycl/detail/common.hpp>
|
17 | 16 | #include <sycl/detail/export.hpp>
|
@@ -230,6 +229,7 @@ void __SYCL_EXPORT constructorNotification(void *BufferObj, void *AccessorObj,
|
230 | 229 | access::target Target,
|
231 | 230 | access::mode Mode,
|
232 | 231 | const code_location &CodeLoc);
|
| 232 | + |
233 | 233 | template <typename T>
|
234 | 234 | using IsPropertyListT = typename std::is_base_of<PropertyListBase, T>;
|
235 | 235 |
|
@@ -340,6 +340,159 @@ class accessor_common {
|
340 | 340 | };
|
341 | 341 | };
|
342 | 342 |
|
| 343 | +#if __cplusplus >= 201703L |
| 344 | + |
| 345 | +template <typename MayBeTag1, typename MayBeTag2> |
| 346 | +constexpr access::mode deduceAccessMode() { |
| 347 | + // property_list = {} is not properly detected by deduction guide, |
| 348 | + // when parameter is passed without curly braces: access(buffer, no_init) |
| 349 | + // thus simplest approach is to check 2 last arguments for being a tag |
| 350 | + if constexpr (std::is_same<MayBeTag1, |
| 351 | + mode_tag_t<access::mode::read>>::value || |
| 352 | + std::is_same<MayBeTag2, |
| 353 | + mode_tag_t<access::mode::read>>::value) { |
| 354 | + return access::mode::read; |
| 355 | + } |
| 356 | + |
| 357 | + if constexpr (std::is_same<MayBeTag1, |
| 358 | + mode_tag_t<access::mode::write>>::value || |
| 359 | + std::is_same<MayBeTag2, |
| 360 | + mode_tag_t<access::mode::write>>::value) { |
| 361 | + return access::mode::write; |
| 362 | + } |
| 363 | + |
| 364 | + if constexpr ( |
| 365 | + std::is_same<MayBeTag1, |
| 366 | + mode_target_tag_t<access::mode::read, |
| 367 | + access::target::constant_buffer>>::value || |
| 368 | + std::is_same<MayBeTag2, |
| 369 | + mode_target_tag_t<access::mode::read, |
| 370 | + access::target::constant_buffer>>::value) { |
| 371 | + return access::mode::read; |
| 372 | + } |
| 373 | + |
| 374 | + return access::mode::read_write; |
| 375 | +} |
| 376 | + |
| 377 | +template <typename MayBeTag1, typename MayBeTag2> |
| 378 | +constexpr access::target deduceAccessTarget(access::target defaultTarget) { |
| 379 | + if constexpr ( |
| 380 | + std::is_same<MayBeTag1, |
| 381 | + mode_target_tag_t<access::mode::read, |
| 382 | + access::target::constant_buffer>>::value || |
| 383 | + std::is_same<MayBeTag2, |
| 384 | + mode_target_tag_t<access::mode::read, |
| 385 | + access::target::constant_buffer>>::value) { |
| 386 | + return access::target::constant_buffer; |
| 387 | + } |
| 388 | + |
| 389 | + return defaultTarget; |
| 390 | +} |
| 391 | + |
| 392 | +#endif |
| 393 | + |
| 394 | + |
| 395 | + |
| 396 | +template <int Dims> class LocalAccessorBaseDevice { |
| 397 | +public: |
| 398 | + LocalAccessorBaseDevice(sycl::range<Dims> Size) |
| 399 | + : AccessRange(Size), |
| 400 | + MemRange(InitializedVal<Dims, range>::template get<0>()) {} |
| 401 | + // TODO: Actually we need only one field here, but currently compiler requires |
| 402 | + // all of them. |
| 403 | + range<Dims> AccessRange; |
| 404 | + range<Dims> MemRange; |
| 405 | + id<Dims> Offset; |
| 406 | + |
| 407 | + bool operator==(const LocalAccessorBaseDevice &Rhs) const { |
| 408 | + return (AccessRange == Rhs.AccessRange); |
| 409 | + } |
| 410 | +}; |
| 411 | + |
| 412 | +// The class describes a requirement to access a SYCL memory object such as |
| 413 | +// sycl::buffer and sycl::image. For example, each accessor used in a kernel, |
| 414 | +// except one with access target "local", adds such requirement for the command |
| 415 | +// group. |
| 416 | + |
| 417 | +template <int Dims> class AccessorImplDevice { |
| 418 | +public: |
| 419 | + AccessorImplDevice() = default; |
| 420 | + AccessorImplDevice(id<Dims> Offset, range<Dims> AccessRange, |
| 421 | + range<Dims> MemoryRange) |
| 422 | + : Offset(Offset), AccessRange(AccessRange), MemRange(MemoryRange) {} |
| 423 | + |
| 424 | + id<Dims> Offset; |
| 425 | + range<Dims> AccessRange; |
| 426 | + range<Dims> MemRange; |
| 427 | + |
| 428 | + bool operator==(const AccessorImplDevice &Rhs) const { |
| 429 | + return (Offset == Rhs.Offset && AccessRange == Rhs.AccessRange && |
| 430 | + MemRange == Rhs.MemRange); |
| 431 | + } |
| 432 | +}; |
| 433 | + |
| 434 | +class AccessorImplHost; |
| 435 | + |
| 436 | +void __SYCL_EXPORT addHostAccessorAndWait(AccessorImplHost *Req); |
| 437 | + |
| 438 | +class SYCLMemObjI; |
| 439 | + |
| 440 | +using AccessorImplPtr = std::shared_ptr<AccessorImplHost>; |
| 441 | + |
| 442 | +class __SYCL_EXPORT AccessorBaseHost { |
| 443 | +public: |
| 444 | + AccessorBaseHost(id<3> Offset, range<3> AccessRange, range<3> MemoryRange, |
| 445 | + access::mode AccessMode, void *SYCLMemObject, int Dims, |
| 446 | + int ElemSize, int OffsetInBytes = 0, |
| 447 | + bool IsSubBuffer = false); |
| 448 | + |
| 449 | +protected: |
| 450 | + id<3> &getOffset(); |
| 451 | + range<3> &getAccessRange(); |
| 452 | + range<3> &getMemoryRange(); |
| 453 | + void *getPtr(); |
| 454 | + unsigned int getElemSize() const; |
| 455 | + |
| 456 | + const id<3> &getOffset() const; |
| 457 | + const range<3> &getAccessRange() const; |
| 458 | + const range<3> &getMemoryRange() const; |
| 459 | + void *getPtr() const; |
| 460 | + |
| 461 | + void *getMemoryObject() const; |
| 462 | + |
| 463 | + template <class Obj> |
| 464 | + friend decltype(Obj::impl) getSyclObjImpl(const Obj &SyclObject); |
| 465 | + |
| 466 | + template <typename, int, access::mode, access::target, access::placeholder, |
| 467 | + typename> |
| 468 | + friend class accessor; |
| 469 | + |
| 470 | + AccessorImplPtr impl; |
| 471 | + |
| 472 | +private: |
| 473 | + friend class sycl::ext::intel::esimd::detail::AccessorPrivateProxy; |
| 474 | +}; |
| 475 | + |
| 476 | +class LocalAccessorImplHost; |
| 477 | +using LocalAccessorImplPtr = std::shared_ptr<LocalAccessorImplHost>; |
| 478 | + |
| 479 | +class __SYCL_EXPORT LocalAccessorBaseHost { |
| 480 | +public: |
| 481 | + LocalAccessorBaseHost(sycl::range<3> Size, int Dims, int ElemSize); |
| 482 | + sycl::range<3> &getSize(); |
| 483 | + const sycl::range<3> &getSize() const; |
| 484 | + void *getPtr(); |
| 485 | + void *getPtr() const; |
| 486 | + int getNumOfDims(); |
| 487 | + int getElementSize(); |
| 488 | + |
| 489 | +protected: |
| 490 | + template <class Obj> |
| 491 | + friend decltype(Obj::impl) getSyclObjImpl(const Obj &SyclObject); |
| 492 | + |
| 493 | + std::shared_ptr<LocalAccessorImplHost> impl; |
| 494 | +}; |
| 495 | + |
343 | 496 | template <int Dim, typename T> struct IsValidCoordDataT;
|
344 | 497 | template <typename T> struct IsValidCoordDataT<1, T> {
|
345 | 498 | constexpr static bool value =
|
@@ -1663,8 +1816,8 @@ class __SYCL_SPECIAL_CLASS accessor :
|
1663 | 1816 | PropertyListT::template areSameCompileTimeProperties<NewPropsT...>(),
|
1664 | 1817 | "Compile-time-constant properties must be the same");
|
1665 | 1818 | #ifndef __SYCL_DEVICE_ONLY__
|
1666 |
| - detail::constructorNotification(impl.get()->MSYCLMemObj, impl.get(), |
1667 |
| - AccessTarget, AccessMode, CodeLoc); |
| 1819 | + detail::constructorNotification(getMemoryObject(), impl.get(), AccessTarget, |
| 1820 | + AccessMode, CodeLoc); |
1668 | 1821 | #endif
|
1669 | 1822 | }
|
1670 | 1823 |
|
|
0 commit comments