|
9 | 9 | #pragma once
|
10 | 10 |
|
11 | 11 | #include <array> // for array
|
| 12 | +#include <limits> |
12 | 13 | #include <stddef.h> // for size_t
|
13 | 14 | #include <stdint.h> // for uint32_T
|
14 | 15 | #include <sycl/aspects.hpp> // for aspect
|
15 | 16 | #include <sycl/ext/oneapi/experimental/forward_progress.hpp> // for forward_progress_guarantee enum
|
16 |
| -#include <sycl/ext/oneapi/properties/property.hpp> // for PropKind |
17 |
| -#include <sycl/ext/oneapi/properties/property_utils.hpp> // for SizeListToStr |
18 |
| -#include <sycl/ext/oneapi/properties/property_value.hpp> // for property_value |
| 17 | +#include <sycl/ext/oneapi/properties/properties.hpp> |
19 | 18 | #include <type_traits> // for true_type
|
20 | 19 | #include <utility> // for declval
|
21 | 20 | namespace sycl {
|
@@ -351,6 +350,80 @@ struct HasKernelPropertiesGetMethod<T,
|
351 | 350 | decltype(std::declval<T>().get(std::declval<properties_tag>()));
|
352 | 351 | };
|
353 | 352 |
|
| 353 | +// Trait for property compile-time meta names and values. |
| 354 | +template <typename PropertyT> struct WGSizePropertyMetaInfo { |
| 355 | + static constexpr std::array<size_t, 0> WGSize = {}; |
| 356 | + static constexpr size_t LinearSize = 0; |
| 357 | +}; |
| 358 | + |
| 359 | +template <size_t Dim0, size_t... Dims> |
| 360 | +struct WGSizePropertyMetaInfo<work_group_size_key::value_t<Dim0, Dims...>> { |
| 361 | + static constexpr std::array<size_t, sizeof...(Dims) + 1> WGSize = {Dim0, |
| 362 | + Dims...}; |
| 363 | + static constexpr size_t LinearSize = (Dim0 * ... * Dims); |
| 364 | +}; |
| 365 | + |
| 366 | +template <size_t Dim0, size_t... Dims> |
| 367 | +struct WGSizePropertyMetaInfo<max_work_group_size_key::value_t<Dim0, Dims...>> { |
| 368 | + static constexpr std::array<size_t, sizeof...(Dims) + 1> WGSize = {Dim0, |
| 369 | + Dims...}; |
| 370 | + static constexpr size_t LinearSize = (Dim0 * ... * Dims); |
| 371 | +}; |
| 372 | + |
| 373 | +// Get the value of a work-group size related property from a property list |
| 374 | +template <typename PropKey, typename PropertiesT> |
| 375 | +struct GetWGPropertyFromPropList {}; |
| 376 | + |
| 377 | +template <typename PropKey, typename... PropertiesT> |
| 378 | +struct GetWGPropertyFromPropList<PropKey, std::tuple<PropertiesT...>> { |
| 379 | + using prop_val_t = std::conditional_t< |
| 380 | + ContainsProperty<PropKey, std::tuple<PropertiesT...>>::value, |
| 381 | + typename FindCompileTimePropertyValueType< |
| 382 | + PropKey, std::tuple<PropertiesT...>>::type, |
| 383 | + void>; |
| 384 | + static constexpr auto WGSize = |
| 385 | + WGSizePropertyMetaInfo<std::remove_const_t<prop_val_t>>::WGSize; |
| 386 | + static constexpr size_t LinearSize = |
| 387 | + WGSizePropertyMetaInfo<std::remove_const_t<prop_val_t>>::LinearSize; |
| 388 | +}; |
| 389 | + |
| 390 | +// If work_group_size and max_work_group_size coexist, check that the |
| 391 | +// dimensionality matches and that the required work-group size doesn't |
| 392 | +// trivially exceed the maximum size. |
| 393 | +template <typename Properties> |
| 394 | +struct ConflictingProperties<max_work_group_size_key, Properties> |
| 395 | + : std::false_type { |
| 396 | + using WGSizeVal = GetWGPropertyFromPropList<work_group_size_key, Properties>; |
| 397 | + using MaxWGSizeVal = |
| 398 | + GetWGPropertyFromPropList<max_work_group_size_key, Properties>; |
| 399 | + // If work_group_size_key doesn't exist in the list of properties, WGSize is |
| 400 | + // an empty array and so Dims == 0. |
| 401 | + static constexpr size_t Dims = WGSizeVal::WGSize.size(); |
| 402 | + static_assert( |
| 403 | + Dims == 0 || Dims == MaxWGSizeVal::WGSize.size(), |
| 404 | + "work_group_size and max_work_group_size dimensionality must match"); |
| 405 | + static_assert(Dims < 1 || WGSizeVal::WGSize[0] <= MaxWGSizeVal::WGSize[0], |
| 406 | + "work_group_size must not exceed max_work_group_size"); |
| 407 | + static_assert(Dims < 2 || WGSizeVal::WGSize[1] <= MaxWGSizeVal::WGSize[1], |
| 408 | + "work_group_size must not exceed max_work_group_size"); |
| 409 | + static_assert(Dims < 3 || WGSizeVal::WGSize[2] <= MaxWGSizeVal::WGSize[2], |
| 410 | + "work_group_size must not exceed max_work_group_size"); |
| 411 | +}; |
| 412 | + |
| 413 | +// If work_group_size and max_linear_work_group_size coexist, check that the |
| 414 | +// required linear work-group size doesn't trivially exceed the maximum size. |
| 415 | +template <typename Properties> |
| 416 | +struct ConflictingProperties<max_linear_work_group_size_key, Properties> |
| 417 | + : std::false_type { |
| 418 | + using WGSizeVal = GetWGPropertyFromPropList<work_group_size_key, Properties>; |
| 419 | + using MaxLinearWGSizeVal = |
| 420 | + GetPropertyValueFromPropList<max_linear_work_group_size_key, size_t, void, |
| 421 | + Properties>; |
| 422 | + static_assert(WGSizeVal::WGSize.empty() || |
| 423 | + WGSizeVal::LinearSize <= MaxLinearWGSizeVal::value, |
| 424 | + "work_group_size must not exceed max_linear_work_group_size"); |
| 425 | +}; |
| 426 | + |
354 | 427 | } // namespace detail
|
355 | 428 | } // namespace ext::oneapi::experimental
|
356 | 429 | } // namespace _V1
|
|
0 commit comments