Skip to content

Commit 23b6ebd

Browse files
drisspgZelboK
authored andcommitted
Map float8 types to uint8 for allgather (pytorch#126556)
# Summary Different take on this one: pytorch#126338 We should probably not allow this mapping for 'compute' ops e.g. reductions ### Corresponding fp8 PR pytorch-labs/float8_experimental#263 Pull Request resolved: pytorch#126556 Approved by: https://github.com/wanchaol
1 parent b51e6dd commit 23b6ebd

File tree

2 files changed

+115
-2
lines changed

2 files changed

+115
-2
lines changed

test/distributed/test_c10d_nccl.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2577,6 +2577,27 @@ def test_all_reduce_coalesced_nccl(self):
25772577
),
25782578
)
25792579

2580+
@requires_nccl()
2581+
@skip_if_lt_x_gpu(2)
2582+
def test_all_reduce_coalesced_nccl_float8_errors(self):
2583+
store = c10d.FileStore(self.file_name, self.world_size)
2584+
c10d.init_process_group(
2585+
backend="nccl", store=store, rank=self.rank, world_size=self.world_size
2586+
)
2587+
process_group = c10d.distributed_c10d._get_default_group()
2588+
device = torch.device("cuda:%d" % self.rank)
2589+
tensors = [
2590+
torch.full(
2591+
(60 + i,), self.rank + 1 + i, device=device, dtype=torch.float
2592+
).to(torch.float8_e4m3fn)
2593+
for i in range(5)
2594+
]
2595+
with self.assertRaisesRegex(
2596+
RuntimeError,
2597+
"Float8 dtypes are not currenlty supported for NCCL reductions",
2598+
):
2599+
torch.distributed.all_reduce_coalesced(tensors, group=process_group)
2600+
25802601
@requires_nccl()
25812602
@skip_if_lt_x_gpu(2)
25822603
def test_all_reduce_coalesced_manager_nccl(self):
@@ -2940,6 +2961,56 @@ def test_reduce_scatter_tensor_coalesced(self):
29402961
dist.reduce_scatter_tensor(output_tensors[i], input_tensors[i])
29412962
self.assertEqual(output_tensors, input_tensors[self.rank] * self.world_size)
29422963

2964+
@requires_nccl()
2965+
@skip_if_lt_x_gpu(2)
2966+
def test_reduce_scatter_base_k_float8_errors(self):
2967+
store = dist.FileStore(self.file_name, self.world_size)
2968+
dist.init_process_group(
2969+
"nccl",
2970+
world_size=self.world_size,
2971+
rank=self.rank,
2972+
store=store,
2973+
)
2974+
output_tensor = (
2975+
torch.zeros(2, dtype=torch.float32).to(torch.float8_e4m3fn).to(self.rank)
2976+
)
2977+
input_tensors = (
2978+
torch.arange(self.world_size * 2, dtype=torch.float32)
2979+
.to(torch.float8_e4m3fn)
2980+
.to(self.rank)
2981+
)
2982+
input_tensors = torch.reshape(input_tensors, (self.world_size, 2))
2983+
with self.assertRaisesRegex(
2984+
RuntimeError,
2985+
"Float8 dtypes are not currenlty supported for NCCL reductions",
2986+
):
2987+
dist.reduce_scatter_tensor(output_tensor, input_tensors)
2988+
2989+
@requires_nccl()
2990+
@skip_if_lt_x_gpu(2)
2991+
def test_reduce_scatter_tensor_coalesced_float8_errors(self):
2992+
store = dist.FileStore(self.file_name, self.world_size)
2993+
dist.init_process_group(
2994+
"nccl",
2995+
world_size=self.world_size,
2996+
rank=self.rank,
2997+
store=store,
2998+
)
2999+
output_tensors = torch.zeros(2, 2).to(torch.float8_e5m2).to(self.rank)
3000+
input_tensors = [
3001+
torch.ones(2, 2).to(torch.float8_e5m2).to(self.rank)
3002+
for _ in range(self.world_size)
3003+
]
3004+
3005+
with self.assertRaisesRegex(
3006+
RuntimeError,
3007+
"Float8 dtypes are not currenlty supported for NCCL reductions",
3008+
):
3009+
with dist._coalescing_manager():
3010+
for i in range(self.world_size):
3011+
dist.reduce_scatter_tensor(output_tensors[i], input_tensors[i])
3012+
self.assertEqual(output_tensors, input_tensors[self.rank])
3013+
29433014

29443015
class SetDeviceMethod(Enum):
29453016
TORCH_CUDA_SET = auto() # torch.cuda.set_device
@@ -2980,6 +3051,28 @@ def test_allgather_base(self):
29803051
dist.all_gather_into_tensor(output_tensor, tensor)
29813052
self.assertEqual(output_tensor, tensor)
29823053

