Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit f3630d0

Browse files
committed
update to make more comple friendly
1 parent bb0a585 commit f3630d0

File tree

5 files changed

+41
-38
lines changed

5 files changed

+41
-38
lines changed

float8_experimental/float8_linear.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -138,23 +138,24 @@ def __init__(self, *args, **kwargs):
138138
self.recipe = delayed_scaling_recipe
139139
history_len = self.recipe.history_len
140140

141-
self.register_always_float32_buffer("fp8_amax_x", torch.tensor(E4M3_MAX_POS))
141+
self.register_always_float32_buffer("fp8_amax_x", torch.tensor([E4M3_MAX_POS]))
142142
self.register_always_float32_buffer(
143143
"fp8_amax_history_x", torch.zeros(history_len)
144144
)
145-
self.register_always_float32_buffer("fp8_scale_x", torch.tensor(1.0))
146-
self.register_always_float32_buffer("fp8_amax_w", torch.tensor(E4M3_MAX_POS))
145+
self.register_always_float32_buffer("fp8_scale_x", torch.tensor([1.0]))
146+
self.register_always_float32_buffer("fp8_amax_w", torch.tensor([E4M3_MAX_POS]))
147147
self.register_always_float32_buffer(
148148
"fp8_amax_history_w", torch.zeros(history_len)
149149
)
150-
self.register_always_float32_buffer("fp8_scale_w", torch.tensor(1.0))
150+
self.register_always_float32_buffer("fp8_scale_w", torch.tensor([1.0]))
151151
self.register_always_float32_buffer(
152-
"fp8_amax_dL_dY", torch.tensor(E5M2_MAX_POS)
152+
"fp8_amax_dL_dY", torch.tensor([E5M2_MAX_POS])
153153
)
154154
self.register_always_float32_buffer(
155155
"fp8_amax_history_dL_dY", torch.zeros(history_len)
156156
)
157-
self.register_always_float32_buffer("fp8_scale_dL_dY", torch.tensor(1.0))
157+
self.register_always_float32_buffer("fp8_scale_dL_dY", torch.tensor([1.0]))
158+
158159
# Whether to emulate the fp8 matmul logic in float32
159160
self.emulate = False
160161

float8_experimental/float8_linear_utils.py

Lines changed: 28 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -145,39 +145,37 @@ def sync_float8_amax_and_scale_history(model: torch.nn.Module, fp8_layers=None)
145145
fp8_layers = get_float8_layers(model)
146146

147147
if dist.is_initialized():
148-
fp8_amax_x_tensor = torch.tensor(
149-
[child.fp8_amax_x for child in fp8_layers],
150-
dtype=torch.float32,
151-
device="cuda",
152-
requires_grad=False,
153-
)
154-
fp8_amax_w_tensor = torch.tensor(
155-
[child.fp8_amax_w for child in fp8_layers],
156-
dtype=torch.float32,
157-
device="cuda",
158-
requires_grad=False,
159-
)
160-
fp8_amax_dL_dY_tensor = torch.tensor(
161-
[child.fp8_amax_dL_dY for child in fp8_layers],
162-
dtype=torch.float32,
163-
device="cuda",
164-
requires_grad=False,
165-
)
166-
dist.all_reduce(fp8_amax_x_tensor, op=dist.ReduceOp.MAX)
167-
dist.all_reduce(fp8_amax_w_tensor, op=dist.ReduceOp.MAX)
168-
dist.all_reduce(fp8_amax_dL_dY_tensor, op=dist.ReduceOp.MAX)
169-
148+
fp8_amax_x_tensors = [child.fp8_amax_x for child in fp8_layers]
149+
fp8_amax_w_tensors = [child.fp8_amax_w for child in fp8_layers]
150+
fp8_amax_dL_dY_tensors = [child.fp8_amax_dL_dY for child in fp8_layers]
151+
152+
assert (
153+
len(fp8_amax_x_tensors)
154+
== len(fp8_amax_w_tensors)
155+
== len(fp8_amax_dL_dY_tensors)
156+
), "Mismatched lengths of amax tensors."
157+
if len(fp8_amax_x_tensors) > 0:
158+
# Combine all the amax tensors into one tensor and reduce it
159+
fp8_amax_x_tensor = torch.cat(fp8_amax_x_tensors)
160+
fp8_amax_w_tensor = torch.cat(fp8_amax_w_tensors)
161+
fp8_amax_dL_dY_tensor = torch.cat(fp8_amax_dL_dY_tensors)
162+
163+
dist.all_reduce(fp8_amax_x_tensor, op=dist.ReduceOp.MAX)
164+
dist.all_reduce(fp8_amax_w_tensor, op=dist.ReduceOp.MAX)
165+
dist.all_reduce(fp8_amax_dL_dY_tensor, op=dist.ReduceOp.MAX)
166+
167+
# Reassign the reduced amax values to the original tensors
168+
169+
for idx in range(len(fp8_layers)):
170+
child = fp8_layers[idx]
171+
child.fp8_amax_x.copy_(fp8_amax_x_tensor[idx].clone())
172+
child.fp8_amax_w.copy_(fp8_amax_w_tensor[idx].clone())
173+
child.fp8_amax_dL_dY.copy_(fp8_amax_dL_dY_tensor[idx].clone())
174+
175+
# Itearte over all the layers and update the amax history and scales
170176
for idx in range(len(fp8_layers)):
171177
child = fp8_layers[idx]
172178

173-
#
174-
# 1. in distributed contexts, syncs amax values across workers
175-
#
176-
if dist.is_initialized():
177-
child.fp8_amax_x = fp8_amax_x_tensor[idx].clone()
178-
child.fp8_amax_w = fp8_amax_w_tensor[idx].clone()
179-
child.fp8_amax_dL_dY = fp8_amax_dL_dY_tensor[idx].clone()
180-
181179
#
182180
# 2. adds the `amax` values to history
183181
#

float8_experimental/float8_python_api.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
from typing import Optional, Tuple
1414

15+
import float8_experimental.float8_aten_api # noqa
16+
1517
import torch
1618
from float8_experimental.float8_tensor import Float8Tensor
1719

float8_experimental/float8_tensor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,9 @@ class FromFloat8ConstrFunc(torch.autograd.Function):
5454

5555
@staticmethod
5656
def forward(ctx, tensor):
57-
return tensor._data.to(tensor._orig_dtype) / tensor._scale
57+
return (tensor._data.to(tensor._orig_dtype) / tensor._scale).to(
58+
tensor._orig_dtype
59+
)
5860

5961
@staticmethod
6062
def backward(ctx, g):

float8_experimental/float8_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
@torch.no_grad()
2525
def amax_to_scale(amax, float8_dtype, orig_dtype):
26-
scale = torch.empty((), device=amax.device, dtype=torch.float32)
26+
scale = torch.empty((1,), device=amax.device, dtype=torch.float32)
2727
if float8_dtype == torch.float8_e4m3fn:
2828
res = E4M3_MAX_POS / torch.clamp(amax, min=EPS)
2929
else: # e5m2

0 commit comments

Comments
 (0)