Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit f7a920d

Browse files
drisspgfacebook-github-bot
authored andcommitted
enable float types in pytorch for non comptue comms (#263)
Summary: Coupled with this: pytorch/pytorch#126556 test everytihng is pasing Pull Request resolved: #263 Reviewed By: wanchaol Differential Revision: D57505783 Pulled By: drisspg fbshipit-source-id: cd928420f559839c63d79bfe7558416fbcfe1d69
1 parent 6891cbe commit f7a920d

File tree

2 files changed

+2
-3
lines changed

2 files changed

+2
-3
lines changed

float8_experimental/float8_ops.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -235,11 +235,8 @@ def allgather_fp8(aten_op, args, kwargs=None):
235235
), f"expecting a Float8Tensor for allgather but found {type(fp8_input)}"
236236

237237
fp8_data = fp8_input._data
238-
fp8_data = fp8_data.view(torch.uint8)
239238
fp8_data = fp8_data.contiguous()
240239
fp8_out = aten_op(fp8_data, *args[1:], **kwargs)
241-
fp8_out = torch.ops._c10d_functional.wait_tensor(fp8_out)
242-
fp8_out = fp8_out.view(fp8_input._data.dtype)
243240
return Float8Tensor(
244241
fp8_out, fp8_input._scale, fp8_input._orig_dtype, fp8_input._mm_config
245242
)

test/test_dtensor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,3 +246,5 @@ def test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16):
246246
except Exception as e:
247247
print(f"Test {test.__name__} failed with error: {e}")
248248
raise e
249+
250+
torch.distributed.destroy_process_group()

0 commit comments

Comments
 (0)