Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

[4/x] add tests for DTensor TP/SP + Float8Linear #294

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 30 additions & 8 deletions float8_experimental/float8_tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
cast_to_float8_e4m3_dynamic,
cast_to_float8_e5m2_dynamic_bw,
)
from float8_experimental.float8_linear import TensorScalingType
from torch.distributed._tensor import DTensor
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor.parallel import (
Expand All @@ -19,7 +20,17 @@
# here is that in input/output handling we do casting after
# creating the DTensor.

# NOTE: This only works and tested with the DynamicLinear
# NOTE: This only works and tested with the dynamic scaling
# (Float8DynamicLinear and Float8Linear with dynamic scaling for all tensors)


def _float8_linear_supports_float8_allgather(m):
# TODO(future): add support for delayed scaling for activations
# and gradients
return (
m.scaling_type_x == TensorScalingType.DYNAMIC
and m.scaling_type_dL_dY == TensorScalingType.DYNAMIC
)


class Float8ColwiseParallel(ColwiseParallel):
Expand Down Expand Up @@ -61,11 +72,16 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me

def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
from float8_experimental.float8_linear import Float8Linear

if not isinstance(module, Float8DynamicLinear):
if not isinstance(module, (Float8DynamicLinear, Float8Linear)):
raise ValueError(
f"Expecting module to be Float8DynamicLinear but found {type(module)}"
f"Expecting module to be Float8DynamicLinear or Float8Linear but found {type(module)}"
)
elif isinstance(
module, Float8Linear
) and not _float8_linear_supports_float8_allgather(module):
raise AssertionError("unsupported")

return super()._apply(module, device_mesh)

Expand Down Expand Up @@ -107,11 +123,16 @@ def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_me

def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
from float8_experimental.float8_linear import Float8Linear

if not isinstance(module, Float8DynamicLinear):
if not isinstance(module, (Float8DynamicLinear, Float8Linear)):
raise ValueError(
f"Expecting module to be Float8DynamicLinear but found {type(module)}"
f"Expecting module to be Float8DynamicLinear or Float8Linear but found {type(module)}"
)
elif isinstance(
module, Float8Linear
) and not _float8_linear_supports_float8_allgather(module):
raise AssertionError("unsupported")

return super()._apply(module, device_mesh)

Expand Down Expand Up @@ -184,22 +205,23 @@ def _prepare_input_arg(self, input, mesh, input_layout, desired_layout):

def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
from float8_experimental.float8_linear import Float8Linear

fwd_linear_config = None
if self.fwd_config_submodule_fqn is not None:
fwd_linear = module.get_submodule(self.fwd_config_submodule_fqn)
assert isinstance(fwd_linear, Float8DynamicLinear)
assert isinstance(fwd_linear, (Float8DynamicLinear, Float8Linear))
fwd_linear_config = fwd_linear.forward_config
else:
# search for ScaledMM configs for all the submodules and make sure they are the same
for mod in module.modules():
if isinstance(mod, Float8DynamicLinear):
if isinstance(mod, (Float8DynamicLinear, Float8Linear)):
if fwd_linear_config is None:
fwd_linear_config = mod.forward_config
else:
assert (
fwd_linear_config == mod.forward_config
), "All the Float8DynamicLinear modules should have same forward config!"
), "All the Float8DynamicLinear and Float8Linear modules should have same forward config!"

self.fwd_linear_config = fwd_linear_config
super()._apply(module, device_mesh)
Expand Down
53 changes: 45 additions & 8 deletions test/test_dtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Float8DynamicLinear,
NoopFwToFloat8E5M2Bw,
)
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
from float8_experimental.float8_linear_utils import swap_linear_with_float8_linear
from float8_experimental.float8_tensor import Float8Tensor, ScaledMMConfig
from float8_experimental.float8_tensor_parallel import (
Expand Down Expand Up @@ -169,23 +170,37 @@ def test_dtensor_fp8_autograd(mesh: DeviceMesh, size=16):
loss.backward()


def test_fp8_mlp_tensor_parallelism_base(
mesh: DeviceMesh, size=16, compile: bool = False
def _test_fp8_mlp_tensor_parallelism_base(
mesh: DeviceMesh, size=16, compile: bool = False, use_float8_linear: bool = False
):
device = mesh.device_type
# TODO(future): delete Float8DynamicLinear from this test once all the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I wonder if we want to standardize on the todo format so that in the future we can just command F and find all the things to change

# code is unified
float8_cls = Float8Linear if use_float8_linear else Float8DynamicLinear
extra_kwargs = {}
if use_float8_linear:
# For now, just use Float8Linear with dynamic scaling, which is the
# same behavior as Float8Linear.
# TODO(future): add support for float8 all-gather with delayed scaling
# for activations and gradients.
extra_kwargs = {
"scaling_type_x": TensorScalingType.DYNAMIC,
"scaling_type_w": TensorScalingType.DYNAMIC,
"scaling_type_dL_dY": TensorScalingType.DYNAMIC,
}

toy_model = ToyModel().to(device)
toy_model_fp8 = swap_linear_with_float8_linear(
toy_model, Float8DynamicLinear, emulate=True
toy_model, float8_cls, emulate=True, **extra_kwargs
)

tp_model = copy.deepcopy(toy_model)
tp_model = swap_linear_with_float8_linear(
tp_model, Float8DynamicLinear, emulate=True
tp_model, float8_cls, emulate=True, **extra_kwargs
)
sp_model = copy.deepcopy(toy_model)
sp_model = swap_linear_with_float8_linear(
sp_model, Float8DynamicLinear, emulate=True
sp_model, float8_cls, emulate=True, **extra_kwargs
)

# vanilla TP
Expand Down Expand Up @@ -218,7 +233,7 @@ def test_fp8_mlp_tensor_parallelism_base(
# PrepareFloat8ModuleInput with specific submodule fqn
sp_model2 = copy.deepcopy(toy_model)
sp_model2 = swap_linear_with_float8_linear(
sp_model2, Float8DynamicLinear, emulate=True
sp_model2, Float8DynamicLinear, emulate=True, **extra_kwargs
)

sp_model2 = parallelize_module(
Expand Down Expand Up @@ -271,8 +286,28 @@ def test_fp8_mlp_tensor_parallelism_base(
)


def test_fp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=16):
_test_fp8_mlp_tensor_parallelism_base(
mesh, size, compile=False, use_float8_linear=False
)


def test_fp8_mlp_tensor_parallelism_eager_float8_linear(mesh: DeviceMesh, size=16):
_test_fp8_mlp_tensor_parallelism_base(
mesh, size, compile=False, use_float8_linear=True
)


def test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16):
test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=True)
_test_fp8_mlp_tensor_parallelism_base(
mesh, size, compile=True, use_float8_linear=False
)


def test_fp8_mlp_tensor_parallelism_compile_float8_linear(mesh: DeviceMesh, size=16):
_test_fp8_mlp_tensor_parallelism_base(
mesh, size, compile=True, use_float8_linear=True
)


if __name__ == "__main__":
Expand All @@ -285,8 +320,10 @@ def test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16):
test_fp8_redistribute,
test_dtensor_cast_to_fp8,
test_dtensor_fp8_autograd,
test_fp8_mlp_tensor_parallelism_base,
test_fp8_mlp_tensor_parallelism_eager,
test_fp8_mlp_tensor_parallelism_eager_float8_linear,
test_fp8_mlp_tensor_parallelism_compile,
test_fp8_mlp_tensor_parallelism_compile_float8_linear,
]

for test in tqdm(tests, desc="Running tests"):
Expand Down
Loading