@@ -300,20 +300,37 @@ def write_png(array):
300
300
raise ValueError ("Data must be NxM (mono), " + \
301
301
"NxMx3 (rgb), or NxMx4 (rgba)" )
302
302
303
+ # have to broadcast up into a full rgba array
304
+ array_full = np .empty ((array .shape [0 ], array .shape [1 ], 4 ))
305
+ # NxM -> NxMx4
306
+ if array .shape [2 ] == 1 :
307
+ array_full [:,:,0 ] = array [:,:,0 ]
308
+ array_full [:,:,1 ] = array [:,:,0 ]
309
+ array_full [:,:,2 ] = array [:,:,0 ]
310
+ array_full [:,:,3 ] = 1
311
+ # NxMx3 -> NxMx4
312
+ elif array .shape [2 ] == 3 :
313
+ array_full [:,:,0 ] = array [:,:,0 ]
314
+ array_full [:,:,1 ] = array [:,:,1 ]
315
+ array_full [:,:,2 ] = array [:,:,2 ]
316
+ array_full [:,:,3 ] = 1
317
+ # NxMx4 -> keep
318
+ else :
319
+ array_full = array
320
+
303
321
# normalize to uint8 if it isn't already
304
- if array .dtype != 'uint8' :
305
- for component in range (array .shape [2 ]):
306
- frame = array [:,:,component ]
307
- array [:,:,component ] = (frame / frame .max () * 255 )
308
- array = array .astype ('uint8' )
309
- array = np .squeeze (array )
310
- width , height = array .shape [:2 ]
311
-
312
- array = array .tobytes ()
322
+ if array_full .dtype != 'uint8' :
323
+ for component in range (4 ):
324
+ frame = array_full [:,:,component ]
325
+ array_full [:,:,component ] = (frame / frame .max () * 255 )
326
+ array_full = array_full .astype ('uint8' )
327
+ width , height = array_full .shape [:2 ]
328
+
329
+ array_full = array_full .tobytes ()
313
330
314
331
# reverse the vertical line order and add null bytes at the start
315
332
width_byte_4 = width * 4
316
- raw_data = b'' .join (b'\x00 ' + array [span :span + width_byte_4 ]
333
+ raw_data = b'' .join (b'\x00 ' + array_full [span :span + width_byte_4 ]
317
334
for span in range ((height - 1 ) * width * 4 , - 1 , - width_byte_4 ))
318
335
319
336
def png_pack (png_tag , data ):
@@ -327,3 +344,4 @@ def png_pack(png_tag, data):
327
344
png_pack (b'IHDR' , struct .pack ("!2I5B" , width , height , 8 , 6 , 0 , 0 , 0 )),
328
345
png_pack (b'IDAT' , zlib .compress (raw_data , 9 )),
329
346
png_pack (b'IEND' , b'' )])
347
+
0 commit comments