Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit c630f3b

Browse files
committed
comments
1 parent 48bd03a commit c630f3b

File tree

1 file changed

+16
-2
lines changed

1 file changed

+16
-2
lines changed

float8_experimental/float8_tensor.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,14 @@ def to_fp8_no_autograd(
5353
inner_float8_tensor = Float8Tensor(
5454
local_bits, local_scale, x.dtype, emulate=emulate
5555
)
56-
return DTensor.from_local(inner_float8_tensor, bits_mesh, bits_placements)
56+
return DTensor.from_local(
57+
inner_float8_tensor,
58+
bits_mesh,
59+
bits_placements,
60+
run_check=False,
61+
shape=bits_fp8.size(),
62+
stride=bits_fp8.stride(),
63+
)
5764

5865
return Float8Tensor(bits_fp8, x_scale, x.dtype, emulate=emulate)
5966

@@ -79,7 +86,14 @@ def to_original_precision(grad):
7986
if isinstance(x, DTensor):
8087
local_grad = x.to_local()
8188
original_precision_grad = to_original_precision(local_grad)
82-
return DTensor.from_local(original_precision_grad, x.device_mesh, x.placements)
89+
return DTensor.from_local(
90+
original_precision_grad,
91+
x.device_mesh,
92+
x.placements,
93+
run_check=False,
94+
shape=x.size(),
95+
stride=x.stride(),
96+
)
8397
else:
8498
return to_original_precision(x)
8599

0 commit comments

Comments
 (0)