Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 02a80f9

Browse files
committed
enable float types in pytorch for non comptue comms
1 parent 6891cbe commit 02a80f9

File tree

2 files changed

+2
-3
lines changed

2 files changed

+2
-3
lines changed

float8_experimental/float8_ops.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -235,11 +235,8 @@ def allgather_fp8(aten_op, args, kwargs=None):
235235
), f"expecting a Float8Tensor for allgather but found {type(fp8_input)}"
236236

237237
fp8_data = fp8_input._data
238-
fp8_data = fp8_data.view(torch.uint8)
239238
fp8_data = fp8_data.contiguous()
240239
fp8_out = aten_op(fp8_data, *args[1:], **kwargs)
241-
fp8_out = torch.ops._c10d_functional.wait_tensor(fp8_out)
242-
fp8_out = fp8_out.view(fp8_input._data.dtype)
243240
return Float8Tensor(
244241
fp8_out, fp8_input._scale, fp8_input._orig_dtype, fp8_input._mm_config
245242
)

test/test_dtensor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,3 +246,5 @@ def test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16):
246246
except Exception as e:
247247
print(f"Test {test.__name__} failed with error: {e}")
248248
raise e
249+
250+
torch.distributed.destroy_process_group()

0 commit comments

Comments
 (0)