Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 7c7cbae

Browse files
wanchaolfacebook-github-bot
authored andcommitted
add PrepareFloat8ModuleInput for sequence parallel (#275)
Summary: when applying Sequence Parallel to a module with more than 2 linear layers for input proj, we often want to transform from Shard to Replicate once (allgather once) and then reuse the allgathered result, for fp8 we would need to do the casting before the shard -> replicate so that we can perform the fp8 allgather. This PR subclasses the PrepareModuleInput to add the fp8 casting logic to make sure we run the fp8 allgather instead of bf16 allgather then do the casting for computation. Also adjust the test cases to test the real ffn case for sequence parallel torchtitan perf benchmarks (8 H100 devgpu, Llama3 8b, 2-way DP, 4-way TP): * eager (with no fp8 allgather): 3265 wps * eager (with fp8 allgather, this PR): 3900 wps * compile (without fp8 allgather): 5850 wps * compile (with fp8 allgather): 6592 wps, with 37% MFU on H100 So even in eager we got around 20% perf improvement with every allgather runs in fp8, and compiled fp8 allgather perf is more than doubled (102% more WPS) :) Pull Request resolved: #275 Reviewed By: vkuzo Differential Revision: D58346331 Pulled By: wanchaol fbshipit-source-id: 008ca49b6aa6973d2f6d6165e13088d6571cabb4
1 parent 5fc07fc commit 7c7cbae

File tree

2 files changed

+158
-14
lines changed

2 files changed

+158
-14
lines changed

float8_experimental/float8_tensor_parallel.py

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
1+
import torch
12
import torch.nn as nn
23
from float8_experimental.float8_dynamic_linear import (
34
cast_to_float8_e4m3fn,
45
cast_to_float8_e5m2_bw,
56
)
67
from torch.distributed._tensor import DTensor
78
from torch.distributed.device_mesh import DeviceMesh
8-
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel
9+
from torch.distributed.tensor.parallel import (
10+
ColwiseParallel,
11+
PrepareModuleInput,
12+
RowwiseParallel,
13+
)
914

1015
# subclass the ColwiseParallel and RowwiseParallel classes
1116
# to add the float8 support
@@ -109,3 +114,93 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
109114
)
110115

