Skip to content

Commit 8736665

Browse files
committed
Resample RGB images in C++
Agg already has RGB resampling with output to RGBA builtin, so we just need to correctly wire up the corresponding templates. With this RGB resampling mode, we save the extra copy from RGB to RGBA in NumPy land that was required for the previous always-RGBA resampling.
1 parent 965efca commit 8736665

File tree

4 files changed

+103
-58
lines changed

4 files changed

+103
-58
lines changed

lib/matplotlib/image.py

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,7 @@ def flush_images():
160160
flush_images()
161161

162162

163-
def _resample(
164-
image_obj, data, out_shape, transform, *, resample=None, alpha=1):
163+
def _resample(image_obj, data, out_shape, transform, *, resample=None, alpha=1):
165164
"""
166165
Convenience wrapper around `._image.resample` to resample *data* to
167166
*out_shape* (with a third dimension if *data* is RGBA) that takes care of
@@ -204,7 +203,10 @@ def _resample(
204203
interpolation = 'nearest'
205204
else:
206205
interpolation = 'hanning'
207-
out = np.zeros(out_shape + data.shape[2:], data.dtype) # 2D->2D, 3D->3D.
206+
if len(data.shape) == 3:
207+
# Always output RGBA.
208+
out_shape += (4, )
209+
out = np.zeros(out_shape, data.dtype)
208210
if resample is None:
209211
resample = image_obj.get_resample()
210212
_image.resample(data, out, transform,
@@ -216,20 +218,6 @@ def _resample(
216218
return out
217219

218220

219-
def _rgb_to_rgba(A):
220-
"""
221-
Convert an RGB image to RGBA, as required by the image resample C++
222-
extension.
223-
"""
224-
rgba = np.zeros((A.shape[0], A.shape[1], 4), dtype=A.dtype)
225-
rgba[:, :, :3] = A
226-
if rgba.dtype == np.uint8:
227-
rgba[:, :, 3] = 255
228-
else:
229-
rgba[:, :, 3] = 1.0
230-
return rgba
231-
232-
233221
class _ImageBase(mcolorizer.ColorizingArtist):
234222
"""
235223
Base class for images.
@@ -508,10 +496,10 @@ def _make_image(self, A, in_bbox, out_bbox, clip_bbox, magnification=1.0,
508496
# alpha channel below.
509497
output_alpha = (255 * alpha) if A.dtype == np.uint8 else alpha
510498
else:
511-
output_alpha = _resample( # resample alpha channel
512-
self, A[..., 3], out_shape, t, alpha=alpha)
513-
output = _resample( # resample rgb channels
514-
self, _rgb_to_rgba(A[..., :3]), out_shape, t, alpha=alpha)
499+
# resample alpha channel
500+
output_alpha = _resample(self, A[..., 3], out_shape, t, alpha=alpha)
501+
# resample rgb channels
502+
output = _resample(self, A[..., :3], out_shape, t, alpha=alpha)
515503
output[..., 3] = output_alpha # recombine rgb and alpha
516504

517505
# output is now either a 2D array of normed (int or float) data

lib/matplotlib/tests/test_image.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1577,8 +1577,8 @@ def test__resample_valid_output():
15771577
resample(np.zeros((9, 9)), np.zeros((9, 9, 4)))
15781578
with pytest.raises(ValueError, match="different dimensionalities"):
15791579
resample(np.zeros((9, 9, 4)), np.zeros((9, 9)))
1580-
with pytest.raises(ValueError, match="3D input array must be RGBA"):
1581-
resample(np.zeros((9, 9, 3)), np.zeros((9, 9, 4)))
1580+
with pytest.raises(ValueError, match="3D input array must be RGB"):
1581+
resample(np.zeros((9, 9, 2)), np.zeros((9, 9, 4)))
15821582
with pytest.raises(ValueError, match="3D output array must be RGBA"):
15831583
resample(np.zeros((9, 9, 4)), np.zeros((9, 9, 3)))
15841584
with pytest.raises(ValueError, match="mismatched types"):

src/_image_resample.h

Lines changed: 60 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "agg_image_accessors.h"
99
#include "agg_path_storage.h"
1010
#include "agg_pixfmt_gray.h"
11+
#include "agg_pixfmt_rgb.h"
1112
#include "agg_pixfmt_rgba.h"
1213
#include "agg_renderer_base.h"
1314
#include "agg_renderer_scanline.h"
@@ -16,6 +17,7 @@
1617
#include "agg_span_allocator.h"
1718
#include "agg_span_converter.h"
1819
#include "agg_span_image_filter_gray.h"
20+
#include "agg_span_image_filter_rgb.h"
1921
#include "agg_span_image_filter_rgba.h"
2022
#include "agg_span_interpolator_adaptor.h"
2123
#include "agg_span_interpolator_linear.h"
@@ -496,16 +498,38 @@ typedef enum {
496498
} interpolation_e;
497499

498500

499-
// T is rgba if and only if it has an T::r field.
501+
// T is rgb(a) if and only if it has an T::r field.
500502
template<typename T, typename = void> struct is_grayscale : std::true_type {};
501503
template<typename T> struct is_grayscale<T, std::void_t<decltype(T::r)>> : std::false_type {};
502504
template<typename T> constexpr bool is_grayscale_v = is_grayscale<T>::value;
503505

504506

505-
template<typename color_type>
507+
template<typename color_type, bool input_has_alpha>
506508
struct type_mapping
507509
{
508-
using blender_type = std::conditional_t<
510+
using input_blender_type = std::conditional_t<
511+
is_grayscale_v<color_type>,
512+
agg::blender_gray<color_type>,
513+
std::conditional_t<
514+
input_has_alpha,
515+
std::conditional_t<
516+
std::is_same_v<color_type, agg::rgba8>,
517+
fixed_blender_rgba_plain<color_type, agg::order_rgba>,
518+
agg::blender_rgba_plain<color_type, agg::order_rgba>
519+
>,
520+
agg::blender_rgb<color_type, agg::order_rgb>
521+
>
522+
>;
523+
using input_pixfmt_type = std::conditional_t<
524+
is_grayscale_v<color_type>,
525+
agg::pixfmt_alpha_blend_gray<input_blender_type, agg::rendering_buffer>,
526+
std::conditional_t<
527+
input_has_alpha,
528+
agg::pixfmt_alpha_blend_rgba<input_blender_type, agg::rendering_buffer>,
529+
agg::pixfmt_alpha_blend_rgb<input_blender_type, agg::rendering_buffer, 3>
530+
>
531+
>;
532+
using output_blender_type = std::conditional_t<
509533
is_grayscale_v<color_type>,
510534
agg::blender_gray<color_type>,
511535
std::conditional_t<
@@ -514,25 +538,37 @@ struct type_mapping
514538
agg::blender_rgba_plain<color_type, agg::order_rgba>
515539
>
516540
>;
517-
using pixfmt_type = std::conditional_t<
541+
using output_pixfmt_type = std::conditional_t<
518542
is_grayscale_v<color_type>,
519-
agg::pixfmt_alpha_blend_gray<blender_type, agg::rendering_buffer>,
520-
agg::pixfmt_alpha_blend_rgba<blender_type, agg::rendering_buffer>
543+
agg::pixfmt_alpha_blend_gray<output_blender_type, agg::rendering_buffer>,
544+
agg::pixfmt_alpha_blend_rgba<output_blender_type, agg::rendering_buffer>
521545
>;
522546
template<typename A> using span_gen_affine_type = std::conditional_t<
523547
is_grayscale_v<color_type>,
524548
agg::span_image_resample_gray_affine<A>,
525-
agg::span_image_resample_rgba_affine<A>
549+
std::conditional_t<
550+
input_has_alpha,
551+
agg::span_image_resample_rgba_affine<A>,
552+
agg::span_image_resample_rgb_affine<A>
553+
>
526554
>;
527555
template<typename A, typename B> using span_gen_filter_type = std::conditional_t<
528556
is_grayscale_v<color_type>,
529557
agg::span_image_filter_gray<A, B>,
530-
agg::span_image_filter_rgba<A, B>
558+
std::conditional_t<
559+
input_has_alpha,
560+
agg::span_image_filter_rgba<A, B>,
561+
agg::span_image_filter_rgb<A, B>
562+
>
531563
>;
532564
template<typename A, typename B> using span_gen_nn_type = std::conditional_t<
533565
is_grayscale_v<color_type>,
534566
agg::span_image_filter_gray_nn<A, B>,
535-
agg::span_image_filter_rgba_nn<A, B>
567+
std::conditional_t<
568+
input_has_alpha,
569+
agg::span_image_filter_rgba_nn<A, B>,
570+
agg::span_image_filter_rgb_nn<A, B>
571+
>
536572
>;
537573
};
538574

@@ -686,16 +722,16 @@ static void get_filter(const resample_params_t &params,
686722
}
687723

688724

689-
template<typename color_type>
725+
template<typename color_type, bool input_has_alpha = true>
690726
void resample(
691727
const void *input, int in_width, int in_height,
692728
void *output, int out_width, int out_height,
693729
resample_params_t &params)
694730
{
695-
using type_mapping_t = type_mapping<color_type>;
731+
using type_mapping_t = type_mapping<color_type, input_has_alpha>;
696732

697-
using input_pixfmt_t = typename type_mapping_t::pixfmt_type;
698-
using output_pixfmt_t = typename type_mapping_t::pixfmt_type;
733+
using input_pixfmt_t = typename type_mapping_t::input_pixfmt_type;
734+
using output_pixfmt_t = typename type_mapping_t::output_pixfmt_type;
699735

700736
using renderer_t = agg::renderer_base<output_pixfmt_t>;
701737
using rasterizer_t = agg::rasterizer_scanline_aa<agg::rasterizer_sl_clip_dbl>;
@@ -711,9 +747,16 @@ void resample(
711747
using arbitrary_interpolator_t =
712748
agg::span_interpolator_adaptor<agg::span_interpolator_linear<>, lookup_distortion>;
713749

714-
size_t itemsize = sizeof(color_type);
750+
size_t in_itemsize = sizeof(color_type);
751+
size_t out_itemsize = sizeof(color_type);
715752
if (is_grayscale<color_type>::value) {
716-
itemsize /= 2; // agg::grayXX includes an alpha channel which we don't have.
753+
// agg::grayXX includes an alpha channel which we don't have.
754+
in_itemsize /= 2;
755+
out_itemsize /= 2;
756+
} else if(!input_has_alpha) {
757+
// color_type is the output type, but the input doesn't have an alpha channel,
758+
// so we remove one value's size off the input size.
759+
in_itemsize -= sizeof(typename color_type::value_type);
717760
}
718761

719762
if (params.interpolation != NEAREST &&
@@ -733,13 +776,13 @@ void resample(
733776

734777
agg::rendering_buffer input_buffer;
735778
input_buffer.attach(
736-
(unsigned char *)input, in_width, in_height, in_width * itemsize);
779+
(unsigned char *)input, in_width, in_height, in_width * in_itemsize);
737780
input_pixfmt_t input_pixfmt(input_buffer);
738781
image_accessor_t input_accessor(input_pixfmt);
739782

740783
agg::rendering_buffer output_buffer;
741784
output_buffer.attach(
742-
(unsigned char *)output, out_width, out_height, out_width * itemsize);
785+
(unsigned char *)output, out_width, out_height, out_width * out_itemsize);
743786
output_pixfmt_t output_pixfmt(output_buffer);
744787
renderer_t renderer(output_pixfmt);
745788

src/_image_wrapper.cpp

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -106,10 +106,14 @@ image_resample(py::array input_array,
106106
throw std::invalid_argument("Input array must be a 2D or 3D array");
107107
}
108108

109-
if (ndim == 3 && input_array.shape(2) != 4) {
110-
throw std::invalid_argument(
111-
"3D input array must be RGBA with shape (M, N, 4), has trailing dimension of {}"_s.format(
112-
input_array.shape(2)));
109+
py::ssize_t ncomponents = 0;
110+
if (ndim == 3) {
111+
ncomponents = input_array.shape(2);
112+
if (ncomponents != 3 && ncomponents != 4) {
113+
throw std::invalid_argument(
114+
"3D input array must be RGB with shape (M, N, 3) or RGBA with shape (M, N, 4), "
115+
"has trailing dimension of {}"_s.format(ncomponents));
116+
}
113117
}
114118

115119
// Ensure input array is contiguous, regardless of dtype
@@ -173,21 +177,31 @@ image_resample(py::array input_array,
173177

174178
if (auto resampler =
175179
(ndim == 2) ? (
176-
(dtype.equal(py::dtype::of<std::uint8_t>())) ? resample<agg::gray8> :
177-
(dtype.equal(py::dtype::of<std::int8_t>())) ? resample<agg::gray8> :
178-
(dtype.equal(py::dtype::of<std::uint16_t>())) ? resample<agg::gray16> :
179-
(dtype.equal(py::dtype::of<std::int16_t>())) ? resample<agg::gray16> :
180-
(dtype.equal(py::dtype::of<float>())) ? resample<agg::gray32> :
181-
(dtype.equal(py::dtype::of<double>())) ? resample<agg::gray64> :
180+
dtype.equal(py::dtype::of<std::uint8_t>()) ? resample<agg::gray8> :
181+
dtype.equal(py::dtype::of<std::int8_t>()) ? resample<agg::gray8> :
182+
dtype.equal(py::dtype::of<std::uint16_t>()) ? resample<agg::gray16> :
183+
dtype.equal(py::dtype::of<std::int16_t>()) ? resample<agg::gray16> :
184+
dtype.equal(py::dtype::of<float>()) ? resample<agg::gray32> :
185+
dtype.equal(py::dtype::of<double>()) ? resample<agg::gray64> :
182186
nullptr) : (
183-
// ndim == 3
184-
(dtype.equal(py::dtype::of<std::uint8_t>())) ? resample<agg::rgba8> :
185-
(dtype.equal(py::dtype::of<std::int8_t>())) ? resample<agg::rgba8> :
186-
(dtype.equal(py::dtype::of<std::uint16_t>())) ? resample<agg::rgba16> :
187-
(dtype.equal(py::dtype::of<std::int16_t>())) ? resample<agg::rgba16> :
188-
(dtype.equal(py::dtype::of<float>())) ? resample<agg::rgba32> :
189-
(dtype.equal(py::dtype::of<double>())) ? resample<agg::rgba64> :
190-
nullptr)) {
187+
// ndim == 3
188+
(ncomponents == 4) ? (
189+
dtype.equal(py::dtype::of<std::uint8_t>()) ? resample<agg::rgba8, true> :
190+
dtype.equal(py::dtype::of<std::int8_t>()) ? resample<agg::rgba8, true> :
191+
dtype.equal(py::dtype::of<std::uint16_t>()) ? resample<agg::rgba16, true> :
192+
dtype.equal(py::dtype::of<std::int16_t>()) ? resample<agg::rgba16, true> :
193+
dtype.equal(py::dtype::of<float>()) ? resample<agg::rgba32, true> :
194+
dtype.equal(py::dtype::of<double>()) ? resample<agg::rgba64, true> :
195+
nullptr
196+
) : (
197+
dtype.equal(py::dtype::of<std::uint8_t>()) ? resample<agg::rgba8, false> :
198+
dtype.equal(py::dtype::of<std::int8_t>()) ? resample<agg::rgba8, false> :
199+
dtype.equal(py::dtype::of<std::uint16_t>()) ? resample<agg::rgba16, false> :
200+
dtype.equal(py::dtype::of<std::int16_t>()) ? resample<agg::rgba16, false> :
201+
dtype.equal(py::dtype::of<float>()) ? resample<agg::rgba32, false> :
202+
dtype.equal(py::dtype::of<double>()) ? resample<agg::rgba64, false> :
203+
nullptr)))
204+
{
191205
Py_BEGIN_ALLOW_THREADS
192206
resampler(
193207
input_array.data(), input_array.shape(1), input_array.shape(0),

0 commit comments

Comments
 (0)