Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

static scaling support for training #306

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 108 additions & 76 deletions benchmarks/profile_linear_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,18 @@ def forward(self, h):
return x


class SigmoidLinear(nn.Module):
def __init__(self, d1, d2):
super().__init__()
self.sigmoid1 = nn.Sigmoid()
self.fc = nn.Linear(d1, d2)

def forward(self, x):
x = self.sigmoid1(x)
x = self.fc(x)
return x


@dataclass
class ProfileConfig:
file_path: Optional[str] = None
Expand Down Expand Up @@ -210,7 +222,12 @@ def main(
model_type: str = "linear",
dtype_filter: str = "both",
):
assert model_type in ("linear", "ln_linear", "norm_ffn_norm"), "unsupported"
assert model_type in (
"linear",
"ln_linear",
"norm_ffn_norm",
"sigmoid_linear",
), "unsupported"
assert dtype_filter in ("both", "float8", "bfloat16")

scaling_type_x = TensorScalingType(scaling_type_x)
Expand Down Expand Up @@ -242,6 +259,12 @@ def main(
input_tensor = torch.randn(
1, 8192, 4096, device=device, dtype=ref_dtype
).requires_grad_()
elif model_type == "sigmoid_linear":
bsz, d1, d2 = 4096, 4096, 4096
m_ref = SigmoidLinear(d1, d2)
input_tensor = torch.randn(
bsz, d1, device=device, dtype=ref_dtype, requires_grad=True
)
else:
M, K, N = 4 * 4096, 8192, 7168
m_ref = torch.nn.Sequential(
Expand All @@ -258,6 +281,9 @@ def main(
"scaling_type_w": scaling_type_w,
"scaling_type_dL_dY": scaling_type_dL_dY,
}
if scaling_type_x is TensorScalingType.STATIC:
# for now, dummy scale
extra_kwargs["static_scale_x"] = torch.tensor(1.0, device="cuda")

m_float8 = copy.deepcopy(m_ref)
swap_linear_with_float8_linear(m_float8, **extra_kwargs)
Expand Down Expand Up @@ -300,85 +326,91 @@ def float8_forw_backward_wrapper(x):
# if the `TORCHINDUCTOR_PROFILE` env var is enabled, parse its output
# to populate triton kernel bandwidth further down in the script
f = io.StringIO()
with redirect_stdout(f):
# warm up
for _ in range(1):
try:
with redirect_stdout(f):
# warm up
for _ in range(1):
if dtype_filter != "float8":
ref_forw_backward(input_tensor)
if dtype_filter != "bfloat16":
float8_forw_backward_wrapper(input_tensor)

profile_iters = 5
ref_times, float8_times = None, None
data = []

if dtype_filter != "float8":
ref_forw_backward(input_tensor)
if dtype_filter != "bfloat16":
float8_forw_backward_wrapper(input_tensor)

profile_iters = 5
ref_times, float8_times = None, None
data = []

if dtype_filter != "float8":
# Profile Reference Model
print("profiling ref")
ref_suffix = f"_{model_type}_ref_compile_{compile}.json"
ref_path = profile_path_prefix + ref_suffix
profile_config = ProfileConfig(
ref_path, ref_suffix, iters=profile_iters, warmup_iters=2, sync=True
)
p = profile_function(profile_config, ref_forw_backward, input_tensor)
print(f"saved {ref_path}")
ref_times = profiler_output_to_time_by_kernel_name(p)
total_time_ms = sum(v for v in ref_times.values()) / 1e3 / profile_iters
for k, v in ref_times.items():
v_ms = v / 1e3 / profile_iters
data.append(
[
"0_ref",
k,
kernel_name_to_category(k),
v_ms,
v_ms / total_time_ms,
None,
]
# Profile Reference Model
print("profiling ref")
ref_suffix = f"_{model_type}_ref_compile_{compile}.json"
ref_path = profile_path_prefix + ref_suffix
profile_config = ProfileConfig(
ref_path, ref_suffix, iters=profile_iters, warmup_iters=2, sync=True
)
p = profile_function(profile_config, ref_forw_backward, input_tensor)
print(f"saved {ref_path}")
ref_times = profiler_output_to_time_by_kernel_name(p)
total_time_ms = sum(v for v in ref_times.values()) / 1e3 / profile_iters
for k, v in ref_times.items():
v_ms = v / 1e3 / profile_iters
data.append(
[
"0_ref",
k,
kernel_name_to_category(k),
v_ms,
v_ms / total_time_ms,
None,
]
)

if dtype_filter != "bfloat16":
# Profile Float8 Model
print("profiling float8")
float8_suffix = (
f"_{model_type}_float8_compile_{compile}_{scaling_repr}.json"
)
float8_path = profile_path_prefix + float8_suffix
profile_config = ProfileConfig(
float8_path,
float8_suffix,
iters=profile_iters,
warmup_iters=2,
sync=True,
)
p = profile_function(
profile_config, float8_forw_backward_wrapper, input_tensor
)
print(f"saved {float8_path}")
float8_times = profiler_output_to_time_by_kernel_name(p)
total_time_ms = sum(v for v in float8_times.values()) / 1e3 / profile_iters
for k, v in float8_times.items():
v_ms = v / 1e3 / profile_iters
data.append(
[
"1_float8",
k,
kernel_name_to_category(k),
v / 1e3 / profile_iters,
v_ms / total_time_ms,
None,
]
if dtype_filter != "bfloat16":
# Profile Float8 Model
print("profiling float8")
float8_suffix = (
f"_{model_type}_float8_compile_{compile}_{scaling_repr}.json"
)

# get the time spent per user annotation
sync_time_us = profiler_output_to_gpu_time_for_key(
p, "scale_amax_and_scales"
)
sync_time_ms = sync_time_us / profile_iters / 1e3
print(f"Sync time ms: {sync_time_ms}")

# print the redirected stdout back to regular stdout
print(f.getvalue())
float8_path = profile_path_prefix + float8_suffix
profile_config = ProfileConfig(
float8_path,
float8_suffix,
iters=profile_iters,
warmup_iters=2,
sync=True,
)
p = profile_function(
profile_config, float8_forw_backward_wrapper, input_tensor
)
print(f"saved {float8_path}")
float8_times = profiler_output_to_time_by_kernel_name(p)
total_time_ms = (
sum(v for v in float8_times.values()) / 1e3 / profile_iters
)
for k, v in float8_times.items():
v_ms = v / 1e3 / profile_iters
data.append(
[
"1_float8",
k,
kernel_name_to_category(k),
v / 1e3 / profile_iters,
v_ms / total_time_ms,
None,
]
)

# get the time spent per user annotation
sync_time_us = profiler_output_to_gpu_time_for_key(
p, "scale_amax_and_scales"
)
sync_time_ms = sync_time_us / profile_iters / 1e3
print(f"Sync time ms: {sync_time_ms}")

finally:
# print the redirected stdout back to regular stdout
# the finally clause is to help print output in the presence of exceptions,
# to aid local debugging
print(f.getvalue())

# populate the triton kernel bandwidth
for line in f.getvalue().split("\n"):
Expand Down
40 changes: 33 additions & 7 deletions float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,13 +135,16 @@ def __init__(self, history_len: int = 16, scale_fn_name: str = "max"):
class TensorScalingType(enum.Enum):
DELAYED = "delayed"
DYNAMIC = "dynamic"
STATIC = "static" # note: only supported for `x`

def short_str(self):
if self is TensorScalingType.DELAYED:
return "del"
else:
assert self is TensorScalingType.DYNAMIC
elif self is TensorScalingType.DYNAMIC:
return "dyn"
else:
assert self is TensorScalingType.STATIC
return "sta"


class Float8Linear(torch.nn.Linear):
Expand All @@ -154,9 +157,10 @@ def __init__(self, *args, **kwargs):
"""
Additional arguments on top of `torch.nn.Linear`'s arguments:
* `delayed_scaling_recipe`: configuration for delayed scaling
* `scaling_type_x`: delayed vs dynamic scaling for `x`
* `scaling_type_w`: delayed vs dynamic scaling for `w`
* `scaling_type_dL_dY`: delayed vs dynamic scaling for `dL_dY`
* `scaling_type_x`: dynamic/delayed/static scaling for `x`
* `scaling_type_w`: dynamic/delayed scaling for `w`
* `scaling_type_dL_dY`: dynamic/delayed scaling for `dL_dY`
* `static_scale_x`: static scale for `x`, requires static scaling
"""

delayed_scaling_recipe = kwargs.pop(
Expand All @@ -168,19 +172,32 @@ def __init__(self, *args, **kwargs):
scaling_type_x = kwargs.pop("scaling_type_x", TensorScalingType.DELAYED)
scaling_type_w = kwargs.pop("scaling_type_w", TensorScalingType.DELAYED)
scaling_type_dL_dY = kwargs.pop("scaling_type_dL_dY", TensorScalingType.DELAYED)
static_scale_x = kwargs.pop("static_scale_x", None)
super().__init__(*args, **kwargs)

# Defines the scaling behavior of x, w, dL_dY
self.scaling_type_x = scaling_type_x
self.scaling_type_w = scaling_type_w
assert self.scaling_type_w in (
TensorScalingType.DELAYED,
TensorScalingType.DYNAMIC,
), "unsupported"
self.scaling_type_dL_dY = scaling_type_dL_dY
assert self.scaling_type_dL_dY in (
TensorScalingType.DELAYED,
TensorScalingType.DYNAMIC,
), "unsupported"
# Convenience flag to skip code related to delayed scaling
self.has_any_delayed_scaling = (
self.scaling_type_x is TensorScalingType.DELAYED
or self.scaling_type_w is TensorScalingType.DELAYED
or self.scaling_type_dL_dY is TensorScalingType.DELAYED
)

if static_scale_x is not None:
assert self.scaling_type_x is TensorScalingType.STATIC, "unsupported"
self.register_always_float32_buffer("static_scale_x", static_scale_x)

# TODO(future): have a unique recipe per buffer instead of one per
# module, saving implementing that until we need it.
# TODO(future): serialization for recipes
Expand Down Expand Up @@ -306,9 +323,16 @@ def cast_x_to_float8(
self.fp8_amax_x,
self.forward_config,
)
else:
assert self.scaling_type_x is TensorScalingType.DYNAMIC
elif self.scaling_type_x is TensorScalingType.DYNAMIC:
x_fp8 = cast_to_float8_e4m3_dynamic(x, self.forward_config)
else:
assert self.scaling_type_x is TensorScalingType.STATIC
x_fp8 = Float8Tensor.to_float8(
x,
self.static_scale_x,
e4m3_dtype,
mm_config=self.forward_config,
)
return x_fp8

def cast_w_to_float8(
Expand Down Expand Up @@ -417,6 +441,7 @@ def from_float(
scaling_type_x=TensorScalingType.DELAYED,
scaling_type_w=TensorScalingType.DELAYED,
scaling_type_dL_dY=TensorScalingType.DELAYED,
static_scale_x=None,
):
"""
Create an nn.Linear with fp8 compute from a regular nn.Linear
Expand All @@ -434,6 +459,7 @@ def from_float(
scaling_type_w=scaling_type_w,
scaling_type_dL_dY=scaling_type_dL_dY,
emulate=emulate,
static_scale_x=static_scale_x,
)
if (
scaling_type_w == TensorScalingType.DYNAMIC
Expand Down
3 changes: 3 additions & 0 deletions float8_experimental/float8_linear_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def swap_linear_with_float8_linear(
scaling_type_x: TensorScalingType = TensorScalingType.DYNAMIC,
scaling_type_w: TensorScalingType = TensorScalingType.DYNAMIC,
scaling_type_dL_dY: TensorScalingType = TensorScalingType.DYNAMIC,
static_scale_x: Optional[float] = None,
) -> Optional[nn.Module]:
"""
Swaps `torch.nn.Linear` in `module` with `Float8Linear`.
Expand All @@ -167,6 +168,7 @@ def swap_linear_with_float8_linear(
scaling_type_x (TensorScalingType): scaling type for `x`
scaling_type_w (TensorScalingType): scaling type for `w`
scaling_type_dL_dY (TensorScalingType): scaling type for `dL_dY`
static_scale_x: static scale for `x`

Returns:
nn.Module: The modified module with swapped linear layers.
Expand All @@ -177,6 +179,7 @@ def swap_linear_with_float8_linear(
scaling_type_x=scaling_type_x,
scaling_type_w=scaling_type_w,
scaling_type_dL_dY=scaling_type_dL_dY,
static_scale_x=static_scale_x,
)
return swap_linear_layers(
module,
Expand Down
Loading
Loading