This repository was archived by the owner on Aug 7, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 19
[4/x] add tests for DTensor TP/SP + Float8Linear #294
Closed
Closed
Changes from 1 commit
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,6 +18,7 @@ | |
Float8DynamicLinear, | ||
NoopFwToFloat8E5M2Bw, | ||
) | ||
from float8_experimental.float8_linear import Float8Linear, TensorScalingType | ||
from float8_experimental.float8_linear_utils import swap_linear_with_float8_linear | ||
from float8_experimental.float8_tensor import Float8Tensor, ScaledMMConfig | ||
from float8_experimental.float8_tensor_parallel import ( | ||
|
@@ -169,23 +170,37 @@ def test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16): | |
loss.backward() | ||
|
||
|
||
def test_fp8_mlp_tensor_parallelism_base( | ||
mesh: DeviceMesh, size=16, compile: bool = False | ||
def _test_fp8_mlp_tensor_parallelism_base( | ||
mesh: DeviceMesh, size=16, compile: bool = False, use_float8_linear: bool = False | ||
): | ||
device = mesh.device_type | ||
# TODO(future): delete Float8DynamicLinear from this test once all the | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: I wonder if we want to standardize on the todo format so that in the future we can just command F and find all the things to change |
||
# code is unified | ||
float8_cls = Float8Linear if use_float8_linear else Float8DynamicLinear | ||
extra_kwargs = {} | ||
if use_float8_linear: | ||
# For now, just use Float8Linear with dynamic scaling, which is the | ||
# same behavior as Float8Linear. | ||
# TODO(future): add support for float8 all-gather with delayed scaling | ||
# for activations and gradients. | ||
extra_kwargs = { | ||
"scaling_type_x": TensorScalingType.DYNAMIC, | ||
"scaling_type_w": TensorScalingType.DYNAMIC, | ||
"scaling_type_dL_dY": TensorScalingType.DYNAMIC, | ||
} | ||
|
||
toy_model = ToyModel().to(device) | ||
toy_model_fp8 = swap_linear_with_float8_linear( | ||
toy_model, Float8DynamicLinear, emulate=True | ||
toy_model, float8_cls, emulate=True, **extra_kwargs | ||
) | ||
|
||
tp_model = copy.deepcopy(toy_model) | ||
tp_model = swap_linear_with_float8_linear( | ||
tp_model, Float8DynamicLinear, emulate=True | ||
tp_model, float8_cls, emulate=True, **extra_kwargs | ||
) | ||
sp_model = copy.deepcopy(toy_model) | ||
sp_model = swap_linear_with_float8_linear( | ||
sp_model, Float8DynamicLinear, emulate=True | ||
sp_model, float8_cls, emulate=True, **extra_kwargs | ||
) | ||
|
||
# vanilla TP | ||
|
@@ -218,7 +233,7 @@ def test_fp8_mlp_tensor_parallelism_base( | |
# PrepareFloat8ModuleInput with specific submodule fqn | ||
sp_model2 = copy.deepcopy(toy_model) | ||
sp_model2 = swap_linear_with_float8_linear( | ||
sp_model2, Float8DynamicLinear, emulate=True | ||
sp_model2, Float8DynamicLinear, emulate=True, **extra_kwargs | ||
) | ||
|
||
sp_model2 = parallelize_module( | ||
|
@@ -271,8 +286,28 @@ def test_fp8_mlp_tensor_parallelism_base( | |
) | ||
|
||
|
||
def test_fp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=16): | ||
_test_fp8_mlp_tensor_parallelism_base( | ||
mesh, size, compile=False, use_float8_linear=False | ||
) | ||
|
||
|
||
def test_fp8_mlp_tensor_parallelism_eager_float8_linear(mesh: DeviceMesh, size=16): | ||
_test_fp8_mlp_tensor_parallelism_base( | ||
mesh, size, compile=False, use_float8_linear=True | ||
) | ||
|
||
|
||
def test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16): | ||
test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=True) | ||
_test_fp8_mlp_tensor_parallelism_base( | ||
mesh, size, compile=True, use_float8_linear=False | ||
) | ||
|
||
|
||
def test_fp8_mlp_tensor_parallelism_compile_float8_linear(mesh: DeviceMesh, size=16): | ||
_test_fp8_mlp_tensor_parallelism_base( | ||
mesh, size, compile=True, use_float8_linear=True | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
|
@@ -285,8 +320,10 @@ def test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16): | |
test_fp8_redistribute, | ||
test_dtensor_cast_to_fp8, | ||
test_dtensor_fp8_autograd, | ||
test_fp8_mlp_tensor_parallelism_base, | ||
test_fp8_mlp_tensor_parallelism_eager, | ||
test_fp8_mlp_tensor_parallelism_eager_float8_linear, | ||
test_fp8_mlp_tensor_parallelism_compile, | ||
test_fp8_mlp_tensor_parallelism_compile_float8_linear, | ||
] | ||
|
||
for test in tqdm(tests, desc="Running tests"): | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: maybe update this comment