|
15 | 15 | PipelineStage,
|
16 | 16 | Schedule1F1B,
|
17 | 17 | ScheduleGPipe,
|
18 |
| - ScheduleInterleaved1F1B, |
19 |
| - ScheduleLoopedBFS, |
20 | 18 | )
|
21 | 19 | from torch.testing._internal.common_cuda import TEST_MULTIGPU
|
22 | 20 | from torch.testing._internal.common_distributed import (
|
|
32 | 30 |
|
33 | 31 | d_hid = 512
|
34 | 32 | batch_size = 256
|
| 33 | +chunks = 4 |
35 | 34 |
|
36 | 35 | torch.manual_seed(0)
|
37 | 36 |
|
@@ -64,7 +63,6 @@ def test_kwargs_with_tracer(self, ScheduleClass):
|
64 | 63 | target = torch.randn(batch_size, d_hid, device=self.device)
|
65 | 64 | loss_fn = torch.nn.MSELoss(reduction="sum")
|
66 | 65 |
|
67 |
| - chunks = 4 |
68 | 66 | pipe = pipeline(
|
69 | 67 | mod,
|
70 | 68 | chunks,
|
@@ -125,7 +123,6 @@ def test_grad_with_tracer(self, ScheduleClass, ModelClass):
|
125 | 123 | ref_loss.backward()
|
126 | 124 |
|
127 | 125 | # Create a pipeline
|
128 |
| - chunks = 4 |
129 | 126 | split_spec = mod.split_spec if hasattr(mod, "split_spec") else None
|
130 | 127 | pipe = pipeline(
|
131 | 128 | mod,
|
@@ -203,7 +200,6 @@ def test_grad_with_manual(self, ScheduleClass):
|
203 | 200 | # Get a submodule, e.g. `layers.0` or `layers.1`
|
204 | 201 | submod_name = f"layers.{self.rank}"
|
205 | 202 | stage_module = full_mod.get_submodule(submod_name)
|
206 |
| - chunks = 4 |
207 | 203 | # Create a pipeline stage to wrap that submodule
|
208 | 204 | stage = ManualPipelineStage(
|
209 | 205 | stage_module,
|
@@ -251,96 +247,6 @@ def test_grad_with_manual(self, ScheduleClass):
|
251 | 247 | print(f"Gradient test failed for {name}: {p.grad} vs {ref_p.grad}")
|
252 | 248 | raise
|
253 | 249 |
|
254 |
| - @requires_nccl() |
255 |
| - @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") |
256 |
| - @parametrize("ScheduleClass", [ScheduleInterleaved1F1B, ScheduleLoopedBFS]) |
257 |
| - def test_grad_with_manual_interleaved(self, ScheduleClass): |
258 |
| - stages_per_rank = 2 |
259 |
| - n_stages = stages_per_rank * self.world_size |
260 |
| - full_mod = MultiMLP(d_hid, n_layers=n_stages) |
261 |
| - full_mod.to(self.device) |
262 |
| - |
263 |
| - ref_mod = copy.deepcopy(full_mod) |
264 |
| - x = torch.randn(batch_size, d_hid, device=self.device) |
265 |
| - with torch.no_grad(): |
266 |
| - y = ref_mod(x) |
267 |
| - # Add a small perturbation |
268 |
| - target = y + torch.randn(batch_size, d_hid, device=self.device) |
269 |
| - |
270 |
| - loss_fn = torch.nn.MSELoss(reduction="sum") |
271 |
| - |
272 |
| - # Run reference |
273 |
| - for _ in range(2): |
274 |
| - ref_mod.zero_grad() |
275 |
| - ref_out = ref_mod(x) |
276 |
| - ref_loss = loss_fn(ref_out, target) |
277 |
| - ref_loss.backward() |
278 |
| - |
279 |
| - # Get a submodule, e.g. `layers.0` or `layers.1` |
280 |
| - stage_indices = [ |
281 |
| - self.rank + i * self.world_size for i in range(stages_per_rank) |
282 |
| - ] |
283 |
| - print(f"Rank {self.rank} stages: {stage_indices}") |
284 |
| - submod_names = [f"layers.{i}" for i in stage_indices] |
285 |
| - stage_modules = [ |
286 |
| - full_mod.get_submodule(submod_name) for submod_name in submod_names |
287 |
| - ] |
288 |
| - # Create a pipeline stage to wrap that submodule |
289 |
| - chunks = 8 |
290 |
| - input_args = x.chunk(chunks)[0] |
291 |
| - stages = [ |
292 |
| - ManualPipelineStage( |
293 |
| - stage_module, |
294 |
| - stage_idx, |
295 |
| - n_stages, |
296 |
| - self.device, |
297 |
| - chunks, |
298 |
| - input_args=input_args, |
299 |
| - ) |
300 |
| - for stage_module, stage_idx in zip(stage_modules, stage_indices) |
301 |
| - ] |
302 |
| - |
303 |
| - # Attach to a schedule |
304 |
| - schedule = ScheduleClass(stages, chunks, loss_fn=loss_fn) |
305 |
| - |
306 |
| - # Run |
307 |
| - for _ in range(2): |
308 |
| - # Zero gradients |
309 |
| - for stage_module in stage_modules: |
310 |
| - stage_module.zero_grad() |
311 |
| - if self.rank == 0: |
312 |
| - schedule.step(x) |
313 |
| - elif self.rank == self.world_size - 1: |
314 |
| - losses = [] |
315 |
| - out = schedule.step(target=target, losses=losses) |
316 |
| - else: |
317 |
| - schedule.step() |
318 |
| - |
319 |
| - dist.barrier() |
320 |
| - |
321 |
| - # Last rank checks result |
322 |
| - if self.rank == self.world_size - 1: |
323 |
| - # Check output |
324 |
| - torch.testing.assert_close(out, ref_out) |
325 |
| - # Check loss |
326 |
| - # Since the reduction used in the loss function above is "sum", we use |
327 |
| - # "sum" here to reduce microbatch losses into a single value too. |
328 |
| - pipe_loss = sum(losses) |
329 |
| - torch.testing.assert_close(pipe_loss, ref_loss) |
330 |
| - |
331 |
| - # Every rank checks gradients |
332 |
| - for stage_module, submod_name in zip(stage_modules, submod_names): |
333 |
| - # Get corresponding submodule from reference model |
334 |
| - ref_submod = ref_mod.get_submodule(submod_name) |
335 |
| - # Check gradients per parameter |
336 |
| - for name, p in stage_module.named_parameters(): |
337 |
| - ref_p = ref_submod.get_parameter(name) |
338 |
| - try: |
339 |
| - torch.testing.assert_close(p.grad, ref_p.grad, rtol=1e-5, atol=4e-5) |
340 |
| - except AssertionError: |
341 |
| - print(f"Gradient test failed for {name}: {p.grad} vs {ref_p.grad}") |
342 |
| - raise |
343 |
| - |
344 | 250 |
|
345 | 251 | instantiate_parametrized_tests(ScheduleTest)
|
346 | 252 |
|
|
0 commit comments