@@ -107,17 +107,29 @@ image_resample(py::array input_array,
107
107
}
108
108
109
109
py::ssize_t ncomponents = 0 ;
110
+ int rgb_step = 0 ;
110
111
if (ndim == 3 ) {
111
112
ncomponents = input_array.shape (2 );
112
- if (ncomponents != 3 && ncomponents != 4 ) {
113
+ if (ncomponents == 3 ) {
114
+ // We special-case a few options in order to avoid copying in the common case.
115
+ auto rgb_stride = input_array.strides (1 );
116
+ auto item_stride = input_array.strides (2 );
117
+ if (rgb_stride == 3 * item_stride) {
118
+ rgb_step = 3 ;
119
+ } else if (rgb_stride == 4 * item_stride) {
120
+ rgb_step = 4 ;
121
+ }
122
+ } else if (ncomponents != 4 ) {
113
123
throw std::invalid_argument (
114
124
" 3D input array must be RGB with shape (M, N, 3) or RGBA with shape (M, N, 4), "
115
125
" has trailing dimension of {}" _s.format (ncomponents));
116
126
}
117
127
}
118
128
119
- // Ensure input array is contiguous, regardless of dtype
120
- input_array = py::array::ensure (input_array, py::array::c_style);
129
+ if (rgb_step == 0 ) {
130
+ // Ensure input array is contiguous, regardless of dtype
131
+ input_array = py::array::ensure (input_array, py::array::c_style);
132
+ }
121
133
122
134
// Validate output array
123
135
auto out_ndim = output_array.ndim ();
@@ -194,13 +206,22 @@ image_resample(py::array input_array,
194
206
dtype.equal (py::dtype::of<double >()) ? resample<agg::rgba64, true > :
195
207
nullptr
196
208
) : (
197
- dtype.equal (py::dtype::of<std::uint8_t >()) ? resample<agg::rgba8, false > :
198
- dtype.equal (py::dtype::of<std::int8_t >()) ? resample<agg::rgba8, false > :
199
- dtype.equal (py::dtype::of<std::uint16_t >()) ? resample<agg::rgba16, false > :
200
- dtype.equal (py::dtype::of<std::int16_t >()) ? resample<agg::rgba16, false > :
201
- dtype.equal (py::dtype::of<float >()) ? resample<agg::rgba32, false > :
202
- dtype.equal (py::dtype::of<double >()) ? resample<agg::rgba64, false > :
203
- nullptr )))
209
+ (rgb_step == 4 ) ? (
210
+ dtype.equal (py::dtype::of<std::uint8_t >()) ? resample<agg::rgba8, false , 4 > :
211
+ dtype.equal (py::dtype::of<std::int8_t >()) ? resample<agg::rgba8, false , 4 > :
212
+ dtype.equal (py::dtype::of<std::uint16_t >()) ? resample<agg::rgba16, false , 4 > :
213
+ dtype.equal (py::dtype::of<std::int16_t >()) ? resample<agg::rgba16, false , 4 > :
214
+ dtype.equal (py::dtype::of<float >()) ? resample<agg::rgba32, false , 4 > :
215
+ dtype.equal (py::dtype::of<double >()) ? resample<agg::rgba64, false , 4 > :
216
+ nullptr
217
+ ) : (
218
+ dtype.equal (py::dtype::of<std::uint8_t >()) ? resample<agg::rgba8, false , 3 > :
219
+ dtype.equal (py::dtype::of<std::int8_t >()) ? resample<agg::rgba8, false , 3 > :
220
+ dtype.equal (py::dtype::of<std::uint16_t >()) ? resample<agg::rgba16, false , 3 > :
221
+ dtype.equal (py::dtype::of<std::int16_t >()) ? resample<agg::rgba16, false , 3 > :
222
+ dtype.equal (py::dtype::of<float >()) ? resample<agg::rgba32, false , 3 > :
223
+ dtype.equal (py::dtype::of<double >()) ? resample<agg::rgba64, false , 3 > :
224
+ nullptr ))))
204
225
{
205
226
Py_BEGIN_ALLOW_THREADS
206
227
resampler (
0 commit comments