111116
return super()._apply(module, device_mesh)
117+
118+
119+
class PrepareFloat8ModuleInput(PrepareModuleInput):
120+
# subclass the PrepareModuleInput classes to implement fp8 specific logic, the only difference is that
121+
# after we prepare the input DTensor, we cast the input to DTensor(Float8Tensor)
122+
# This is to ensure the float8 cast happens before the all-gather (i.e. Shard -> Replicate)
123+
# so that if there are multiple float8 users of the input activation, we perform fp8 allgather
124+
# only once.
125+
# FP8 Args:
126+
# float8_dtype (torch.dtype, optional): control what float8 dtype to cast to when prepare the module input,
127+
# we currently only support torch.float8_e4m3fn. default: torch.float8_e4m3fn
128+
# fwd_config_submodule_fqn (str, optional): the fqn of the submodule that contains the forward config used
129+
# for the float8 cast. If not specified, we will search for the Float8DynamicLinear in the submodules
130+
# and use the forward config from that module, in this case all module's forward config must be
131+
# the same.
132+
133+
def __init__(
134+
self,
135+
*,
136+
input_layouts=None,
137+
desired_input_layouts=None,
138+
input_kwarg_layouts=None,
139+
desired_input_kwarg_layouts=None,
140+
use_local_output=False,
141+
float8_dtype=torch.float8_e4m3fn,
142+
fwd_config_submodule_fqn=None,
143+
):
144+
super().__init__(
145+
input_layouts=input_layouts,
146+
desired_input_layouts=desired_input_layouts,
147+
input_kwarg_layouts=input_kwarg_layouts,
148+
desired_input_kwarg_layouts=desired_input_kwarg_layouts,
149+
use_local_output=use_local_output,
150+
)
151+
152+
# fp8 specific fields
153+
self.float8_dtype = float8_dtype
154+
self.fwd_config_submodule_fqn = fwd_config_submodule_fqn
155+
156+
if self.float8_dtype != torch.float8_e4m3fn:
157+
raise NotImplementedError(
158+
"PrepareFloat8ModuleInput only support casting to float8_e4m3fn for now"
159+
)
160+
161+
def _prepare_input_arg(self, input, mesh, input_layout, desired_layout):
162+
if input_layout is not None:
163+
if isinstance(input, DTensor):
164+
# TODO: re-enable the check once we fix the compile path
165+
# assert inp.placements[0] == input_layout
166+
dt_inp = input
167+
else:
168+
assert isinstance(
169+
input, torch.Tensor
170+
), "expecting input to be a torch.Tensor!"
171+
dt_inp = DTensor.from_local(
172+
input, mesh, (input_layout,), run_check=False
173+
)
174+
175+
dt_inp = cast_to_float8_e4m3fn(
176+
dt_inp, self.fwd_linear_config
177+
) # DTensor(Float8Tensor)
178+
if desired_layout is not None and input_layout != desired_layout:
179+
dt_inp = dt_inp.redistribute(placements=(desired_layout,))
180+
181+
return dt_inp.to_local() if self.use_local_output else dt_inp
182+
else:
183+
return input
184+
185+
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
186+
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
187+
188+
fwd_linear_config = None
189+
if self.fwd_config_submodule_fqn is not None:
190+
fwd_linear = module.get_submodule(self.fwd_config_submodule_fqn)
191+
assert isinstance(fwd_linear, Float8DynamicLinear)
192+
fwd_linear_config = fwd_linear.forward_config
193+
else:
194+
# search for ScaledMM configs for all the submodules and make sure they are the same
195+
for mod in module.modules():
196+
if isinstance(mod, Float8DynamicLinear):
197+
if fwd_linear_config is None:
198+
fwd_linear_config = mod.forward_config
199+
else:
200+
assert (
201+
fwd_linear_config == mod.forward_config
202+
), "All the Float8DynamicLinear modules should have same forward config!"
203+
204+
self.fwd_linear_config = fwd_linear_config
205+
super()._apply(module, device_mesh)
206+
return module

test/test_dtensor.py

Lines changed: 62 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
import torch
1414
import torch.nn as nn
15+
import torch.nn.functional as F
1516

