@@ -413,23 +413,32 @@ def test_read_interlaced_png():
413
413
414
414
415
415
@needs_cuda
416
- @pytest .mark .parametrize (
417
- "img_path" ,
418
- [pytest .param (jpeg_path , id = _get_safe_image_name (jpeg_path )) for jpeg_path in get_images (IMAGE_ROOT , ".jpg" )],
419
- )
420
416
@pytest .mark .parametrize ("mode" , [ImageReadMode .UNCHANGED , ImageReadMode .GRAY , ImageReadMode .RGB ])
421
417
@pytest .mark .parametrize ("scripted" , (False , True ))
422
- def test_decode_jpeg_cuda (mode , img_path , scripted ):
423
- if "cmyk" in img_path :
424
- pytest .xfail ("Decoding a CMYK jpeg isn't supported" )
418
+ def test_decode_jpegs_cuda (mode , scripted ):
419
+ encoded_images = []
420
+ for jpeg_path in get_images (IMAGE_ROOT , ".jpg" ):
421
+ if "cmyk" in jpeg_path :
422
+ continue
423
+ encoded_image = read_file (jpeg_path )
424
+ encoded_images .append (encoded_image )
425
+ decoded_images_cpu = decode_jpeg (encoded_images , mode = mode )
426
+ decode_fn = torch .jit .script (decode_jpeg ) if scripted else decode_jpeg
425
427
426
- data = read_file (img_path )
427
- img = decode_image (data , mode = mode )
428
- f = torch .jit .script (decode_jpeg ) if scripted else decode_jpeg
429
- img_nvjpeg = f (data , mode = mode , device = "cuda" )
428
+ # test multithreaded decoding
429
+ # in the current version we prevent this by using a lock but we still want to test it
430
+ num_workers = 10
430
431
431
- # Some difference expected between jpeg implementations
432
- assert (img .float () - img_nvjpeg .cpu ().float ()).abs ().mean () < 2
432
+ with concurrent .futures .ThreadPoolExecutor (max_workers = num_workers ) as executor :
433
+ futures = [executor .submit (decode_fn , encoded_images , mode , "cuda" ) for _ in range (num_workers )]
434
+ decoded_images_threaded = [future .result () for future in futures ]
435
+ assert len (decoded_images_threaded ) == num_workers
436
+ for decoded_images in decoded_images_threaded :
437
+ assert len (decoded_images ) == len (encoded_images )
438
+ for decoded_image_cuda , decoded_image_cpu in zip (decoded_images , decoded_images_cpu ):
439
+ assert decoded_image_cuda .shape == decoded_image_cpu .shape
440
+ assert decoded_image_cuda .dtype == decoded_image_cpu .dtype == torch .uint8
441
+ assert (decoded_image_cuda .cpu ().float () - decoded_image_cpu .cpu ().float ()).abs ().mean () < 2
433
442
434
443
435
444
@needs_cuda
@@ -440,25 +449,95 @@ def test_decode_image_cuda_raises():
440
449
441
450
442
451
@needs_cuda
443
- @pytest .mark .parametrize ("cuda_device" , ("cuda" , "cuda:0" , torch .device ("cuda" )))
444
- def test_decode_jpeg_cuda_device_param (cuda_device ):
445
- """Make sure we can pass a string or a torch.device as device param"""
452
+ def test_decode_jpeg_cuda_device_param ():
446
453
path = next (path for path in get_images (IMAGE_ROOT , ".jpg" ) if "cmyk" not in path )
447
454
data = read_file (path )
448
- decode_jpeg (data , device = cuda_device )
455
+ current_device = torch .cuda .current_device ()
456
+ current_stream = torch .cuda .current_stream ()
457
+ num_devices = torch .cuda .device_count ()
458
+ devices = ["cuda" , torch .device ("cuda" )] + [torch .device (f"cuda:{ i } " ) for i in range (num_devices )]
459
+ results = []
460
+ for device in devices :
461
+ results .append (decode_jpeg (data , device = device ))
462
+ assert len (results ) == len (devices )
463
+ for result in results :
464
+ assert torch .all (result .cpu () == results [0 ].cpu ())
465
+ assert current_device == torch .cuda .current_device ()
466
+ assert current_stream == torch .cuda .current_stream ()
449
467
450
468
451
469
@needs_cuda
452
470
def test_decode_jpeg_cuda_errors ():
453
471
data = read_file (next (get_images (IMAGE_ROOT , ".jpg" )))
454
472
with pytest .raises (RuntimeError , match = "Expected a non empty 1-dimensional tensor" ):
455
473
decode_jpeg (data .reshape (- 1 , 1 ), device = "cuda" )
456
- with pytest .raises (RuntimeError , match = "input tensor must be on CPU" ):
474
+ with pytest .raises (ValueError , match = "must be tensors" ):
475
+ decode_jpeg ([1 , 2 , 3 ])
476
+ with pytest .raises (ValueError , match = "Input tensor must be a CPU tensor" ):
457
477
decode_jpeg (data .to ("cuda" ), device = "cuda" )
458
478
with pytest .raises (RuntimeError , match = "Expected a torch.uint8 tensor" ):
459
479
decode_jpeg (data .to (torch .float ), device = "cuda" )
460
- with pytest .raises (RuntimeError , match = "Expected a cuda device" ):
461
- torch .ops .image .decode_jpeg_cuda (data , ImageReadMode .UNCHANGED .value , "cpu" )
480
+ with pytest .raises (RuntimeError , match = "Expected the device parameter to be a cuda device" ):
481
+ torch .ops .image .decode_jpegs_cuda ([data ], ImageReadMode .UNCHANGED .value , "cpu" )
482
+ with pytest .raises (ValueError , match = "Input tensor must be a CPU tensor" ):
483
+ decode_jpeg (
484
+ torch .empty ((100 ,), dtype = torch .uint8 , device = "cuda" ),
485
+ )
486
+ with pytest .raises (ValueError , match = "Input list must contain tensors on CPU" ):
487
+ decode_jpeg (
488
+ [
489
+ torch .empty ((100 ,), dtype = torch .uint8 , device = "cuda" ),
490
+ torch .empty ((100 ,), dtype = torch .uint8 , device = "cuda" ),
491
+ ]
492
+ )
493
+
494
+ with pytest .raises (ValueError , match = "Input list must contain tensors on CPU" ):
495
+ decode_jpeg (
496
+ [
497
+ torch .empty ((100 ,), dtype = torch .uint8 , device = "cuda" ),
498
+ torch .empty ((100 ,), dtype = torch .uint8 , device = "cuda" ),
499
+ ],
500
+ device = "cuda" ,
501
+ )
502
+
503
+ with pytest .raises (ValueError , match = "Input list must contain tensors on CPU" ):
504
+ decode_jpeg (
505
+ [
506
+ torch .empty ((100 ,), dtype = torch .uint8 , device = "cpu" ),
507
+ torch .empty ((100 ,), dtype = torch .uint8 , device = "cuda" ),
508
+ ],
509
+ device = "cuda" ,
510
+ )
511
+
512
+ with pytest .raises (RuntimeError , match = "Expected a torch.uint8 tensor" ):
513
+ decode_jpeg (
514
+ [
515
+ torch .empty ((100 ,), dtype = torch .uint8 ),
516
+ torch .empty ((100 ,), dtype = torch .float32 ),
517
+ ],
518
+ device = "cuda" ,
519
+ )
520
+
521
+ with pytest .raises (RuntimeError , match = "Expected a non empty 1-dimensional tensor" ):
522
+ decode_jpeg (
523
+ [
524
+ torch .empty ((100 ,), dtype = torch .uint8 ),
525
+ torch .empty ((1 , 100 ), dtype = torch .uint8 ),
526
+ ],
527
+ device = "cuda" ,
528
+ )
529
+
530
+ with pytest .raises (RuntimeError , match = "Error while decoding JPEG images" ):
531
+ decode_jpeg (
532
+ [
533
+ torch .empty ((100 ,), dtype = torch .uint8 ),
534
+ torch .empty ((100 ,), dtype = torch .uint8 ),
535
+ ],
536
+ device = "cuda" ,
537
+ )
538
+
539
+ with pytest .raises (ValueError , match = "Input list must contain at least one element" ):
540
+ decode_jpeg ([], device = "cuda" )
462
541
463
542
464
543
def test_encode_jpeg_errors ():
@@ -515,12 +594,10 @@ def test_encode_jpeg_cuda_device_param():
515
594
devices = ["cuda" , torch .device ("cuda" )] + [torch .device (f"cuda:{ i } " ) for i in range (num_devices )]
516
595
results = []
517
596
for device in devices :
518
- print (f"python: device: { device } " )
519
597
results .append (encode_jpeg (data .to (device = device )))
520
598
assert len (results ) == len (devices )
521
599
for result in results :
522
600
assert torch .all (result .cpu () == results [0 ].cpu ())
523
-
524
601
assert current_device == torch .cuda .current_device ()
525
602
assert current_stream == torch .cuda .current_stream ()
526
603
0 commit comments