You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
{{ message }}
This repository was archived by the owner on Aug 7, 2024. It is now read-only.
add PrepareFloat8ModuleInput for sequence parallel (#275)
Summary:
when applying Sequence Parallel to a module with more than 2 linear layers for input proj, we often want to transform from Shard to Replicate once (allgather once) and then reuse the allgathered result, for fp8 we would need to do the casting before the shard -> replicate so that we can perform the fp8 allgather.
This PR subclasses the PrepareModuleInput to add the fp8 casting logic to make sure we run the fp8 allgather instead of bf16 allgather then do the casting for computation.
Also adjust the test cases to test the real ffn case for sequence parallel
torchtitan perf benchmarks (8 H100 devgpu, Llama3 8b, 2-way DP, 4-way TP):
* eager (with no fp8 allgather): 3265 wps
* eager (with fp8 allgather, this PR): 3900 wps
* compile (without fp8 allgather): 5850 wps
* compile (with fp8 allgather): 6592 wps, with 37% MFU on H100
So even in eager we got around 20% perf improvement with every allgather runs in fp8, and compiled fp8 allgather perf is more than doubled (102% more WPS) :)
Pull Request resolved: #275
Reviewed By: vkuzo
Differential Revision: D58346331
Pulled By: wanchaol
fbshipit-source-id: 008ca49b6aa6973d2f6d6165e13088d6571cabb4
0 commit comments