@@ -88,16 +88,14 @@ void validate(const usm_ndarray &sample,
88
88
{&histogram}, names);
89
89
90
90
check_size_at_least (bins_ptr, 2 , names);
91
-
92
91
check_size_at_least (&histogram, 1 , names);
93
- check_num_dims (&histogram, 1 , names);
94
92
95
93
if (weights_ptr) {
96
94
check_num_dims (weights_ptr, 1 , names);
97
95
98
- auto sample_size = sample.get_size ( );
96
+ auto sample_size = sample.get_shape ( 0 );
99
97
auto weights_size = weights_ptr->get_size ();
100
- if (sample. get_size () != weights_ptr->get_size ()) {
98
+ if (sample_size != weights_ptr->get_size ()) {
101
99
throw py::value_error (name_of (&sample, names) + " size (" +
102
100
std::to_string (sample_size) + " ) and " +
103
101
name_of (weights_ptr, names) + " size (" +
@@ -110,61 +108,74 @@ void validate(const usm_ndarray &sample,
110
108
111
109
if (sample.get_ndim () == 1 ) {
112
110
check_num_dims (bins_ptr, 1 , names);
111
+
112
+ if (bins_ptr && histogram.get_size () != bins_ptr->get_size () - 1 ) {
113
+ auto hist_size = histogram.get_size ();
114
+ auto bins_size = bins_ptr->get_size ();
115
+ throw py::value_error (
116
+ name_of (&histogram, names) + " parameter and " +
117
+ name_of (bins_ptr, names) + " parameters shape mismatch. " +
118
+ name_of (&histogram, names) + " size is " +
119
+ std::to_string (hist_size) + name_of (bins_ptr, names) +
120
+ " must have size " + std::to_string (hist_size + 1 ) +
121
+ " but have " + std::to_string (bins_size));
122
+ }
113
123
}
114
124
else if (sample.get_ndim () == 2 ) {
115
125
auto sample_count = sample.get_shape (0 );
116
126
auto expected_dims = sample.get_shape (1 );
117
127
118
- if (bins_ptr != nullptr && bins_ptr-> get_ndim () != expected_dims) {
128
+ if (histogram. get_ndim () != expected_dims) {
119
129
throw py::value_error (
120
- name_of (&sample, names) + " parameter has shape { " +
121
- std::to_string (sample_count) + " x " +
122
- std::to_string (expected_dims) + " } " + " , so " +
123
- name_of (bins_ptr , names) + " parameter expected to be " +
130
+ name_of (&sample, names) + " parameter has shape ( " +
131
+ std::to_string (sample_count) + " , " +
132
+ std::to_string (expected_dims) + " ) " + " , so " +
133
+ name_of (&histogram , names) + " parameter expected to be " +
124
134
std::to_string (expected_dims) +
125
135
" d. "
126
136
" Actual " +
127
- std::to_string (bins-> get_ndim ()) + " d" );
137
+ std::to_string (histogram. get_ndim ()) + " d" );
128
138
}
129
- }
130
139
131
- if (bins_ptr != nullptr ) {
132
- py::ssize_t expected_hist_size = 1 ;
133
- for (int i = 0 ; i < bins_ptr->get_ndim (); ++i) {
134
- expected_hist_size *= (bins_ptr->get_shape (i) - 1 );
140
+ if (bins_ptr != nullptr ) {
141
+ py::ssize_t expected_bins_size = 0 ;
142
+ for (int i = 0 ; i < histogram.get_ndim (); ++i) {
143
+ expected_bins_size += histogram.get_shape (i) + 1 ;
144
+ }
145
+
146
+ auto actual_bins_size = bins_ptr->get_size ();
147
+ if (actual_bins_size != expected_bins_size) {
148
+ throw py::value_error (
149
+ name_of (&histogram, names) + " and " +
150
+ name_of (bins_ptr, names) + " shape mismatch. " +
151
+ name_of (bins_ptr, names) + " expected to have size = " +
152
+ std::to_string (expected_bins_size) + " . Actual " +
153
+ std::to_string (actual_bins_size));
154
+ }
135
155
}
136
156
137
- if (histogram.get_size () != expected_hist_size) {
138
- throw py::value_error (
139
- name_of (&histogram, names) + " and " +
140
- name_of (bins_ptr, names) + " shape mismatch. " +
141
- name_of (&histogram, names) + " expected to have size = " +
142
- std::to_string (expected_hist_size) + " . Actual " +
143
- std::to_string (histogram.get_size ()));
157
+ int64_t max_hist_size = std::numeric_limits<uint32_t >::max () - 1 ;
158
+ if (histogram.get_size () > max_hist_size) {
159
+ throw py::value_error (name_of (&histogram, names) +
160
+ " parameter size expected to be less than " +
161
+ std::to_string (max_hist_size) + " . Actual " +
162
+ std::to_string (histogram.get_size ()));
144
163
}
145
- }
146
-
147
- int64_t max_hist_size = std::numeric_limits<uint32_t >::max () - 1 ;
148
- if (histogram.get_size () > max_hist_size) {
149
- throw py::value_error (name_of (&histogram, names) +
150
- " parameter size expected to be less than " +
151
- std::to_string (max_hist_size) + " . Actual " +
152
- std::to_string (histogram.get_size ()));
153
- }
154
164
155
- auto array_types = dpctl_td_ns::usm_ndarray_types ();
156
- auto hist_type = static_cast <typenum_t >(
157
- array_types.typenum_to_lookup_id (histogram.get_typenum ()));
158
- if (histogram.get_elemsize () == 8 && hist_type != typenum_t ::CFLOAT) {
159
- auto device = exec_q.get_device ();
160
- bool _64bit_atomics = device.has (sycl::aspect::atomic64);
161
-
162
- if (!_64bit_atomics) {
163
- auto device_name = device.get_info <sycl::info::device::name>();
164
- throw py::value_error (
165
- name_of (&histogram, names) +
166
- " parameter has 64-bit type, but 64-bit atomics " +
167
- " are not supported for " + device_name);
165
+ auto array_types = dpctl_td_ns::usm_ndarray_types ();
166
+ auto hist_type = static_cast <typenum_t >(
167
+ array_types.typenum_to_lookup_id (histogram.get_typenum ()));
168
+ if (histogram.get_elemsize () == 8 && hist_type != typenum_t ::CFLOAT) {
169
+ auto device = exec_q.get_device ();
170
+ bool _64bit_atomics = device.has (sycl::aspect::atomic64);
171
+
172
+ if (!_64bit_atomics) {
173
+ auto device_name = device.get_info <sycl::info::device::name>();
174
+ throw py::value_error (
175
+ name_of (&histogram, names) +
176
+ " parameter has 64-bit type, but 64-bit atomics " +
177
+ " are not supported for " + device_name);
178
+ }
168
179
}
169
180
}
170
181
}
0 commit comments