1617
from float8_experimental.float8_dynamic_linear import (
1718
Float8DynamicLinear,
@@ -22,6 +23,7 @@
2223
from float8_experimental.float8_tensor_parallel import (
2324
Float8ColwiseParallel,
2425
Float8RowwiseParallel,
26+
PrepareFloat8ModuleInput,
2527
)
2628
from float8_experimental.float8_utils import tensor_to_scale
2729
from torch.distributed._tensor import distribute_tensor, DTensor, Replicate, Shard
@@ -38,17 +40,26 @@ def setup_distributed():
3840
return device_mesh
3941

4042

41-
class ToyModel(nn.Module):
43+
class FeedForward(nn.Module):
4244
"""MLP based model"""
4345

46+
def __init__(self):
47+
super(FeedForward, self).__init__()
48+
self.w1 = nn.Linear(16, 32, bias=False)
49+
self.w2 = nn.Linear(16, 32, bias=False)
50+
self.out_proj = nn.Linear(32, 16, bias=False)
51+
52+
def forward(self, x):
53+
return self.out_proj(F.silu(self.w1(x)) * self.w2(x))
54+
55+
56+
class ToyModel(nn.Module):
4457
def __init__(self):
4558
super(ToyModel, self).__init__()
46-
self.in_proj = nn.Linear(16, 32)
47-
self.relu = nn.ReLU()
48-
self.out_proj = nn.Linear(32, 16)
59+
self.ffn = FeedForward()
4960

5061
def forward(self, x):
51-
return self.out_proj(self.relu(self.in_proj(x)))
62+
return self.ffn(x)
5263

5364

5465
def test_scaled_mm(mesh: DeviceMesh, size=16):
@@ -182,8 +193,9 @@ def test_fp8_mlp_tensor_parallelism_base(
182193
tp_model,
183194
mesh,
184195
{
185-
"in_proj": Float8ColwiseParallel(),
186-
"out_proj": Float8RowwiseParallel(),
196+
"ffn.w1": Float8ColwiseParallel(),
197+
"ffn.w2": Float8ColwiseParallel(),
198+
"ffn.out_proj": Float8RowwiseParallel(),
187199
},
188200
)
189201

@@ -192,17 +204,46 @@ def test_fp8_mlp_tensor_parallelism_base(
192204
sp_model,
193205
mesh,
194206
{
195-
"in_proj": Float8ColwiseParallel(input_layouts=Shard(0)),
196-
"out_proj": Float8RowwiseParallel(
197-
output_layouts=Shard(0), use_local_output=False
207+
"ffn": PrepareFloat8ModuleInput(
208+
input_layouts=Shard(1), desired_input_layouts=Replicate()
209+
),
210+
"ffn.w1": Float8ColwiseParallel(),
211+
"ffn.w2": Float8ColwiseParallel(),
212+
"ffn.out_proj": Float8RowwiseParallel(
213+
output_layouts=Shard(1), use_local_output=False
214+
),
215+
},
216+
)
217+
218+
# PrepareFloat8ModuleInput with specific submodule fqn
219+
sp_model2 = copy.deepcopy(toy_model)
220+
sp_model2 = swap_linear_with_float8_linear(
221+
sp_model2, Float8DynamicLinear, emulate=True
222+
)
223+
224+
sp_model2 = parallelize_module(
225+
sp_model2,
226+
mesh,
227+
{
228+
"ffn": PrepareFloat8ModuleInput(
229+
input_layouts=Shard(1),
230+
desired_input_layouts=Replicate(),
231+
fwd_config_submodule_fqn="w2",
232+
),
233+
"ffn.w1": Float8ColwiseParallel(),
234+
"ffn.w2": Float8ColwiseParallel(),
235+
"ffn.out_proj": Float8RowwiseParallel(
236+
output_layouts=Shard(1), use_local_output=False
198237
),
199238
},
200239
)
201240

202241
if compile:
203242
tp_model = torch.compile(tp_model)
243+
sp_model = torch.compile(sp_model)
244+
sp_model2 = torch.compile(sp_model2)
204245

205-
x_fp32 = torch.rand(size * 2, size, device=device, requires_grad=False)
246+
x_fp32 = torch.rand(size, size * 2, size, device=device, requires_grad=False)
206247
x_fp32_tp_input = x_fp32.clone()
207248
x_fp32_sp_input = distribute_tensor(x_fp32.clone(), mesh, [Shard(0)])
208249

@@ -214,11 +255,19 @@ def test_fp8_mlp_tensor_parallelism_base(
214255
global_out.sum().backward()
215256
torch.testing.assert_close(tp_out, global_out)
216257
torch.testing.assert_close(sp_out.full_tensor(), global_out)
258+
torch.testing.assert_close(tp_model.ffn.w1.weight.grad, sp_model.ffn.w1.weight.grad)
259+
torch.testing.assert_close(
260+
tp_model.ffn.out_proj.weight.grad, sp_model.ffn.out_proj.weight.grad
261+
)
262+
263+
sp_out2 = sp_model2(x_fp32_sp_input)
264+
sp_out2.sum().backward()
265+
torch.testing.assert_close(sp_out2.full_tensor(), global_out)
217266
torch.testing.assert_close(
218-
tp_model.in_proj.weight.grad, sp_model.in_proj.weight.grad
267+
tp_model.ffn.w1.weight.grad, sp_model2.ffn.w1.weight.grad
219268
)
220269
torch.testing.assert_close(
221-
tp_model.out_proj.weight.grad, sp_model.out_proj.weight.grad
270+
tp_model.ffn.out_proj.weight.grad, sp_model2.ffn.out_proj.weight.grad
222271
)
223272

224273

0 commit comments

Comments
 (0)