@@ -2577,6 +2577,27 @@ def test_all_reduce_coalesced_nccl(self):
2577
2577
),
2578
2578
)
2579
2579
2580
+ @requires_nccl ()
2581
+ @skip_if_lt_x_gpu (2 )
2582
+ def test_all_reduce_coalesced_nccl_float8_errors (self ):
2583
+ store = c10d .FileStore (self .file_name , self .world_size )
2584
+ c10d .init_process_group (
2585
+ backend = "nccl" , store = store , rank = self .rank , world_size = self .world_size
2586
+ )
2587
+ process_group = c10d .distributed_c10d ._get_default_group ()
2588
+ device = torch .device ("cuda:%d" % self .rank )
2589
+ tensors = [
2590
+ torch .full (
2591
+ (60 + i ,), self .rank + 1 + i , device = device , dtype = torch .float
2592
+ ).to (torch .float8_e4m3fn )
2593
+ for i in range (5 )
2594
+ ]
2595
+ with self .assertRaisesRegex (
2596
+ RuntimeError ,
2597
+ "Float8 dtypes are not currenlty supported for NCCL reductions" ,
2598
+ ):
2599
+ torch .distributed .all_reduce_coalesced (tensors , group = process_group )
2600
+
2580
2601
@requires_nccl ()
2581
2602
@skip_if_lt_x_gpu (2 )
2582
2603
def test_all_reduce_coalesced_manager_nccl (self ):
@@ -2940,6 +2961,56 @@ def test_reduce_scatter_tensor_coalesced(self):
2940
2961
dist .reduce_scatter_tensor (output_tensors [i ], input_tensors [i ])
2941
2962
self .assertEqual (output_tensors , input_tensors [self .rank ] * self .world_size )
2942
2963
2964
+ @requires_nccl ()
2965
+ @skip_if_lt_x_gpu (2 )
2966
+ def test_reduce_scatter_base_k_float8_errors (self ):
2967
+ store = dist .FileStore (self .file_name , self .world_size )
2968
+ dist .init_process_group (
2969
+ "nccl" ,
2970
+ world_size = self .world_size ,
2971
+ rank = self .rank ,
2972
+ store = store ,
2973
+ )
2974
+ output_tensor = (
2975
+ torch .zeros (2 , dtype = torch .float32 ).to (torch .float8_e4m3fn ).to (self .rank )
2976
+ )
2977
+ input_tensors = (
2978
+ torch .arange (self .world_size * 2 , dtype = torch .float32 )
2979
+ .to (torch .float8_e4m3fn )
2980
+ .to (self .rank )
2981
+ )
2982
+ input_tensors = torch .reshape (input_tensors , (self .world_size , 2 ))
2983
+ with self .assertRaisesRegex (
2984
+ RuntimeError ,
2985
+ "Float8 dtypes are not currenlty supported for NCCL reductions" ,
2986
+ ):
2987
+ dist .reduce_scatter_tensor (output_tensor , input_tensors )
2988
+
2989
+ @requires_nccl ()
2990
+ @skip_if_lt_x_gpu (2 )
2991
+ def test_reduce_scatter_tensor_coalesced_float8_errors (self ):
2992
+ store = dist .FileStore (self .file_name , self .world_size )
2993
+ dist .init_process_group (
2994
+ "nccl" ,
2995
+ world_size = self .world_size ,
2996
+ rank = self .rank ,
2997
+ store = store ,
2998
+ )
2999
+ output_tensors = torch .zeros (2 , 2 ).to (torch .float8_e5m2 ).to (self .rank )
3000
+ input_tensors = [
3001
+ torch .ones (2 , 2 ).to (torch .float8_e5m2 ).to (self .rank )
3002
+ for _ in range (self .world_size )
3003
+ ]
3004
+
3005
+ with self .assertRaisesRegex (
3006
+ RuntimeError ,
3007
+ "Float8 dtypes are not currenlty supported for NCCL reductions" ,
3008
+ ):
3009
+ with dist ._coalescing_manager ():
3010
+ for i in range (self .world_size ):
3011
+ dist .reduce_scatter_tensor (output_tensors [i ], input_tensors [i ])
3012
+ self .assertEqual (output_tensors , input_tensors [self .rank ])
3013
+
2943
3014
2944
3015
class SetDeviceMethod (Enum ):
2945
3016
TORCH_CUDA_SET = auto () # torch.cuda.set_device
@@ -2980,6 +3051,28 @@ def test_allgather_base(self):
2980
3051
dist .all_gather_into_tensor (output_tensor , tensor )
2981
3052
self .assertEqual (output_tensor , tensor )
2982
3053
3054
+ @requires_nccl ()
3055
+ @skip_if_lt_x_gpu (1 )
3056
+ @parametrize ("float8_dtype" , [torch .float8_e4m3fn , torch .float8_e5m2 ])
3057
+ def test_allgather_float8 (self , float8_dtype ):
3058
+ store = dist .FileStore (self .file_name , self .world_size )
3059
+ dist .init_process_group (
3060
+ "nccl" ,
3061
+ world_size = self .world_size ,
3062
+ rank = self .rank ,
3063
+ store = store ,
3064
+ )
3065
+ device = "cuda"
3066
+ tensor = torch .ones (10 , 16 , device = torch .device (device )).to (float8_dtype )
3067
+ output_tensor = torch .zeros (10 , 16 , device = torch .device (device )).to (
3068
+ float8_dtype
3069
+ )
3070
+ dist .all_gather_into_tensor (output_tensor , tensor )
3071
+ self .assertEqual (output_tensor .view (torch .float32 ), tensor .view (torch .float32 ))
3072
+
3073
+
3074
+ instantiate_parametrized_tests (NcclProcessGroupWithDispatchedCollectivesTests )
3075
+
2983
3076
2984
3077
class LargeCommTest (test_c10d_common .AbstractLargeCommTest , MultiProcessTestCase ):
2985
3078
def setUp (self ):
0 commit comments