|
| 1 | +import contextlib |
| 2 | +import copy |
| 3 | +from typing import List, Type |
| 4 | + |
| 5 | +import torch |
| 6 | +import torch._dynamo.testing |
| 7 | +import torch.distributed as dist |
| 8 | +from float8_experimental import config |
| 9 | +from float8_experimental.float8_dynamic_linear import Float8DynamicLinear |
| 10 | +from float8_experimental.float8_linear import Float8Linear |
| 11 | +from float8_experimental.float8_linear_utils import ( |
| 12 | + swap_linear_with_float8_linear, |
| 13 | + sync_float8_amax_and_scale_history, |
| 14 | +) |
| 15 | +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, MixedPrecision |
| 16 | +from torch.distributed.fsdp.wrap import ModuleWrapPolicy |
| 17 | +from torch.testing._internal.common_distributed import skip_if_lt_x_gpu |
| 18 | +from torch.testing._internal.common_fsdp import FSDPTest |
| 19 | +from torch.testing._internal.common_utils import run_tests |
| 20 | +from torch.testing._internal.distributed._tensor.common_dtensor import ( |
| 21 | + ModelArgs, |
| 22 | + Transformer, |
| 23 | + TransformerBlock, |
| 24 | +) |
| 25 | +from torch.testing._internal.distributed.fake_pg import FakeStore |
| 26 | + |
| 27 | + |
| 28 | +class TestFloat8CompileCommon: |
| 29 | + def _init_transformer_with_fp8( |
| 30 | + self, module_cls: Type, checkpoint_activations: bool = False |
| 31 | + ): |
| 32 | + torch.manual_seed(42) |
| 33 | + args = ModelArgs( |
| 34 | + n_layers=3, |
| 35 | + dim=768, |
| 36 | + n_heads=12, |
| 37 | + dropout_p=0.0, |
| 38 | + weight_tying=False, |
| 39 | + checkpoint_activations=checkpoint_activations, |
| 40 | + ) |
| 41 | + module = Transformer(args) |
| 42 | + # Only dynamic linear supports activation hooks |
| 43 | + use_hooks = module_cls is Float8DynamicLinear |
| 44 | + return swap_linear_with_float8_linear( |
| 45 | + module, module_cls, emulate=True, use_activation_hooks=use_hooks |
| 46 | + ) |
| 47 | + |
| 48 | + @contextlib.contextmanager |
| 49 | + def enable_amax_init(self, enable: bool): |
| 50 | + prev_value = config.enable_amax_init |
| 51 | + config.enable_amax_init = enable |
| 52 | + try: |
| 53 | + yield |
| 54 | + finally: |
| 55 | + config.enable_amax_init = prev_value |
| 56 | + |
| 57 | + @contextlib.contextmanager |
| 58 | + def enable_pre_and_post_forward(self, enable: bool): |
| 59 | + prev_value = config.enable_pre_and_post_forward |
| 60 | + config.enable_pre_and_post_forward = enable |
| 61 | + try: |
| 62 | + yield |
| 63 | + finally: |
| 64 | + config.enable_pre_and_post_forward = prev_value |
| 65 | + |
| 66 | + |
| 67 | +class TestFloat8CompileFakePG( |
| 68 | + TestFloat8CompileCommon, torch._dynamo.test_case.TestCase |
| 69 | +): |
| 70 | + def setUp(self): |
| 71 | + super().setUp() |
| 72 | + fake_store = FakeStore() |
| 73 | + dist.init_process_group( |
| 74 | + "fake", store=fake_store, rank=0, world_size=self.world_size |
| 75 | + ) |
| 76 | + |
| 77 | + def tearDown(self): |
| 78 | + super().tearDown() |
| 79 | + dist.destroy_process_group() |
| 80 | + |
| 81 | + @property |
| 82 | + def world_size(self) -> int: |
| 83 | + return min(torch.cuda.device_count(), 2) |
| 84 | + |
| 85 | + @skip_if_lt_x_gpu(2) |
| 86 | + def test_compile_submodule_dynamic(self): |
| 87 | + module = self._init_transformer_with_fp8(Float8DynamicLinear) |
| 88 | + |
| 89 | + # Compile each transformer block separately |
| 90 | + cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager") |
| 91 | + num_compiled_fns = 0 |
| 92 | + for submodule in module.modules(): |
| 93 | + if isinstance(submodule, TransformerBlock): |
| 94 | + submodule.forward = torch.compile(submodule.forward, backend=cnt) |
| 95 | + num_compiled_fns += 1 |
| 96 | + module = FSDP( |
| 97 | + module, |
| 98 | + auto_wrap_policy=ModuleWrapPolicy({TransformerBlock}), |
| 99 | + use_orig_params=True, |
| 100 | + device_id=dist.get_rank(), |
| 101 | + ) |
| 102 | + local_inp = torch.randint(0, 16, (1, 4), device="cuda") |
| 103 | + out = module(local_inp) |
| 104 | + out.sum().backward() |
| 105 | + self.assertEqual(cnt.frame_count, num_compiled_fns) |
| 106 | + |
| 107 | + # Compile the output projection |
| 108 | + module.output.forward = torch.compile(module.output.forward, backend=cnt) |
| 109 | + # in float8_mm |
| 110 | + # assert isinstance(args[0], Float8Tensor) and isinstance(args[1], Float8Tensor) |
| 111 | + with self.assertRaises(RuntimeError): |
| 112 | + module(local_inp) |
| 113 | + |
| 114 | + @skip_if_lt_x_gpu(2) |
| 115 | + def test_compile_root_dynamic(self): |
| 116 | + module = self._init_transformer_with_fp8(Float8DynamicLinear) |
| 117 | + |
| 118 | + # Compile the root module |
| 119 | + module = FSDP( |
| 120 | + module, |
| 121 | + auto_wrap_policy=ModuleWrapPolicy({TransformerBlock}), |
| 122 | + use_orig_params=True, |
| 123 | + device_id=dist.get_rank(), |
| 124 | + ) |
| 125 | + cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager") |
| 126 | + module = torch.compile(module, backend=cnt) |
| 127 | + local_inp = torch.randint(0, 16, (1, 4), device="cuda") |
| 128 | + # in forward |
| 129 | + # h = layer(h) |
| 130 | + # in float8_mm |
| 131 | + # assert isinstance(args[0], Float8Tensor) and isinstance(args[1], Float8Tensor) |
| 132 | + with self.assertRaises(RuntimeError): |
| 133 | + module(local_inp) |
| 134 | + |
| 135 | + @skip_if_lt_x_gpu(2) |
| 136 | + def test_compile_submodule_delayed(self): |
| 137 | + module = self._init_transformer_with_fp8(Float8Linear) |
| 138 | + |
| 139 | + # Compile each transformer block separately |
| 140 | + torch._dynamo.config.cache_size_limit = 16 |
| 141 | + cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager") |
| 142 | + for submodule in module.modules(): |
| 143 | + if isinstance(submodule, TransformerBlock): |
| 144 | + submodule.forward = torch.compile(submodule.forward, backend=cnt) |
| 145 | + module = FSDP( |
| 146 | + module, |
| 147 | + auto_wrap_policy=ModuleWrapPolicy({TransformerBlock}), |
| 148 | + use_orig_params=True, |
| 149 | + device_id=dist.get_rank(), |
| 150 | + ) |
| 151 | + local_inp = torch.randint(0, 16, (1, 4), device="cuda") |
| 152 | + # Please convert all Tensors to FakeTensors first or instantiate FakeTensorMode with 'allow_non_fake_inputs' |
| 153 | + # with self.assertRaises(RuntimeError): |
| 154 | + with self.assertRaises(RuntimeError): |
| 155 | + module(local_inp) |
| 156 | + |
| 157 | + # Compile each transformer block separately with amax init disabled |
| 158 | + with self.enable_amax_init(False): |
| 159 | + module = self._init_transformer_with_fp8(Float8Linear) |
| 160 | + cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager") |
| 161 | + num_compiled_fns = 0 |
| 162 | + for submodule in module.modules(): |
| 163 | + if isinstance(submodule, TransformerBlock): |
| 164 | + submodule.forward = torch.compile(submodule.forward, backend=cnt) |
| 165 | + num_compiled_fns += 1 |
| 166 | + module = FSDP( |
| 167 | + module, |
| 168 | + auto_wrap_policy=ModuleWrapPolicy({TransformerBlock}), |
| 169 | + use_orig_params=True, |
| 170 | + device_id=dist.get_rank(), |
| 171 | + ) |
| 172 | + module(local_inp).sum().backward() |
| 173 | + self.assertEqual(cnt.frame_count, 18) # TODO! |
| 174 | + |
| 175 | + # Compile each transformer block separately with amax init disabled and |
| 176 | + # pre/post-forward disabled |
| 177 | + with self.enable_amax_init(False), self.enable_pre_and_post_forward(False): |
| 178 | + module = self._init_transformer_with_fp8(Float8Linear) |
| 179 | + cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager") |
| 180 | + num_compiled_fns = 0 |
| 181 | + for submodule in module.modules(): |
| 182 | + if isinstance(submodule, TransformerBlock): |
| 183 | + submodule.forward = torch.compile(submodule.forward, backend=cnt) |
| 184 | + num_compiled_fns += 1 |
| 185 | + module = FSDP( |
| 186 | + module, |
| 187 | + auto_wrap_policy=ModuleWrapPolicy({TransformerBlock}), |
| 188 | + use_orig_params=True, |
| 189 | + device_id=dist.get_rank(), |
| 190 | + ) |
| 191 | + module(local_inp).sum().backward() |
| 192 | + self.assertEqual(cnt.frame_count, num_compiled_fns) |
| 193 | + |
| 194 | + @skip_if_lt_x_gpu(2) |
| 195 | + def test_compile_root_delayed(self): |
| 196 | + with self.enable_amax_init(False): |
| 197 | + module = self._init_transformer_with_fp8(Float8Linear) |
| 198 | + module = FSDP( |
| 199 | + module, |
| 200 | + auto_wrap_policy=ModuleWrapPolicy({TransformerBlock}), |
| 201 | + use_orig_params=True, |
| 202 | + device_id=dist.get_rank(), |
| 203 | + ) |
| 204 | + cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager") |
| 205 | + module = torch.compile(module, backend=cnt) |
| 206 | + local_inp = torch.randint(0, 16, (1, 4), device="cuda") |
| 207 | + out = module(local_inp) |
| 208 | + out.sum().backward() |
| 209 | + self.assertEqual(cnt.frame_count, 19) # TODO! |
| 210 | + |
| 211 | + with self.enable_amax_init(False), self.enable_pre_and_post_forward(False): |
| 212 | + module = self._init_transformer_with_fp8(Float8Linear) |
| 213 | + module = FSDP( |
| 214 | + module, |
| 215 | + auto_wrap_policy=ModuleWrapPolicy({TransformerBlock}), |
| 216 | + use_orig_params=True, |
| 217 | + device_id=dist.get_rank(), |
| 218 | + ) |
| 219 | + cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager") |
| 220 | + module = torch.compile(module, backend=cnt) |
| 221 | + local_inp = torch.randint(0, 16, (1, 4), device="cuda") |
| 222 | + out = module(local_inp) |
| 223 | + out.sum().backward() |
| 224 | + self.assertEqual(cnt.frame_count, 19) # TODO! |
| 225 | + |
| 226 | + |
| 227 | +class TestFloat8CompileNCCLPG(TestFloat8CompileCommon, FSDPTest): |
| 228 | + @property |
| 229 | + def world_size(self) -> int: |
| 230 | + return min(torch.cuda.device_count(), 2) |
| 231 | + |
| 232 | + @skip_if_lt_x_gpu(2) |
| 233 | + def test_transformer_parity_no_mp(self): |
| 234 | + """ |
| 235 | + Test numeric parity against manual data parallelism without using |
| 236 | + FSDP's mixed precision. |
| 237 | + """ |
| 238 | + self.run_subtests( |
| 239 | + { |
| 240 | + "module_cls": [Float8Linear, Float8DynamicLinear], |
| 241 | + "checkpoint_activations": [False, True], |
| 242 | + }, |
| 243 | + self._test_transformer_parity_no_mp, |
| 244 | + ) |
| 245 | + |
| 246 | + def _test_transformer_parity_no_mp( |
| 247 | + self, module_cls: Type, checkpoint_activations: bool |
| 248 | + ): |
| 249 | + model = self._init_transformer_with_fp8(module_cls, checkpoint_activations) |
| 250 | + ref_model = copy.deepcopy(model).cuda() |
| 251 | + ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2) |
| 252 | + fsdp_model = FSDP( |
| 253 | + model, |
| 254 | + auto_wrap_policy=ModuleWrapPolicy({TransformerBlock}), |
| 255 | + use_orig_params=True, |
| 256 | + device_id=self.rank, |
| 257 | + ) |
| 258 | + fsdp_optim = torch.optim.Adam(fsdp_model.parameters(), lr=1e-2) |
| 259 | + |
| 260 | + local_inp = torch.randint(0, 16, (1, 4), device="cuda") |
| 261 | + with self.enable_amax_init(False), self.enable_pre_and_post_forward( |
| 262 | + False |
| 263 | + ) if module_cls is Float8Linear else contextlib.nullcontext(): |
| 264 | + for iter_idx in range(10): |
| 265 | + losses: List[torch.Tensor] = [] |
| 266 | + for model, optim in ((ref_model, ref_optim), (fsdp_model, fsdp_optim)): |
| 267 | + optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) |
| 268 | + losses.append(model(local_inp).sum()) |
| 269 | + losses[-1].backward() |
| 270 | + if model is ref_model: |
| 271 | + for param in model.parameters(): |
| 272 | + dist.all_reduce(param.grad) |
| 273 | + param.grad.div_(self.world_size) |
| 274 | + if module_cls is Float8Linear: |
| 275 | + sync_float8_amax_and_scale_history(model) |
| 276 | + optim.step() |
| 277 | + self.assertEqual(losses[0], losses[1]) |
| 278 | + |
| 279 | + @skip_if_lt_x_gpu(2) |
| 280 | + def test_transformer_parity_bf16_mp(self): |
| 281 | + """ |
| 282 | + Test numeric parity against manual data parallelism using FSDP's bf16 |
| 283 | + mixed precision. |
| 284 | + """ |
| 285 | + self.run_subtests( |
| 286 | + { |
| 287 | + "module_cls": [Float8Linear, Float8DynamicLinear], |
| 288 | + "checkpoint_activations": [False, True], |
| 289 | + }, |
| 290 | + self._test_transformer_parity_bf16_mp, |
| 291 | + ) |
| 292 | + |
| 293 | + def _test_transformer_parity_bf16_mp( |
| 294 | + self, module_cls: Type, checkpoint_activations: bool |
| 295 | + ): |
| 296 | + model = self._init_transformer_with_fp8(module_cls, checkpoint_activations) |
| 297 | + ref_model = copy.deepcopy(model).cuda() # used for optimizer |
| 298 | + ref_model_bf16 = copy.deepcopy(ref_model).to( |
| 299 | + torch.bfloat16 |
| 300 | + ) # used for forward/backward |
| 301 | + ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2) |
| 302 | + fsdp_model = FSDP( |
| 303 | + model, |
| 304 | + auto_wrap_policy=ModuleWrapPolicy({TransformerBlock}), |
| 305 | + use_orig_params=True, |
| 306 | + device_id=self.rank, |
| 307 | + mixed_precision=MixedPrecision(param_dtype=torch.bfloat16), |
| 308 | + ) |
| 309 | + fsdp_optim = torch.optim.Adam(fsdp_model.parameters(), lr=1e-2) |
| 310 | + |
| 311 | + local_inp = torch.randint(0, 16, (1, 4), device="cuda") |
| 312 | + with self.enable_amax_init(False), self.enable_pre_and_post_forward( |
| 313 | + False |
| 314 | + ) if module_cls is Float8Linear else contextlib.nullcontext(): |
| 315 | + for iter_idx in range(10): |
| 316 | + losses: List[torch.Tensor] = [] |
| 317 | + for model, optim in ( |
| 318 | + (ref_model_bf16, ref_optim), |
| 319 | + (fsdp_model, fsdp_optim), |
| 320 | + ): |
| 321 | + optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) |
| 322 | + losses.append(model(local_inp).sum()) |
| 323 | + losses[-1].backward() |
| 324 | + if model is ref_model_bf16: |
| 325 | + for param_bf16, param_fp32 in zip( |
| 326 | + ref_model_bf16.parameters(), ref_model.parameters() |
| 327 | + ): |
| 328 | + dist.all_reduce(param_bf16.grad) |
| 329 | + param_bf16.grad.div_(self.world_size) |
| 330 | + param_fp32.grad = param_bf16.grad.float() |
| 331 | + param_bf16.grad = None |
| 332 | + if module_cls is Float8Linear: |
| 333 | + sync_float8_amax_and_scale_history(model) |
| 334 | + optim.step() |
| 335 | + for param_fp32, param_bf16 in zip( |
| 336 | + ref_model.parameters(), ref_model_bf16.parameters() |
| 337 | + ): |
| 338 | + param_bf16.detach().copy_(param_fp32) |
| 339 | + self.assertEqual(losses[0], losses[1]) |
| 340 | + |
| 341 | + |
| 342 | +if __name__ == "__main__": |
| 343 | + run_tests() |
0 commit comments