3054+
@requires_nccl()
3055+
@skip_if_lt_x_gpu(1)
3056+
@parametrize("float8_dtype", [torch.float8_e4m3fn, torch.float8_e5m2])
3057+
def test_allgather_float8(self, float8_dtype):
3058+
store = dist.FileStore(self.file_name, self.world_size)
3059+
dist.init_process_group(
3060+
"nccl",
3061+
world_size=self.world_size,
3062+
rank=self.rank,
3063+
store=store,
3064+
)
3065+
device = "cuda"
3066+
tensor = torch.ones(10, 16, device=torch.device(device)).to(float8_dtype)
3067+
output_tensor = torch.zeros(10, 16, device=torch.device(device)).to(
3068+
float8_dtype
3069+
)
3070+
dist.all_gather_into_tensor(output_tensor, tensor)
3071+
self.assertEqual(output_tensor.view(torch.float32), tensor.view(torch.float32))
3072+
3073+
3074+
instantiate_parametrized_tests(NcclProcessGroupWithDispatchedCollectivesTests)
3075+
29833076

29843077
class LargeCommTest(test_c10d_common.AbstractLargeCommTest, MultiProcessTestCase):
29853078
def setUp(self):

torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
#ifdef USE_C10D_NCCL
32

43
#include <exception>
@@ -64,6 +63,10 @@ std::map<at::ScalarType, ncclDataType_t> ncclDataType = {
6463
{at::kLong, ncclInt64},
6564
{at::kHalf, ncclHalf},
6665
{at::kBool, ncclUint8},
66+
{at::kFloat8_e5m2, ncclUint8},
67+
{at::kFloat8_e4m3fn, ncclUint8},
68+
{at::kFloat8_e4m3fnuz, ncclUint8},
69+
{at::kFloat8_e5m2fnuz, ncclUint8},
6770
#if HAS_NCCL_BF16_DATATYPE
6871
{at::kBFloat16, ncclBfloat16},
6972
#endif
@@ -3039,6 +3042,9 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allreduce_sparse(
30393042
const AllreduceOptions& opts) {
30403043
TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG);
30413044
auto tensor = tensors.back();
3045+
TORCH_CHECK(
3046+
!isFloat8Type(tensor.scalar_type()),
3047+
"Float8 dtypes are not currenlty supported for NCCL reductions");
30423048
#ifdef IS_NCCLX
30433049
tensor = tensor.coalesce();
30443050
at::Tensor outputTensor =
@@ -3153,7 +3159,9 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allreduce(
31533159
return c10::make_intrusive<IntraNodeCommWork>();
31543160
}
31553161
}
3156-
3162+
TORCH_CHECK(
3163+
!isFloat8Type(tensor.scalar_type()),
3164+
"Float8 dtypes are not currenlty supported for NCCL reductions");
31573165
// @lint-ignore CLANGTIDY
31583166
RECORD_PARAM_COMMS_DATA(
31593167
static_cast<int>(
@@ -3180,6 +3188,9 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::allreduce_coalesced(
31803188
std::vector<at::Tensor>& tensors,
31813189
const AllreduceCoalescedOptions& opts) {
31823190
auto total_numel = check_gpu_tensors_same_device(tensors);
3191+
TORCH_CHECK(
3192+
!isFloat8Type(tensors.back().scalar_type()),
3193+
"Float8 dtypes are not currenlty supported for NCCL reductions");
31833194

31843195
// @lint-ignore CLANGTIDY
31853196
RECORD_PARAM_COMMS_DATA(
@@ -3552,6 +3563,9 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce_scatter(
35523563
check_gpu_single_tensor(outputTensor);
35533564
// @lint-ignore CLANGTIDY
35543565
auto inputTensors_ = inputTensors.back();
3566+
TORCH_CHECK(
3567+
!isFloat8Type(outputTensor.scalar_type()),
3568+
"Float8 dtypes are not currenlty supported for NCCL reductions");
35553569

35563570
RECORD_PARAM_COMMS_DATA(
35573571
static_cast<int>(
@@ -3663,6 +3677,9 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::_reduce_scatter_base(
36633677

36643678
// @lint-ignore CLANGTIDY
36653679
const auto& tensor = outputTensor;
3680+
TORCH_CHECK(
3681+
!isFloat8Type(tensor.scalar_type()),
3682+
"Float8 dtypes are not currenlty supported for NCCL reductions");
36663683
RECORD_PARAM_COMMS_DATA(
36673684
static_cast<int>(
36683685
this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective
@@ -3723,6 +3740,9 @@ c10::intrusive_ptr<Work> ProcessGroupNCCL::reduce_scatter_tensor_coalesced(
37233740
std::vector<at::Tensor>& outputs,
37243741
std::vector<at::Tensor>& inputs,
37253742
const ReduceScatterOptions& opts) {
3743+
TORCH_CHECK(
3744+
!isFloat8Type(inputs.back().scalar_type()),
3745+
"Float8 dtypes are not currenlty supported for NCCL reductions");
37263746
return collectiveCoalesced(
37273747
inputs,
37283748
outputs,

0 commit comments

Comments
 (0)