|
| 1 | +.. _fsdp_notes: |
| 2 | + |
| 3 | +FSDP Notes |
| 4 | +========== |
| 5 | + |
| 6 | +.. _fsdp_prefetch: |
| 7 | + |
| 8 | +FSDP Prefetch Nuances |
| 9 | +--------------------- |
| 10 | + |
| 11 | +For overlapping ``forward`` all-gathers with ``forward`` compute, there are two possible mechanisms: |
| 12 | + |
| 13 | +1. Implicit forward prefetching (always enabled) |
| 14 | +2. Explicit forward prefetching (``forward_prefetch=True``) |
| 15 | + |
| 16 | +Implicit ``forward`` prefetching refers to relying on issuing the all-gathers from a separate CUDA |
| 17 | +stream to allow for overlapping an all-gather with ``forward`` compute issued before it (from the CPU |
| 18 | +perspective). For example, if we have layer 0 all-gather -> layer 0 ``forward`` compute -> layer 1 |
| 19 | +all-gather -> …, then layer 1 all-gather can overlap with layer 0 ``forward`` compute even though the |
| 20 | +CPU thread issued it afterwards. (The 1st all-gather will not be able to overlap with anything.) |
| 21 | + |
| 22 | +Explicit ``forward`` prefetching refers to changing the CPU thread’s issue order: e.g. layer 0 |
| 23 | +all-gather -> layer 1 all-gather -> layer 0 ``forward`` compute -> …. In eager mode, there is no way to |
| 24 | +know in general which layer is the next layer (e.g. layer 1 in the example) when still executing on |
| 25 | +layer 0. Therefore, explicit ``forward`` prefetching should only be used for models whose execution |
| 26 | +order is fixed from iteration to iteration (which we sometimes call “static graph”). An example of a |
| 27 | +model that does not satisfy this constraint is `FLAVA |
| 28 | +<https://pytorch.org/blog/scaling-multimodal-foundation-models-in-torchmultimodal-with-pytorch-distributed/>`_). |
| 29 | + |
| 30 | +Explicit ``forward`` prefetching only saves the time taken to issue a layer’s ``forward`` compute kernels at |
| 31 | +the cost that the next all-gather’s output tensor must be allocated while the current one is still |
| 32 | +in use. By issuing the next all- gather before the current ``forward`` compute kernels, the next |
| 33 | +all-gather can start sooner on GPU. For most LLM workloads, this is not the case, so there is no |
| 34 | +motivation for enabling ``forward_prefetch=True``. |
| 35 | + |
| 36 | +In contrast, for ``backward``, we must use explicit ``backward`` prefetching or else there will be 0 overlap |
| 37 | +of communication and computation. The reason is because we use a single NCCL process group for both |
| 38 | +all-gather and reduce-scatter (partially because in earlier NCCL versions, it was not safe to use |
| 39 | +multiple concurrently on the same device over the same ranks). A single NCCL process group means a |
| 40 | +single internal NCCL stream on which reduce-scatters and all-gathers run serially. As such, unless |
| 41 | +we explicitly reorder the CPU issue order to be next all-gather -> current reduce-scatter, then the |
| 42 | +current reduce-scatter would block the next all-gather and hence the next ``backward`` computation, |
| 43 | +preventing the current reduce-scatter from overlapping. |
| 44 | + |
| 45 | +.. _fsdp_comms_payload_size: |
| 46 | + |
| 47 | +Communication payload size |
| 48 | +-------------------------- |
| 49 | + |
| 50 | +In FSDP the communications are: |
| 51 | + |
| 52 | +1. all-gather on parameters in ``forward`` |
| 53 | +2. all-gather on parameters in ``backward`` |
| 54 | +3. reduce-scatter on gradients in ``backward`` |
| 55 | + |
| 56 | +If activation checkpointing (:func:`~torch.utils.checkpoint.checkpoint`) is used there is no |
| 57 | +additional communication since the parameters are prefetched anyway during ``backward``. |
| 58 | + |
| 59 | +In the FSDP design, the communication payload per rank is determined as follows: Each call to |
| 60 | +:class:`FullyShardedDataParallel` creates one communication group consisting of the parameters in |
| 61 | +``module.parameters()`` except any already assigned to a nested :class:`FullyShardedDataParallel` |
| 62 | +instance. For example, for Llama, if you apply :class:`FullyShardedDataParallel` to every |
| 63 | +transformer block and also to the root module, then there is one communication group for each |
| 64 | +transformer block and finally one communication group with the initial embedding and final linear. |
| 65 | +Each communication group corresponds to a single all-gather call and single reduce-scatter call. In |
| 66 | +that way, how you apply :class:`FullyShardedDataParallel` determines the communication size. In |
| 67 | +general, applying FSDP to each transformer block is a good heuristic for LLMs, and it is hard to do |
| 68 | +better than that given the current design. |
| 69 | + |
| 70 | +Let's consider an example where we have a Transformer-based model sharded over 8 GPUs, where the |
| 71 | +sharding happens at the transformer block-level only, and each transformer block contains 1.6B |
| 72 | +parameters and the parameters are in fp32 (4 bytes each). Which means that once sharded, each |
| 73 | +transformer block will contain 0.2B parameters on each rank. |
| 74 | + |
| 75 | +* The ``forward`` pass will communicate in chunks of ``0.2*4 = 0.8GB`` in all-gather |
| 76 | +* The ``backward`` pass will communicate 2 times ``0.8GB`` each (1x all-gather and 1x reduce-scatter) |
| 77 | + |
| 78 | +In other words there will be 3 communications with a payload of ``0.8GB` each. If the model was |
| 79 | +comprised of 10 transformer blocks there would be a total of 30 communications for a total of |
| 80 | +``30*0.8=24GB`. |
| 81 | +
|
| 82 | +To formalize the payload size per communication per rank is |
| 83 | +``total_transformer_block_params_in_B*dtype_bytes/num_gpus`` (GBs). |
| 84 | + |
| 85 | +Please note that in this example we didn't include the additional communications required for the |
| 86 | +embedding, which should be accounted for as well. And the math would depend on whether the input and |
| 87 | +output embeddings are tied or not. If they aren't tied there will be 2x more communications. |
| 88 | + |
| 89 | +.. _fsdp_buffers_sizes: |
| 90 | + |
| 91 | +FSDP buffers sizes |
| 92 | +------------------ |
| 93 | + |
| 94 | +First, let's cover the buffers allocated for communications: |
| 95 | + |
| 96 | +``forward`` currently requires 2x all-gather buffer size. Here is why: |
| 97 | + |
| 98 | +As explained in :ref:`fsdp_prefetch` in the case of explicit ``forward`` prefetching |
| 99 | +(``forward_prefetch=True`) case of layer 0 all-gather -> layer 0 forward compute -> layer 1 |
| 100 | +all-gather there is a need for 2 all-gather-sized buffers, because one buffer is used in the current ``forward`` while the other is used to do the prefetching. |
| 101 | + |
| 102 | +While the implicit ``forward`` prefetching (``forward_prefetch=False``, default) case of the same sequence in theory should need only 1 buffer, in reality it's still 2x all-gather-sized buffers. The reason is that in the flat-parameter FSDP design, we do not copy-out of the all-gather buffer. The parameters used for compute are directly viewed into the all-gather buffer (in fact, the main benefit of the "flat parameter" is exactly this reason). In that case, while 'layer 1 all-gather' is overlapping with 'layer 0 forward compute', the 'layer 0 forward compute' is using the parameters viewed into the 'layer 0 all-gather' buffer. |
| 103 | + |
| 104 | +A natural question then is, when would you want ``forward_prefetch=False``? For static-graph models (like most LLMs), there is a major technical reason. It is more that, practically, we added this option quickly for some CPU-bound internal models and have not tested every code path with it in unit testing, so we are less confident in it. ``forward_prefetching=False`` can be slightly easier to reason about since we do not have to check the recorded forward order as a possible 'failure mode'; a module's all-gather can always be found under its own ``record_function`` label in its profiler trace. |
| 105 | + |
| 106 | +``backward`` currently requires at least 2x all-gather buffer size and potentially a bit more. Here is why: |
| 107 | + |
| 108 | +The current FSDP design uses ``recordStream`` to manage allocations produced in one stream consumed in another, which can lead to more memory usage than expected. How much more can be "non-deterministic" in that it depends on GPU kernel timing relative to the CPU. The ``limit_all_gathers=True`` argument is a mitigation to that - for more details refer to this discussion is `FSDP & CUDACachingAllocator <https://dev-discuss.pytorch.org/t/fsdp-cudacachingallocator-an-outsider-newb-perspective/1486/1>`_. |
| 109 | + |
| 110 | +The way existing FSDP works with autograd: |
| 111 | + |
| 112 | +* Existing FSDP all-gathers the ``flat_param``, which is the autograd leaf. |
| 113 | +* It calls ``torch.split`` to get 1D views into the ``flat_param`` corresponding to its constituent original parameters. |
| 114 | +* It calls ``torch.view`` on each 1D split to view back to ND. |
| 115 | +* This means that in ``backward``, we end up with ``ViewBackward`` (ND -> 1D) and ``SplitWithSizesBackward`` (which is a concat). In particular, each individual gradient is computed as a separate allocation, and an explicit concat happens to construct the reduce-scatter input buffer. This implies actually a 2x buffer size for reduce-scatter at that peak memory point. |
| 116 | + |
| 117 | +In summary, for ``backward``, it is about 2x buffer size for reduce-scatter plus any ``recordStream`` effects. |
| 118 | + |
| 119 | +Second, let's discuss the additional buffers: |
| 120 | + |
| 121 | +Once the sharded parameters are gathered from all ranks, they require an additional buffer of `total_transformer_block_params_in_B*dtype_bytes` for the full parameters - so continuing the example from earlier if each transformer block is 1.6B parameters and the parameters are in fp32, then it'd be `1.6*4=6.4GB` buffer. |
| 122 | + |
| 123 | +And there is a need for 2 of those buffers, since there is one currently being used and another being prefetched. |
| 124 | + |
| 125 | +To summarize, we have: |
| 126 | + |
| 127 | +1. 2 times communication buffers of ``total_transformer_block_params_in_B*dtype_bytes/num_gpus`` |
| 128 | +2. 2 times unsharded transformer block parameters buffer ````total_transformer_block_params_in_B*dtype_bytes`` |
| 129 | + |
| 130 | +or if you have been following the example: |
| 131 | + |
| 132 | +1. ``2*1.6*4/8=1.6GB`` |
| 133 | +2. ``2**1.6*4=12.8GB`` |
| 134 | + |
| 135 | +and the total of ``14.4GB``. |
| 136 | + |
| 137 | +Now let's briefly discuss what happens to the embeddings as we have left those out from the calculations: |
| 138 | + |
| 139 | +Given the rule we discussed that you included in the note starting with "the communication buffer |
| 140 | +size is determined as follows", we can analyze as follows: |
| 141 | + |
| 142 | +* Suppose we apply FSDP to the root module (e.g. the ``Transformer`` class). Suppose we further apply FSDP to each transformer block (e.g. the ``TransformerBlock`` class). |
| 143 | +* Most commonly, the embedding and final linear projection are direct children of the root ``Transformer`` class. |
| 144 | +* Following our rule, that means that the embedding and final linear projection are assigned to the root ``Transformer``'s flat parameter. |
| 145 | +* We have _another_ special rule, which is that the root does not free its parameters after forward because they will be anyways immediately all-gathered in backward. |
| 146 | +* Putting this together, this means that the root's flat parameter including the embedding and final projection are all-gathered to begin forward and kept in GPU memory until the end of backward. |
| 147 | +* If the embedding and final linear are not weight-tied, then we _could_ further apply FSDP to the embedding and to the final linear. For weight-tied parameters, we require them to be part of the same flat parameter (or else it would get double-counted). That would allow the embedding to be freed after its usage in forward and only all-gathered toward the end of backward. |
| 148 | +* Hopefully, this gives a better sense -- each FSDP module gets assigned parameters in its ``module.parameters`` except those already assigned to another nested FSDP module, and the FSDP module's ``forward`` defines the 'live' interval for its parameters. Hence, the nested ``nn.Module`` structure can affect the all-gather/free schedule and hence the memory/throughput performance. |
0 commit comments