Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 6891cbe

Browse files
bdhirshfacebook-github-bot
authored andcommitted
add wait_tensor() after all_gather in float8 to fix mem leak (#262)
Summary: I'm going to write a more detailed post internally to explain this memory leak Tracking issue for a better fix in inductor: pytorch/pytorch#126338 Pull Request resolved: #262 Reviewed By: drisspg Differential Revision: D57464230 Pulled By: bdhirsh fbshipit-source-id: 134c50e95045c43f95b5aec4dd3df496ff3fb9a3
1 parent cb55df2 commit 6891cbe

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

float8_experimental/float8_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,7 @@ def allgather_fp8(aten_op, args, kwargs=None):
238238
fp8_data = fp8_data.view(torch.uint8)
239239
fp8_data = fp8_data.contiguous()
240240
fp8_out = aten_op(fp8_data, *args[1:], **kwargs)
241+
fp8_out = torch.ops._c10d_functional.wait_tensor(fp8_out)
241242
fp8_out = fp8_out.view(fp8_input._data.dtype)
242243
return Float8Tensor(
243244
fp8_out, fp8_input._scale, fp8_input._orig_dtype, fp8_input._mm_config

0 commit comments

Comments
 (0)