|
6 | 6 |
|
7 | 7 | from typing import Any, Optional, Tuple
|
8 | 8 |
|
| 9 | +import float8_experimental.config as config |
| 10 | + |
9 | 11 | import torch
|
10 | 12 | import torch.utils._pytree as pytree
|
11 | 13 | from float8_experimental.float8_dynamic_utils import cast_to_float8_e4m3_dynamic
|
12 |
| - |
13 | 14 | from float8_experimental.float8_tensor import (
|
14 | 15 | Float8Tensor,
|
15 | 16 | merge_mm_configs,
|
16 | 17 | ScaledMMConfig,
|
17 | 18 | )
|
| 19 | +from float8_experimental.float8_utils import e4m3_dtype |
18 | 20 | from torch._prims_common import suggest_memory_format
|
19 | 21 |
|
20 | 22 | # FSDP pads its local tensor on dim-0. The subclass should be preserved such
|
@@ -110,3 +112,181 @@ def fsdp_post_all_gather(
|
110 | 112 | out._scale = scale
|
111 | 113 | return
|
112 | 114 | return Float8Tensor(data, scale, param_dtype, self._mm_config), (data,)
|
| 115 | + |
| 116 | + |
| 117 | +class WeightWithDelayedFloat8CastTensor(torch.Tensor): |
| 118 | + @staticmethod |
| 119 | + def __new__( |
| 120 | + cls, |
| 121 | + tensor: torch.Tensor, |
| 122 | + amax_buffer: torch.Tensor, |
| 123 | + amax_history_buffer: torch.Tensor, |
| 124 | + scale_buffer: torch.Tensor, |
| 125 | + mm_config: ScaledMMConfig, |
| 126 | + is_amax_initialized: bool, |
| 127 | + ): |
| 128 | + return torch.Tensor._make_wrapper_subclass( |
| 129 | + cls, |
| 130 | + tensor.size(), |
| 131 | + strides=tensor.stride(), |
| 132 | + storage_offset=tensor.storage_offset(), |
| 133 | + memory_format=suggest_memory_format(tensor), |
| 134 | + dtype=tensor.dtype, |
| 135 | + layout=tensor.layout, |
| 136 | + device=tensor.device, |
| 137 | + pin_memory=tensor.is_pinned(), |
| 138 | + requires_grad=tensor.requires_grad, |
| 139 | + ) |
| 140 | + |
| 141 | + def __init__( |
| 142 | + self, |
| 143 | + tensor: torch.Tensor, |
| 144 | + amax_buffer: torch.Tensor, |
| 145 | + amax_history_buffer: torch.Tensor, |
| 146 | + scale_buffer: torch.Tensor, |
| 147 | + mm_config: ScaledMMConfig, |
| 148 | + is_amax_initialized: bool, |
| 149 | + ): |
| 150 | + self._tensor = tensor |
| 151 | + self._amax_buffer = amax_buffer |
| 152 | + self._amax_history_buffer = amax_history_buffer |
| 153 | + self._scale_buffer = scale_buffer |
| 154 | + self._mm_config = mm_config |
| 155 | + |
| 156 | + # Note: is_amax_initialized is not a buffer to avoid data dependent |
| 157 | + # control flow visible to dynamo |
| 158 | + # TODO(future PR): add serialization for this flag |
| 159 | + self.is_amax_initialized = is_amax_initialized |
| 160 | + |
| 161 | + @classmethod |
| 162 | + def __torch_dispatch__(cls, func, types, args, kwargs=None): |
| 163 | + if func == torch.ops.aten.detach.default: |
| 164 | + return WeightWithDelayedFloat8CastTensor( |
| 165 | + args[0]._tensor, |
| 166 | + args[0]._amax_buffer, |
| 167 | + args[0]._amax_history_buffer, |
| 168 | + args[0]._scale_buffer, |
| 169 | + args[0]._mm_config, |
| 170 | + args[0].is_amax_initialized, |
| 171 | + ) |
| 172 | + mm_config: Optional[ScaledMMConfig] = None |
| 173 | + amax_buffer: Optional[torch.Tensor] = None |
| 174 | + amax_history_buffer: Optional[torch.Tensor] = None |
| 175 | + scale_buffer: Optional[torch.Tensor] = None |
| 176 | + is_amax_initialized: Optional[bool] = None |
| 177 | + |
| 178 | + def unwrap(t): |
| 179 | + nonlocal mm_config |
| 180 | + if mm_config is None: |
| 181 | + mm_config = t._mm_config |
| 182 | + else: |
| 183 | + mm_config = merge_mm_configs(mm_config, t._mm_config) |
| 184 | + nonlocal amax_buffer |
| 185 | + if amax_buffer is None: |
| 186 | + amax_buffer = t._amax_buffer |
| 187 | + nonlocal amax_history_buffer |
| 188 | + if amax_history_buffer is None: |
| 189 | + amax_history_buffer = t._amax_history_buffer |
| 190 | + nonlocal scale_buffer |
| 191 | + if scale_buffer is None: |
| 192 | + scale_buffer = t._scale_buffer |
| 193 | + nonlocal is_amax_initialized |
| 194 | + if is_amax_initialized is None: |
| 195 | + is_amax_initialized = t.is_amax_initialized |
| 196 | + return t._tensor |
| 197 | + |
| 198 | + args, kwargs = pytree.tree_map_only( |
| 199 | + WeightWithDelayedFloat8CastTensor, unwrap, (args, kwargs or {}) |
| 200 | + ) |
| 201 | + out = func(*args, **kwargs) |
| 202 | + if func not in _ops_to_preserve_subclass: |
| 203 | + return out |
| 204 | + return pytree.tree_map_only( |
| 205 | + torch.Tensor, |
| 206 | + lambda x: WeightWithDelayedFloat8CastTensor( |
| 207 | + x, |
| 208 | + amax_buffer, |
| 209 | + amax_history_buffer, |
| 210 | + scale_buffer, |
| 211 | + mm_config, |
| 212 | + is_amax_initialized, |
| 213 | + ), |
| 214 | + out, |
| 215 | + ) |
| 216 | + |
| 217 | + def __tensor_flatten__(self): |
| 218 | + return ( |
| 219 | + [ |
| 220 | + "_tensor", |
| 221 | + "_amax_buffer", |
| 222 | + "_amax_history_buffer", |
| 223 | + "_scale_buffer", |
| 224 | + ], |
| 225 | + self._mm_config, |
| 226 | + is_amax_initialized, |
| 227 | + ) |
| 228 | + |
| 229 | + @staticmethod |
| 230 | + def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): |
| 231 | + mm_config, is_amax_initialized = flatten_spec |
| 232 | + return WeightWithDelayedFloat8CastTensor( |
| 233 | + inner_tensors["_tensor"], |
| 234 | + inner_tensors["_amax_buffer"], |
| 235 | + inner_tensors["_amax_history_buffer"], |
| 236 | + inner_tensors["_scale_buffer"], |
| 237 | + mm_config, |
| 238 | + is_amax_initialized, |
| 239 | + ) |
| 240 | + |
| 241 | + def __repr__(self): |
| 242 | + return f"WeightWithDelayedFloat8CastTensor(tensor={self._tensor}, amax_buffer={self._amax_buffer}, scale_buffer={self._scale_buffer}, mm_config={self._mm_config})" |
| 243 | + |
| 244 | + def fsdp_pre_all_gather(self, mesh): |
| 245 | + # initialize if needed |
| 246 | + # TODO(before land): ensure settings are consistent between Float8Linear and here |
| 247 | + if not self.is_amax_initialized: |
| 248 | + from float8_experimental.float8_linear import ( |
| 249 | + _maybe_initialize_amaxes_scales_for_float8_cast, |
| 250 | + ) |
| 251 | + |
| 252 | + _maybe_initialize_amaxes_scales_for_float8_cast( |
| 253 | + self._tensor, |
| 254 | + self._amax_buffer, |
| 255 | + self._amax_history_buffer, |
| 256 | + self._scale_buffer, |
| 257 | + "max", # TODO(before land): read this from parent |
| 258 | + e4m3_dtype, |
| 259 | + self.is_amax_initialized, |
| 260 | + reduce_amax=True, |
| 261 | + ) |
| 262 | + self.is_amax_initialized = True |
| 263 | + |
| 264 | + # this will: |
| 265 | + # 1. cast the tensor to float8 using `_scale_buffer` |
| 266 | + # 2. populate `_amax_buffer` inplace |
| 267 | + # TODO(future PR): clean up all the casting functions and clearly |
| 268 | + # separate dynamic vs delayed, tech debt has accumulated |
| 269 | + float8_tensor = Float8Tensor.to_float8( |
| 270 | + self._tensor, |
| 271 | + self._scale_buffer, |
| 272 | + e4m3_dtype, |
| 273 | + self._amax_buffer, |
| 274 | + self._mm_config, |
| 275 | + ) |
| 276 | + return (float8_tensor._data,), (float8_tensor._scale,) |
| 277 | + |
| 278 | + def fsdp_post_all_gather( |
| 279 | + self, |
| 280 | + all_gather_outputs: Tuple[torch.Tensor, ...], |
| 281 | + metadata: Any, |
| 282 | + param_dtype: torch.dtype, |
| 283 | + *, |
| 284 | + out: Optional[torch.Tensor] = None, |
| 285 | + ): |
| 286 | + (data,) = all_gather_outputs |
| 287 | + (scale,) = metadata |
| 288 | + if out is not None: |
| 289 | + assert isinstance(out, Float8Tensor), f"{type(out)}" |
| 290 | + out._scale = scale |
| 291 | + return |
| 292 | + return Float8Tensor(data, scale, param_dtype, self._mm_config), (data,) |
0 commit comments