Skip to content

Commit 0d80848

Browse files
deekay42NicolasHug
andauthored
GPU jpeg decoder: add batch support and hardware decoding (#8496)
Co-authored-by: Nicolas Hug <[email protected]>
1 parent 5242d6a commit 0d80848

File tree

10 files changed

+934
-317
lines changed

10 files changed

+934
-317
lines changed

benchmarks/encoding.py

Lines changed: 0 additions & 67 deletions
This file was deleted.

benchmarks/encoding_decoding.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
import os
2+
import platform
3+
import statistics
4+
5+
import torch
6+
import torch.utils.benchmark as benchmark
7+
import torchvision
8+
9+
10+
def print_machine_specs():
11+
print("Processor:", platform.processor())
12+
print("Platform:", platform.platform())
13+
print("Logical CPUs:", os.cpu_count())
14+
print(f"\nCUDA device: {torch.cuda.get_device_name()}")
15+
print(f"Total Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
16+
17+
18+
def get_data():
19+
transform = torchvision.transforms.Compose(
20+
[
21+
torchvision.transforms.PILToTensor(),
22+
]
23+
)
24+
path = os.path.join(os.getcwd(), "data")
25+
testset = torchvision.datasets.Places365(
26+
root="./data", download=not os.path.exists(path), transform=transform, split="val"
27+
)
28+
testloader = torch.utils.data.DataLoader(
29+
testset, batch_size=1000, shuffle=False, num_workers=1, collate_fn=lambda batch: [r[0] for r in batch]
30+
)
31+
return next(iter(testloader))
32+
33+
34+
def run_encoding_benchmark(decoded_images):
35+
results = []
36+
for device in ["cpu", "cuda"]:
37+
decoded_images_device = [t.to(device=device) for t in decoded_images]
38+
for size in [1, 100, 1000]:
39+
for num_threads in [1, 12, 24]:
40+
for stmt, strat in zip(
41+
[
42+
"[torchvision.io.encode_jpeg(img) for img in decoded_images_device_trunc]",
43+
"torchvision.io.encode_jpeg(decoded_images_device_trunc)",
44+
],
45+
["unfused", "fused"],
46+
):
47+
decoded_images_device_trunc = decoded_images_device[:size]
48+
t = benchmark.Timer(
49+
stmt=stmt,
50+
setup="import torchvision",
51+
globals={"decoded_images_device_trunc": decoded_images_device_trunc},
52+
label="Image Encoding",
53+
sub_label=f"{device.upper()} ({strat}): {stmt}",
54+
description=f"{size} images",
55+
num_threads=num_threads,
56+
)
57+
results.append(t.blocked_autorange())
58+
compare = benchmark.Compare(results)
59+
compare.print()
60+
61+
62+
def run_decoding_benchmark(encoded_images):
63+
results = []
64+
for device in ["cpu", "cuda"]:
65+
for size in [1, 100, 1000]:
66+
for num_threads in [1, 12, 24]:
67+
for stmt, strat in zip(
68+
[
69+
f"[torchvision.io.decode_jpeg(img, device='{device}') for img in encoded_images_trunc]",
70+
f"torchvision.io.decode_jpeg(encoded_images_trunc, device='{device}')",
71+
],
72+
["unfused", "fused"],
73+
):
74+
encoded_images_trunc = encoded_images[:size]
75+
t = benchmark.Timer(
76+
stmt=stmt,
77+
setup="import torchvision",
78+
globals={"encoded_images_trunc": encoded_images_trunc},
79+
label="Image Decoding",
80+
sub_label=f"{device.upper()} ({strat}): {stmt}",
81+
description=f"{size} images",
82+
num_threads=num_threads,
83+
)
84+
results.append(t.blocked_autorange())
85+
compare = benchmark.Compare(results)
86+
compare.print()
87+
88+
89+
if __name__ == "__main__":
90+
print_machine_specs()
91+
decoded_images = get_data()
92+
mean_h, mean_w = statistics.mean(t.shape[-2] for t in decoded_images), statistics.mean(
93+
t.shape[-1] for t in decoded_images
94+
)
95+
print(f"\nMean image size: {int(mean_h)}x{int(mean_w)}")
96+
run_encoding_benchmark(decoded_images)
97+
encoded_images_cuda = torchvision.io.encode_jpeg([img.cuda() for img in decoded_images])
98+
encoded_images_cpu = [img.cpu() for img in encoded_images_cuda]
99+
run_decoding_benchmark(encoded_images_cpu)

test/test_image.py

Lines changed: 99 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -413,23 +413,32 @@ def test_read_interlaced_png():
413413

414414

415415
@needs_cuda
416-
@pytest.mark.parametrize(
417-
"img_path",
418-
[pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(IMAGE_ROOT, ".jpg")],
419-
)
420416
@pytest.mark.parametrize("mode", [ImageReadMode.UNCHANGED, ImageReadMode.GRAY, ImageReadMode.RGB])
421417
@pytest.mark.parametrize("scripted", (False, True))
422-
def test_decode_jpeg_cuda(mode, img_path, scripted):
423-
if "cmyk" in img_path:
424-
pytest.xfail("Decoding a CMYK jpeg isn't supported")
418+
def test_decode_jpegs_cuda(mode, scripted):
419+
encoded_images = []
420+
for jpeg_path in get_images(IMAGE_ROOT, ".jpg"):
421+
if "cmyk" in jpeg_path:
422+
continue
423+
encoded_image = read_file(jpeg_path)
424+
encoded_images.append(encoded_image)
425+
decoded_images_cpu = decode_jpeg(encoded_images, mode=mode)
426+
decode_fn = torch.jit.script(decode_jpeg) if scripted else decode_jpeg
425427

426-
data = read_file(img_path)
427-
img = decode_image(data, mode=mode)
428-
f = torch.jit.script(decode_jpeg) if scripted else decode_jpeg
429-
img_nvjpeg = f(data, mode=mode, device="cuda")
428+
# test multithreaded decoding
429+
# in the current version we prevent this by using a lock but we still want to test it
430+
num_workers = 10
430431

431-
# Some difference expected between jpeg implementations
432-
assert (img.float() - img_nvjpeg.cpu().float()).abs().mean() < 2
432+
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
433+
futures = [executor.submit(decode_fn, encoded_images, mode, "cuda") for _ in range(num_workers)]
434+
decoded_images_threaded = [future.result() for future in futures]
435+
assert len(decoded_images_threaded) == num_workers
436+
for decoded_images in decoded_images_threaded:
437+
assert len(decoded_images) == len(encoded_images)
438+
for decoded_image_cuda, decoded_image_cpu in zip(decoded_images, decoded_images_cpu):
439+
assert decoded_image_cuda.shape == decoded_image_cpu.shape
440+
assert decoded_image_cuda.dtype == decoded_image_cpu.dtype == torch.uint8
441+
assert (decoded_image_cuda.cpu().float() - decoded_image_cpu.cpu().float()).abs().mean() < 2
433442

434443

435444
@needs_cuda
@@ -440,25 +449,95 @@ def test_decode_image_cuda_raises():
440449

441450

442451
@needs_cuda
443-
@pytest.mark.parametrize("cuda_device", ("cuda", "cuda:0", torch.device("cuda")))
444-
def test_decode_jpeg_cuda_device_param(cuda_device):
445-
"""Make sure we can pass a string or a torch.device as device param"""
452+
def test_decode_jpeg_cuda_device_param():
446453
path = next(path for path in get_images(IMAGE_ROOT, ".jpg") if "cmyk" not in path)
447454
data = read_file(path)
448-
decode_jpeg(data, device=cuda_device)
455+
current_device = torch.cuda.current_device()
456+
current_stream = torch.cuda.current_stream()
457+
num_devices = torch.cuda.device_count()
458+
devices = ["cuda", torch.device("cuda")] + [torch.device(f"cuda:{i}") for i in range(num_devices)]
459+
results = []
460+
for device in devices:
461+
results.append(decode_jpeg(data, device=device))
462+
assert len(results) == len(devices)
463+
for result in results:
464+
assert torch.all(result.cpu() == results[0].cpu())
465+
assert current_device == torch.cuda.current_device()
466+
assert current_stream == torch.cuda.current_stream()
449467

450468

451469
@needs_cuda
452470
def test_decode_jpeg_cuda_errors():
453471
data = read_file(next(get_images(IMAGE_ROOT, ".jpg")))
454472
with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"):
455473
decode_jpeg(data.reshape(-1, 1), device="cuda")
456-
with pytest.raises(RuntimeError, match="input tensor must be on CPU"):
474+
with pytest.raises(ValueError, match="must be tensors"):
475+
decode_jpeg([1, 2, 3])
476+
with pytest.raises(ValueError, match="Input tensor must be a CPU tensor"):
457477
decode_jpeg(data.to("cuda"), device="cuda")
458478
with pytest.raises(RuntimeError, match="Expected a torch.uint8 tensor"):
459479
decode_jpeg(data.to(torch.float), device="cuda")
460-
with pytest.raises(RuntimeError, match="Expected a cuda device"):
461-
torch.ops.image.decode_jpeg_cuda(data, ImageReadMode.UNCHANGED.value, "cpu")
480+
with pytest.raises(RuntimeError, match="Expected the device parameter to be a cuda device"):
481+
torch.ops.image.decode_jpegs_cuda([data], ImageReadMode.UNCHANGED.value, "cpu")
482+
with pytest.raises(ValueError, match="Input tensor must be a CPU tensor"):
483+
decode_jpeg(
484+
torch.empty((100,), dtype=torch.uint8, device="cuda"),
485+
)
486+
with pytest.raises(ValueError, match="Input list must contain tensors on CPU"):
487+
decode_jpeg(
488+
[
489+
torch.empty((100,), dtype=torch.uint8, device="cuda"),
490+
torch.empty((100,), dtype=torch.uint8, device="cuda"),
491+
]
492+
)
493+
494+
with pytest.raises(ValueError, match="Input list must contain tensors on CPU"):
495+
decode_jpeg(
496+
[
497+
torch.empty((100,), dtype=torch.uint8, device="cuda"),
498+
torch.empty((100,), dtype=torch.uint8, device="cuda"),
499+
],
500+
device="cuda",
501+
)
502+
503+
with pytest.raises(ValueError, match="Input list must contain tensors on CPU"):
504+
decode_jpeg(
505+
[
506+
torch.empty((100,), dtype=torch.uint8, device="cpu"),
507+
torch.empty((100,), dtype=torch.uint8, device="cuda"),
508+
],
509+
device="cuda",
510+
)
511+
512+
with pytest.raises(RuntimeError, match="Expected a torch.uint8 tensor"):
513+
decode_jpeg(
514+
[
515+
torch.empty((100,), dtype=torch.uint8),
516+
torch.empty((100,), dtype=torch.float32),
517+
],
518+
device="cuda",
519+
)
520+
521+
with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"):
522+
decode_jpeg(
523+
[
524+
torch.empty((100,), dtype=torch.uint8),
525+
torch.empty((1, 100), dtype=torch.uint8),
526+
],
527+
device="cuda",
528+
)
529+
530+
with pytest.raises(RuntimeError, match="Error while decoding JPEG images"):
531+
decode_jpeg(
532+
[
533+
torch.empty((100,), dtype=torch.uint8),
534+
torch.empty((100,), dtype=torch.uint8),
535+
],
536+
device="cuda",
537+
)
538+
539+
with pytest.raises(ValueError, match="Input list must contain at least one element"):
540+
decode_jpeg([], device="cuda")
462541

463542

464543
def test_encode_jpeg_errors():
@@ -515,12 +594,10 @@ def test_encode_jpeg_cuda_device_param():
515594
devices = ["cuda", torch.device("cuda")] + [torch.device(f"cuda:{i}") for i in range(num_devices)]
516595
results = []
517596
for device in devices:
518-
print(f"python: device: {device}")
519597
results.append(encode_jpeg(data.to(device=device)))
520598
assert len(results) == len(devices)
521599
for result in results:
522600
assert torch.all(result.cpu() == results[0].cpu())
523-
524601
assert current_device == torch.cuda.current_device()
525602
assert current_stream == torch.cuda.current_stream()
526603

0 commit comments

Comments
 (0)