Skip to content

Commit 19eab2a

Browse files
committed
switch float8 logic from Float8DynamicLinear to Float8Linear
Summary: In the stack ending in pytorch-labs/float8_experimental#300 in float8_experimental, we are unifying `Float8DynamicLinear` and `Float8Linear`, with a future PR being planned to delete the `Float8DynamicLinear` object. After pytorch-labs/float8_experimental#300, `Float8Linear` with default settings is equivalent to `Float8DynamicLinear`. This PR changes `torchtitan` to use `Float8Linear`. To support the new UX of `float8_experimental` better, I also switched the `fp8_linear` configuration to be a boolean on whether to swap the linears or not. In the future we can add new options on how to configure each linear (scaling type, scaling granularity, etc) - saving that for a future PR. Test Plan: ``` // run baseline (Float8DynamicLinear) for llama3_8b for 50 iterations on 4 GPUs, // verify performance and loss values do not change meaningfully between // baseline and this PR // baseline (before this PR) // 1. compile, bf16 // 2. compile, float8 // 3. compile, float8, fdsp_fp8_allgather=True // 4. compile, float8, fdsp_fp8_allgather=True, tp=2 // logs: https://gist.github.com/vkuzo/e6d5f3b15349862bfad3706baad8c9ce // experiment (this PR): repeat all of the above, but with Float8Linear // logs: https://gist.github.com/vkuzo/a4d6754358facffa64df931654459631 ``` Reviewers: Subscribers: Tasks: Tags:
1 parent b0ed7f0 commit 19eab2a

File tree

8 files changed

+18
-34
lines changed

8 files changed

+18
-34
lines changed

torchtitan/config_manager.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -339,15 +339,11 @@ def __init__(self):
339339
)
340340
self.parser.add_argument(
341341
"--training.fp8_linear",
342-
type=str,
343-
default="",
344-
choices=[
345-
"dynamic",
346-
"",
347-
], # TODO: add "delayed" option back in when supported
342+
action="store_true",
348343
help="""
349-
Type of fp8 linear quantization to apply to the model ['', 'dynamic'].
350-
This features requires you to install 'float8_experimental' which can be found
344+
If true, swaps `torch.nn.Linear` with `Float8Linear` with
345+
default settings (dynamic scaling).
346+
This feature requires you to install 'float8_experimental' which can be found
351347
here: https://github.com/pytorch-labs/float8_experimental
352348
""",
353349
)

torchtitan/float8_linear.py

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,34 +21,22 @@
2121

2222
def build_fp8_linear(model: nn.Module, job_config: JobConfig):
2323
"""
24-
This function converts the linear layers to one of the fp8 types:
25-
- Float8DynamicLinear: Dynamic quantization of the weights and the activations
26-
- [Not Yet Supported] Float8Linear: Uses a history of amaxs to quantize the weights and activations
24+
This function converts the linear layers to `Float8Linear`. Note that today,
25+
only dynamic tensor scaling (the default) is supported.
2726
2827
This will mutate the model inplace.
2928
"""
30-
linear_type = job_config.training.fp8_linear.lower()
29+
use_fp8_linear = job_config.training.fp8_linear
3130
try:
32-
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
33-
34-
# from float8_experimental.float8_linear import Float8Linear
31+
from float8_experimental.float8_linear import Float8Linear
3532
from float8_experimental.float8_linear_utils import (
3633
swap_linear_with_float8_linear,
3734
)
3835
except ImportError as exc:
3936
raise ImportError(
4037
"float8_experimental is not installed. Please install it to use fp8 linear layers."
4138
) from exc
42-
if linear_type:
43-
linear_type_map = {
44-
# "delayed": Float8Linear, # TODO: add "delayed" option back in when supported
45-
"dynamic": Float8DynamicLinear,
46-
}
47-
assert (
48-
linear_type in linear_type_map
49-
), f"Invalid fp8 linear type: {linear_type}, supported types: {', '.join(linear_type_map.keys())}."
50-
float8_linear_type = linear_type_map[linear_type.lower()]
51-
52-
# Mutates the model inplace replacing instances of torch.nn.Linear with float8_linear_type
53-
swap_linear_with_float8_linear(model, float8_linear_type)
54-
logger.info(f"Swapped to {linear_type} float8 linear layers")
39+
if use_fp8_linear:
40+
# Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear
41+
swap_linear_with_float8_linear(model, Float8Linear)
42+
logger.info("Swapped to Float8Linear layers")

train_configs/debug_model.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ max_norm = 1.0 # grad norm clipping
3737
steps = 10
3838
data_parallel_degree = -1
3939
tensor_parallel_degree = 1
40-
fp8_linear = ""
40+
fp8_linear = false
4141
compile = false
4242
dataset = "c4_mini" # supported datasets: c4_mini (45K), c4 (177M)
4343

train_configs/llama2_13b.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ max_norm = 1.0 # grad norm clipping
3333
steps = 1000
3434
data_parallel_degree = -1
3535
tensor_parallel_degree = 1
36-
fp8_linear = ""
36+
fp8_linear = false
3737
compile = false
3838
dataset = "c4"
3939

train_configs/llama2_70b.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ max_norm = 1.0 # grad norm clipping
3333
steps = 1000
3434
data_parallel_degree = -1
3535
tensor_parallel_degree = 8 # 8-way TP
36-
fp8_linear = ""
36+
fp8_linear = false
3737
compile = false
3838
dataset = "c4"
3939

train_configs/llama2_7b.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ max_norm = 1.0 # grad norm clipping
3232
steps = 1000
3333
data_parallel_degree = -1
3434
tensor_parallel_degree = 1 # dp-only would be sufficient for 7B
35-
fp8_linear = ""
35+
fp8_linear = false
3636
compile = false
3737
dataset = "c4"
3838

train_configs/llama3_70b.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ max_norm = 1.0 # grad norm clipping
3333
steps = 1000
3434
data_parallel_degree = -1
3535
tensor_parallel_degree = 8 # 8-way TP
36-
fp8_linear = ""
36+
fp8_linear = false
3737
compile = false
3838
dataset = "c4"
3939

train_configs/llama3_8b.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ max_norm = 1.0 # grad norm clipping
3333
steps = 1000
3434
data_parallel_degree = -1
3535
tensor_parallel_degree = 1
36-
fp8_linear = ""
36+
fp8_linear = false
3737
compile = false
3838
dataset = "c4"
3939

0 commit comments

Comments
 (0)