@@ -512,12 +512,53 @@ def pin_copy_to_device_nonblocking(*tensors):
512
512
#
513
513
# Until now, we have operated under the assumption that asynchronous copies from the CPU to the GPU are safe.
514
514
# This is generally true because CUDA automatically handles synchronization to ensure that the data being accessed is
515
- # valid at read time.
516
- # However, this guarantee does not extend to transfers in the opposite direction, from GPU to CPU.
517
- # Without explicit synchronization, these transfers offer no assurance that the copy will be complete at the time of
518
- # data access. Consequently, the data on the host might be incomplete or incorrect, effectively rendering it garbage:
515
+ # valid at read time __whenever the tensor is in pageable memory__.
516
+ #
517
+ # However, in other cases we cannot make the same asusmption: when a tensor is placed in pinned memory, mutating the
518
+ # original copy after calling the host-to-device transfer may corrupt the data received on GPU.
519
+ # Similarly, when a transfer is achieved in the opposite direction, from GPU to CPU, or from any device that is not CPU
520
+ # or GPU to any device that is not a CUDA-handled GPU (e.g., MPS), there is no guarantee that the data read on GPU is
521
+ # valid without explicit synchronization.
522
+ #
523
+ # In these scenarios, these transfers offer no assurance that the copy will be complete at the time of
524
+ # data access. Consequently, the data on the host might be incomplete or incorrect, effectively rendering it garbage.
525
+ #
526
+ # Let's first demonstrate this with a pinned-memory tensor:
527
+ DELAY = 100000000
528
+ try :
529
+ i = - 1
530
+ for i in range (100 ):
531
+ # Create a tensor in pin-memory
532
+ cpu_tensor = torch .ones (1024 , 1024 , pin_memory = True )
533
+ torch .cuda .synchronize ()
534
+ # Send the tensor to CUDA
535
+ cuda_tensor = cpu_tensor .to ("cuda" , non_blocking = True )
536
+ torch .cuda ._sleep (DELAY )
537
+ # Corrupt the original tensor
538
+ cpu_tensor .zero_ ()
539
+ assert (cuda_tensor == 1 ).all ()
540
+ print ("No test failed with non_blocking" )
541
+ except AssertionError :
542
+ print (f"{ i } th test failed with non_blocking. Skipping remaining tests" )
543
+
544
+ ######################################################################
545
+ # Using a pageable tensor always works:
519
546
#
520
547
548
+ i = - 1
549
+ for i in range (100 ):
550
+ # Create a tensor in pin-memory
551
+ cpu_tensor = torch .ones (1024 , 1024 )
552
+ torch .cuda .synchronize ()
553
+ # Send the tensor to CUDA
554
+ cuda_tensor = cpu_tensor .to ("cuda" , non_blocking = True )
555
+ torch .cuda ._sleep (DELAY )
556
+ # Corrupt the original tensor
557
+ cpu_tensor .zero_ ()
558
+ assert (cuda_tensor == 1 ).all ()
559
+
560
+ ######################################################################
561
+ # Now let's demonstrate that CUDA to CPU also fails to produce reliable outputs without synchronization:
521
562
522
563
tensor = (
523
564
torch .arange (1 , 1_000_000 , dtype = torch .double , device = "cuda" )
@@ -551,9 +592,8 @@ def pin_copy_to_device_nonblocking(*tensors):
551
592
552
593
553
594
######################################################################
554
- # The same considerations apply to copies from the CPU to non-CUDA devices, such as MPS.
555
595
# Generally, asynchronous copies to a device are safe without explicit synchronization only when the target is a
556
- # CUDA-enabled device.
596
+ # CUDA-enabled device and the original tensor is in pageable memory .
557
597
#
558
598
# In summary, copying data from CPU to GPU is safe when using ``non_blocking=True``, but for any other direction,
559
599
# ``non_blocking=True`` can still be used but the user must make sure that a device synchronization is executed before
0 commit comments