Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

[7/x] make profiling script support Float8Linear dynamic scaling #298

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions benchmarks/profile_linear_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import torch.nn as nn
import torch.nn.functional as F
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
from float8_experimental.float8_linear import Float8Linear
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
from float8_experimental.float8_linear_utils import (
linear_requires_sync,
LinearType,
Expand Down Expand Up @@ -207,6 +207,9 @@ def main(
profile_path_prefix: Path,
compile: bool = True,
linear_type: str = "dynamic",
scaling_type_x: str = "delayed",
scaling_type_w: str = "delayed",
scaling_type_dL_dY: str = "delayed",
model_type: str = "linear",
dtype_filter: str = "both",
):
Expand Down Expand Up @@ -250,9 +253,17 @@ def main(
linear_cls = (
Float8Linear if linear_type is LinearType.DELAYED else Float8DynamicLinear
)
extra_kwargs = {}
scaling_type_x = TensorScalingType(scaling_type_x)
scaling_type_w = TensorScalingType(scaling_type_w)
scaling_type_dL_dY = TensorScalingType(scaling_type_dL_dY)
if linear_type is LinearType.DELAYED:
extra_kwargs["scaling_type_x"] = scaling_type_x
extra_kwargs["scaling_type_w"] = scaling_type_w
extra_kwargs["scaling_type_dL_dY"] = scaling_type_dL_dY

m_float8 = copy.deepcopy(m_ref)
swap_linear_with_float8_linear(m_float8, linear_cls)
swap_linear_with_float8_linear(m_float8, linear_cls, **extra_kwargs)

def ref_forw_backward(x):
out = m_ref(x)
Expand All @@ -270,7 +281,9 @@ def float8_forw_backward_wrapper(x):
# inspection of the fw+bw torch.compile without the scale
# syncing code
# TODO(future): make this better
if linear_requires_sync(linear_type):
if linear_requires_sync(
linear_type, scaling_type_x, scaling_type_w, scaling_type_dL_dY
):
with record_function("scale_amax_and_scales"):
sync_amax_history(m_float8)
out = float8_forw(x)
Expand Down
Loading