Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 7a1bdab

Browse files
vkuzofacebook-github-bot
authored andcommitted
make profiling script support Float8Linear dynamic scaling (#298)
Summary: Pull Request resolved: #298 Run with relevant settings and verify: 1. performance of Float8Linear with dynamic scaling is very close to Float8DynamicLinear 2. if we start with all delayed scaling and gradually turn on dynamic scaling tensor by tensor, performance decreases and approaches that of (1) Reviewed By: drisspg Differential Revision: D59305795 fbshipit-source-id: e5d525d1bdd22e78b4a0f9b068e0115f3f4336f5
1 parent 88d9407 commit 7a1bdab

File tree

1 file changed

+16
-3
lines changed

1 file changed

+16
-3
lines changed

benchmarks/profile_linear_float8.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import torch.nn as nn
2020
import torch.nn.functional as F
2121
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
22-
from float8_experimental.float8_linear import Float8Linear
22+
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
2323
from float8_experimental.float8_linear_utils import (
2424
linear_requires_sync,
2525
LinearType,
@@ -207,6 +207,9 @@ def main(
207207
profile_path_prefix: Path,
208208
compile: bool = True,
209209
linear_type: str = "dynamic",
210+
scaling_type_x: str = "delayed",
211+
scaling_type_w: str = "delayed",
212+
scaling_type_dL_dY: str = "delayed",
210213
model_type: str = "linear",
211214
dtype_filter: str = "both",
212215
):
@@ -250,9 +253,17 @@ def main(
250253
linear_cls = (
251254
Float8Linear if linear_type is LinearType.DELAYED else Float8DynamicLinear
252255
)
256+
extra_kwargs = {}
257+
scaling_type_x = TensorScalingType(scaling_type_x)
258+
scaling_type_w = TensorScalingType(scaling_type_w)
259+
scaling_type_dL_dY = TensorScalingType(scaling_type_dL_dY)
260+
if linear_type is LinearType.DELAYED:
261+
extra_kwargs["scaling_type_x"] = scaling_type_x
262+
extra_kwargs["scaling_type_w"] = scaling_type_w
263+
extra_kwargs["scaling_type_dL_dY"] = scaling_type_dL_dY
253264

254265
m_float8 = copy.deepcopy(m_ref)
255-
swap_linear_with_float8_linear(m_float8, linear_cls)
266+
swap_linear_with_float8_linear(m_float8, linear_cls, **extra_kwargs)
256267

257268
def ref_forw_backward(x):
258269
out = m_ref(x)
@@ -270,7 +281,9 @@ def float8_forw_backward_wrapper(x):
270281
# inspection of the fw+bw torch.compile without the scale
271282
# syncing code
272283
# TODO(future): make this better
273-
if linear_requires_sync(linear_type):
284+
if linear_requires_sync(
285+
linear_type, scaling_type_x, scaling_type_w, scaling_type_dL_dY
286+
):
274287
with record_function("scale_amax_and_scales"):
275288
sync_amax_history(m_float8)
276289
out = float8_forw(x)

0 commit comments

Comments
 (0)