@@ -1009,6 +1009,9 @@ template <typename T1,
1009
1009
typename T6>
1010
1010
class custom_reduction_over_group_temps_strided_krn ;
1011
1011
1012
+ template <typename T1, typename T2, typename T3>
1013
+ class reduction_over_group_temps_empty_krn ;
1014
+
1012
1015
template <typename T1, typename T2, typename T3, typename T4, typename T5>
1013
1016
class single_reduction_axis0_temps_contig_krn ;
1014
1017
@@ -1120,6 +1123,31 @@ sycl::event reduction_over_group_temps_strided_impl(
1120
1123
1121
1124
constexpr resTy identity_val = su_ns::Identity<ReductionOpT, resTy>::value;
1122
1125
1126
+ if (reduction_nelems == 0 ) {
1127
+ sycl::event res_init_ev = exec_q.submit ([&](sycl::handler &cgh) {
1128
+ using IndexerT =
1129
+ dpctl::tensor::offset_utils::UnpackedStridedIndexer;
1130
+
1131
+ const py::ssize_t *const &res_shape = iter_shape_and_strides;
1132
+ const py::ssize_t *const &res_strides =
1133
+ iter_shape_and_strides + 2 * iter_nd;
1134
+ IndexerT res_indexer (iter_nd, iter_res_offset, res_shape,
1135
+ res_strides);
1136
+ using InitKernelName =
1137
+ class reduction_over_group_temps_empty_krn <resTy, argTy,
1138
+ ReductionOpT>;
1139
+ cgh.depends_on (depends);
1140
+
1141
+ cgh.parallel_for <InitKernelName>(
1142
+ sycl::range<1 >(iter_nelems), [=](sycl::id<1 > id) {
1143
+ auto res_offset = res_indexer (id[0 ]);
1144
+ res_tp[res_offset] = identity_val;
1145
+ });
1146
+ });
1147
+
1148
+ return res_init_ev;
1149
+ }
1150
+
1123
1151
const sycl::device &d = exec_q.get_device ();
1124
1152
const auto &sg_sizes = d.get_info <sycl::info::device::sub_group_sizes>();
1125
1153
size_t wg = choose_workgroup_size<4 >(reduction_nelems, sg_sizes);
@@ -1244,7 +1272,7 @@ sycl::event reduction_over_group_temps_strided_impl(
1244
1272
resTy *partially_reduced_tmp2 = nullptr ;
1245
1273
1246
1274
if (partially_reduced_tmp == nullptr ) {
1247
- throw std::runtime_error (" Unabled to allocate device_memory" );
1275
+ throw std::runtime_error (" Unable to allocate device_memory" );
1248
1276
}
1249
1277
else {
1250
1278
partially_reduced_tmp2 =
@@ -1501,6 +1529,13 @@ sycl::event reduction_axis1_over_group_temps_contig_impl(
1501
1529
1502
1530
constexpr resTy identity_val = su_ns::Identity<ReductionOpT, resTy>::value;
1503
1531
1532
+ if (reduction_nelems == 0 ) {
1533
+ sycl::event res_init_ev = exec_q.fill <resTy>(
1534
+ res_tp, resTy (identity_val), iter_nelems, depends);
1535
+
1536
+ return res_init_ev;
1537
+ }
1538
+
1504
1539
const sycl::device &d = exec_q.get_device ();
1505
1540
const auto &sg_sizes = d.get_info <sycl::info::device::sub_group_sizes>();
1506
1541
size_t wg = choose_workgroup_size<4 >(reduction_nelems, sg_sizes);
@@ -1632,7 +1667,7 @@ sycl::event reduction_axis1_over_group_temps_contig_impl(
1632
1667
resTy *partially_reduced_tmp2 = nullptr ;
1633
1668
1634
1669
if (partially_reduced_tmp == nullptr ) {
1635
- throw std::runtime_error (" Unabled to allocate device_memory" );
1670
+ throw std::runtime_error (" Unable to allocate device_memory" );
1636
1671
}
1637
1672
else {
1638
1673
partially_reduced_tmp2 =
@@ -1879,6 +1914,13 @@ sycl::event reduction_axis0_over_group_temps_contig_impl(
1879
1914
1880
1915
constexpr resTy identity_val = su_ns::Identity<ReductionOpT, resTy>::value;
1881
1916
1917
+ if (reduction_nelems == 0 ) {
1918
+ sycl::event res_init_ev = exec_q.fill <resTy>(
1919
+ res_tp, resTy (identity_val), iter_nelems, depends);
1920
+
1921
+ return res_init_ev;
1922
+ }
1923
+
1882
1924
const sycl::device &d = exec_q.get_device ();
1883
1925
const auto &sg_sizes = d.get_info <sycl::info::device::sub_group_sizes>();
1884
1926
size_t wg = choose_workgroup_size<4 >(reduction_nelems, sg_sizes);
@@ -2015,7 +2057,7 @@ sycl::event reduction_axis0_over_group_temps_contig_impl(
2015
2057
resTy *partially_reduced_tmp2 = nullptr ;
2016
2058
2017
2059
if (partially_reduced_tmp == nullptr ) {
2018
- throw std::runtime_error (" Unabled to allocate device_memory" );
2060
+ throw std::runtime_error (" Unable to allocate device_memory" );
2019
2061
}
2020
2062
else {
2021
2063
partially_reduced_tmp2 =
@@ -2712,12 +2754,16 @@ struct TypePairSupportDataForSumReductionTemps
2712
2754
td_ns::TypePairDefinedEntry<argTy, bool , outTy, std::uint32_t >,
2713
2755
td_ns::TypePairDefinedEntry<argTy, bool , outTy, std::int64_t >,
2714
2756
td_ns::TypePairDefinedEntry<argTy, bool , outTy, std::uint64_t >,
2757
+ td_ns::TypePairDefinedEntry<argTy, bool , outTy, float >,
2758
+ td_ns::TypePairDefinedEntry<argTy, bool , outTy, double >,
2715
2759
2716
2760
// input int8_t
2717
2761
td_ns::TypePairDefinedEntry<argTy, std::int8_t , outTy, std::int8_t >,
2718
2762
td_ns::TypePairDefinedEntry<argTy, std::int8_t , outTy, std::int16_t >,
2719
2763
td_ns::TypePairDefinedEntry<argTy, std::int8_t , outTy, std::int32_t >,
2720
2764
td_ns::TypePairDefinedEntry<argTy, std::int8_t , outTy, std::int64_t >,
2765
+ td_ns::TypePairDefinedEntry<argTy, std::int8_t , outTy, float >,
2766
+ td_ns::TypePairDefinedEntry<argTy, std::int8_t , outTy, double >,
2721
2767
2722
2768
// input uint8_t
2723
2769
td_ns::TypePairDefinedEntry<argTy, std::uint8_t , outTy, std::uint8_t >,
@@ -2727,32 +2773,44 @@ struct TypePairSupportDataForSumReductionTemps
2727
2773
td_ns::TypePairDefinedEntry<argTy, std::uint8_t , outTy, std::uint32_t >,
2728
2774
td_ns::TypePairDefinedEntry<argTy, std::uint8_t , outTy, std::int64_t >,
2729
2775
td_ns::TypePairDefinedEntry<argTy, std::uint8_t , outTy, std::uint64_t >,
2776
+ td_ns::TypePairDefinedEntry<argTy, std::uint8_t , outTy, float >,
2777
+ td_ns::TypePairDefinedEntry<argTy, std::uint8_t , outTy, double >,
2730
2778
2731
2779
// input int16_t
2732
2780
td_ns::TypePairDefinedEntry<argTy, std::int16_t , outTy, std::int16_t >,
2733
2781
td_ns::TypePairDefinedEntry<argTy, std::int16_t , outTy, std::int32_t >,
2734
2782
td_ns::TypePairDefinedEntry<argTy, std::int16_t , outTy, std::int64_t >,
2783
+ td_ns::TypePairDefinedEntry<argTy, std::int16_t , outTy, float >,
2784
+ td_ns::TypePairDefinedEntry<argTy, std::int16_t , outTy, double >,
2735
2785
2736
2786
// input uint16_t
2737
2787
td_ns::TypePairDefinedEntry<argTy, std::uint16_t , outTy, std::uint16_t >,
2738
2788
td_ns::TypePairDefinedEntry<argTy, std::uint16_t , outTy, std::int32_t >,
2739
2789
td_ns::TypePairDefinedEntry<argTy, std::uint16_t , outTy, std::uint32_t >,
2740
2790
td_ns::TypePairDefinedEntry<argTy, std::uint16_t , outTy, std::int64_t >,
2741
2791
td_ns::TypePairDefinedEntry<argTy, std::uint16_t , outTy, std::uint64_t >,
2792
+ td_ns::TypePairDefinedEntry<argTy, std::uint16_t , outTy, float >,
2793
+ td_ns::TypePairDefinedEntry<argTy, std::uint16_t , outTy, double >,
2742
2794
2743
2795
// input int32_t
2744
2796
td_ns::TypePairDefinedEntry<argTy, std::int32_t , outTy, std::int32_t >,
2745
2797
td_ns::TypePairDefinedEntry<argTy, std::int32_t , outTy, std::int64_t >,
2798
+ td_ns::TypePairDefinedEntry<argTy, std::int32_t , outTy, float >,
2799
+ td_ns::TypePairDefinedEntry<argTy, std::int32_t , outTy, double >,
2746
2800
2747
2801
// input uint32_t
2748
2802
td_ns::TypePairDefinedEntry<argTy, std::uint32_t , outTy, std::uint32_t >,
2749
2803
td_ns::TypePairDefinedEntry<argTy, std::uint32_t , outTy, std::uint64_t >,
2804
+ td_ns::TypePairDefinedEntry<argTy, std::uint32_t , outTy, float >,
2805
+ td_ns::TypePairDefinedEntry<argTy, std::uint32_t , outTy, double >,
2750
2806
2751
2807
// input int64_t
2752
2808
td_ns::TypePairDefinedEntry<argTy, std::int64_t , outTy, std::int64_t >,
2809
+ td_ns::TypePairDefinedEntry<argTy, std::int64_t , outTy, double >,
2753
2810
2754
- // input uint32_t
2811
+ // input uint64_t
2755
2812
td_ns::TypePairDefinedEntry<argTy, std::uint64_t , outTy, std::uint64_t >,
2813
+ td_ns::TypePairDefinedEntry<argTy, std::uint64_t , outTy, double >,
2756
2814
2757
2815
// input half
2758
2816
td_ns::TypePairDefinedEntry<argTy, sycl::half, outTy, sycl::half>,
@@ -2967,12 +3025,16 @@ struct TypePairSupportDataForProductReductionTemps
2967
3025
td_ns::TypePairDefinedEntry<argTy, bool , outTy, std::uint32_t >,
2968
3026
td_ns::TypePairDefinedEntry<argTy, bool , outTy, std::int64_t >,
2969
3027
td_ns::TypePairDefinedEntry<argTy, bool , outTy, std::uint64_t >,
3028
+ td_ns::TypePairDefinedEntry<argTy, bool , outTy, float >,
3029
+ td_ns::TypePairDefinedEntry<argTy, bool , outTy, double >,
2970
3030
2971
3031
// input int8_t
2972
3032
td_ns::TypePairDefinedEntry<argTy, std::int8_t , outTy, std::int8_t >,
2973
3033
td_ns::TypePairDefinedEntry<argTy, std::int8_t , outTy, std::int16_t >,
2974
3034
td_ns::TypePairDefinedEntry<argTy, std::int8_t , outTy, std::int32_t >,
2975
3035
td_ns::TypePairDefinedEntry<argTy, std::int8_t , outTy, std::int64_t >,
3036
+ td_ns::TypePairDefinedEntry<argTy, std::int8_t , outTy, float >,
3037
+ td_ns::TypePairDefinedEntry<argTy, std::int8_t , outTy, double >,
2976
3038
2977
3039
// input uint8_t
2978
3040
td_ns::TypePairDefinedEntry<argTy, std::uint8_t , outTy, std::uint8_t >,
@@ -2982,32 +3044,44 @@ struct TypePairSupportDataForProductReductionTemps
2982
3044
td_ns::TypePairDefinedEntry<argTy, std::uint8_t , outTy, std::uint32_t >,
2983
3045
td_ns::TypePairDefinedEntry<argTy, std::uint8_t , outTy, std::int64_t >,
2984
3046
td_ns::TypePairDefinedEntry<argTy, std::uint8_t , outTy, std::uint64_t >,
3047
+ td_ns::TypePairDefinedEntry<argTy, std::uint8_t , outTy, float >,
3048
+ td_ns::TypePairDefinedEntry<argTy, std::uint8_t , outTy, double >,
2985
3049
2986
3050
// input int16_t
2987
3051
td_ns::TypePairDefinedEntry<argTy, std::int16_t , outTy, std::int16_t >,
2988
3052
td_ns::TypePairDefinedEntry<argTy, std::int16_t , outTy, std::int32_t >,
2989
3053
td_ns::TypePairDefinedEntry<argTy, std::int16_t , outTy, std::int64_t >,
3054
+ td_ns::TypePairDefinedEntry<argTy, std::int16_t , outTy, float >,
3055
+ td_ns::TypePairDefinedEntry<argTy, std::int16_t , outTy, double >,
2990
3056
2991
3057
// input uint16_t
2992
3058
td_ns::TypePairDefinedEntry<argTy, std::uint16_t , outTy, std::uint16_t >,
2993
3059
td_ns::TypePairDefinedEntry<argTy, std::uint16_t , outTy, std::int32_t >,
2994
3060
td_ns::TypePairDefinedEntry<argTy, std::uint16_t , outTy, std::uint32_t >,
2995
3061
td_ns::TypePairDefinedEntry<argTy, std::uint16_t , outTy, std::int64_t >,
2996
3062
td_ns::TypePairDefinedEntry<argTy, std::uint16_t , outTy, std::uint64_t >,
3063
+ td_ns::TypePairDefinedEntry<argTy, std::uint16_t , outTy, float >,
3064
+ td_ns::TypePairDefinedEntry<argTy, std::uint16_t , outTy, double >,
2997
3065
2998
3066
// input int32_t
2999
3067
td_ns::TypePairDefinedEntry<argTy, std::int32_t , outTy, std::int32_t >,
3000
3068
td_ns::TypePairDefinedEntry<argTy, std::int32_t , outTy, std::int64_t >,
3069
+ td_ns::TypePairDefinedEntry<argTy, std::int32_t , outTy, float >,
3070
+ td_ns::TypePairDefinedEntry<argTy, std::int32_t , outTy, double >,
3001
3071
3002
3072
// input uint32_t
3003
3073
td_ns::TypePairDefinedEntry<argTy, std::uint32_t , outTy, std::uint32_t >,
3004
3074
td_ns::TypePairDefinedEntry<argTy, std::uint32_t , outTy, std::uint64_t >,
3075
+ td_ns::TypePairDefinedEntry<argTy, std::uint32_t , outTy, float >,
3076
+ td_ns::TypePairDefinedEntry<argTy, std::uint32_t , outTy, double >,
3005
3077
3006
3078
// input int64_t
3007
3079
td_ns::TypePairDefinedEntry<argTy, std::int64_t , outTy, std::int64_t >,
3080
+ td_ns::TypePairDefinedEntry<argTy, std::int64_t , outTy, double >,
3008
3081
3009
3082
// input uint32_t
3010
3083
td_ns::TypePairDefinedEntry<argTy, std::uint64_t , outTy, std::uint64_t >,
3084
+ td_ns::TypePairDefinedEntry<argTy, std::uint64_t , outTy, double >,
3011
3085
3012
3086
// input half
3013
3087
td_ns::TypePairDefinedEntry<argTy, sycl::half, outTy, sycl::half>,
@@ -3957,6 +4031,8 @@ template <typename T1,
3957
4031
bool b2>
3958
4032
class custom_search_over_group_temps_strided_krn ;
3959
4033
4034
+ template <typename T1, typename T2, typename T3> class search_empty_krn ;
4035
+
3960
4036
template <typename T1,
3961
4037
typename T2,
3962
4038
typename T3,
@@ -4160,6 +4236,30 @@ sycl::event search_over_group_temps_strided_impl(
4160
4236
constexpr argTy identity_val = su_ns::Identity<ReductionOpT, argTy>::value;
4161
4237
constexpr resTy idx_identity_val = su_ns::Identity<IndexOpT, resTy>::value;
4162
4238
4239
+ if (reduction_nelems == 0 ) {
4240
+ sycl::event res_init_ev = exec_q.submit ([&](sycl::handler &cgh) {
4241
+ using IndexerT =
4242
+ dpctl::tensor::offset_utils::UnpackedStridedIndexer;
4243
+
4244
+ const py::ssize_t *const &res_shape = iter_shape_and_strides;
4245
+ const py::ssize_t *const &res_strides =
4246
+ iter_shape_and_strides + 2 * iter_nd;
4247
+ IndexerT res_indexer (iter_nd, iter_res_offset, res_shape,
4248
+ res_strides);
4249
+ using InitKernelName =
4250
+ class search_empty_krn <resTy, argTy, ReductionOpT>;
4251
+ cgh.depends_on (depends);
4252
+
4253
+ cgh.parallel_for <InitKernelName>(
4254
+ sycl::range<1 >(iter_nelems), [=](sycl::id<1 > id) {
4255
+ auto res_offset = res_indexer (id[0 ]);
4256
+ res_tp[res_offset] = idx_identity_val;
4257
+ });
4258
+ });
4259
+
4260
+ return res_init_ev;
4261
+ }
4262
+
4163
4263
const sycl::device &d = exec_q.get_device ();
4164
4264
const auto &sg_sizes = d.get_info <sycl::info::device::sub_group_sizes>();
4165
4265
size_t wg = choose_workgroup_size<4 >(reduction_nelems, sg_sizes);
@@ -4590,6 +4690,13 @@ sycl::event search_axis1_over_group_temps_contig_impl(
4590
4690
constexpr argTy identity_val = su_ns::Identity<ReductionOpT, argTy>::value;
4591
4691
constexpr resTy idx_identity_val = su_ns::Identity<IndexOpT, resTy>::value;
4592
4692
4693
+ if (reduction_nelems == 0 ) {
4694
+ sycl::event res_init_ev = exec_q.fill <resTy>(
4695
+ res_tp, resTy (idx_identity_val), iter_nelems, depends);
4696
+
4697
+ return res_init_ev;
4698
+ }
4699
+
4593
4700
const sycl::device &d = exec_q.get_device ();
4594
4701
const auto &sg_sizes = d.get_info <sycl::info::device::sub_group_sizes>();
4595
4702
size_t wg = choose_workgroup_size<4 >(reduction_nelems, sg_sizes);
@@ -5005,6 +5112,13 @@ sycl::event search_axis0_over_group_temps_contig_impl(
5005
5112
constexpr argTy identity_val = su_ns::Identity<ReductionOpT, argTy>::value;
5006
5113
constexpr resTy idx_identity_val = su_ns::Identity<IndexOpT, resTy>::value;
5007
5114
5115
+ if (reduction_nelems == 0 ) {
5116
+ sycl::event res_init_ev = exec_q.fill <resTy>(
5117
+ res_tp, resTy (idx_identity_val), iter_nelems, depends);
5118
+
5119
+ return res_init_ev;
5120
+ }
5121
+
5008
5122
const sycl::device &d = exec_q.get_device ();
5009
5123
const auto &sg_sizes = d.get_info <sycl::info::device::sub_group_sizes>();
5010
5124
size_t wg = choose_workgroup_size<4 >(reduction_nelems, sg_sizes);
0 commit comments