Skip to content

Commit bb08b99

Browse files
committed
insure that we use a NxMx4 array
1 parent 0cf5d1b commit bb08b99

File tree

1 file changed

+28
-10
lines changed

1 file changed

+28
-10
lines changed

folium/utilities.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -300,20 +300,37 @@ def write_png(array):
300300
raise ValueError("Data must be NxM (mono), " + \
301301
"NxMx3 (rgb), or NxMx4 (rgba)")
302302

303+
# have to broadcast up into a full rgba array
304+
array_full = np.empty((array.shape[0], array.shape[1], 4))
305+
# NxM -> NxMx4
306+
if array.shape[2] == 1:
307+
array_full[:,:,0] = array[:,:,0]
308+
array_full[:,:,1] = array[:,:,0]
309+
array_full[:,:,2] = array[:,:,0]
310+
array_full[:,:,3] = 1
311+
# NxMx3 -> NxMx4
312+
elif array.shape[2] == 3:
313+
array_full[:,:,0] = array[:,:,0]
314+
array_full[:,:,1] = array[:,:,1]
315+
array_full[:,:,2] = array[:,:,2]
316+
array_full[:,:,3] = 1
317+
# NxMx4 -> keep
318+
else:
319+
array_full = array
320+
303321
# normalize to uint8 if it isn't already
304-
if array.dtype != 'uint8':
305-
for component in range(array.shape[2]):
306-
frame = array[:,:,component]
307-
array[:,:,component] = (frame / frame.max() * 255)
308-
array = array.astype('uint8')
309-
array = np.squeeze(array)
310-
width, height = array.shape[:2]
311-
312-
array = array.tobytes()
322+
if array_full.dtype != 'uint8':
323+
for component in range(4):
324+
frame = array_full[:,:,component]
325+
array_full[:,:,component] = (frame / frame.max() * 255)
326+
array_full = array_full.astype('uint8')
327+
width, height = array_full.shape[:2]
328+
329+
array_full = array_full.tobytes()
313330

314331
# reverse the vertical line order and add null bytes at the start
315332
width_byte_4 = width * 4
316-
raw_data = b''.join(b'\x00' + array[span:span + width_byte_4]
333+
raw_data = b''.join(b'\x00' + array_full[span:span + width_byte_4]
317334
for span in range((height - 1) * width * 4, -1, - width_byte_4))
318335

319336
def png_pack(png_tag, data):
@@ -327,3 +344,4 @@ def png_pack(png_tag, data):
327344
png_pack(b'IHDR', struct.pack("!2I5B", width, height, 8, 6, 0, 0, 0)),
328345
png_pack(b'IDAT', zlib.compress(raw_data, 9)),
329346
png_pack(b'IEND', b'')])
347+

0 commit comments

Comments
 (0)