|
1 | 1 | import torch
|
2 |
| -from torch.distributed._tensor import DTensor, Shard, Replicate |
3 |
| - |
| 2 | +from torch.distributed import DeviceMesh |
| 3 | +from torch.distributed._tensor import DTensor, Shard, Replicate, Placement |
| 4 | +from torch.distributed.tensor._utils import compute_local_shape_and_global_offset |
4 | 5 |
|
5 | 6 | from collections import defaultdict
|
| 7 | +from typing import Optional, Sequence |
6 | 8 |
|
7 | 9 | from torchchat.distributed.logging_utils import SingletonLogger
|
8 | 10 | logger = SingletonLogger.get_logger()
|
9 | 11 |
|
10 | 12 |
|
11 |
| -def convert_to_dtensor(weight_tensor, dtensor_template): |
12 |
| - """Adjust a loaded tensor to match the shape/placement of the model DTensor and copy the data into it""" |
13 |
| - |
14 |
| - if weight_tensor.shape != dtensor_template.shape: |
| 13 | +def convert_to_dtensor( |
| 14 | + full_tensor: torch.Tensor, |
| 15 | + dtensor_template: DTensor, |
| 16 | +) -> DTensor: |
| 17 | + """ |
| 18 | + Converts a full tensor to a DTensor with the same placements as the given |
| 19 | + DTensor template. |
| 20 | + """ |
| 21 | + if full_tensor.shape != dtensor_template.shape: |
15 | 22 | raise ValueError(
|
16 |
| - f"Shape mismatch: weight tensor shape {weight_tensor.shape} " |
| 23 | + f"Shape mismatch: weight tensor shape {full_tensor.shape} " |
17 | 24 | f"doesn't match DTensor shape {dtensor_template.shape}"
|
18 | 25 | )
|
19 | 26 |
|
20 |
| - placements = dtensor_template.placements |
21 |
| - mesh = dtensor_template.device_mesh |
22 |
| - mesh_dims = mesh.ndim |
23 |
| - |
24 |
| - for placement in placements: |
25 |
| - if isinstance(placement, Shard): |
26 |
| - shard_dim = placement.dim |
27 |
| - |
28 |
| - if shard_dim >= weight_tensor.dim(): |
29 |
| - raise ValueError( |
30 |
| - f"Shard dimension {shard_dim} is out of range for tensor with {weight_tensor.dim()} dimensions." |
31 |
| - ) |
32 |
| - |
33 |
| - num_shards = mesh.size( |
34 |
| - 0 |
35 |
| - ) # Assuming sharding is always along the first mesh dimension |
36 |
| - shard_size = weight_tensor.size(shard_dim) // num_shards |
37 |
| - shard_index = mesh.get_coordinate()[0] |
38 |
| - |
39 |
| - start_idx = shard_index * shard_size |
40 |
| - end_idx = start_idx + shard_size |
41 |
| - |
42 |
| - slice_list = [slice(None)] * weight_tensor.dim() |
43 |
| - slice_list[shard_dim] = slice(start_idx, end_idx) |
44 |
| - weight_tensor = weight_tensor[tuple(slice_list)] |
45 |
| - |
46 |
| - elif isinstance(placement, Replicate): |
47 |
| - continue |
48 |
| - else: |
49 |
| - raise ValueError(f"Unsupported placement type: {type(placement)}") |
50 |
| - |
51 |
| - new_dtensor = DTensor.from_local(weight_tensor, mesh, placements) |
52 |
| - |
| 27 | + new_dtensor = shard( |
| 28 | + full_tensor, |
| 29 | + dtensor_template.placements, |
| 30 | + dtensor_template.device_mesh |
| 31 | + ) |
53 | 32 | return new_dtensor
|
54 | 33 |
|
55 | 34 |
|
56 |
| -def inspect_dtensor_sharding(dtensor): |
57 |
| - """hepful debug util for inspecting DTensor sharding""" |
58 |
| - if not is_dtensor(dtensor): |
59 |
| - logger.info(f"This tensor {dtensor} is not a DTensor") |
60 |
| - return |
61 |
| - |
62 |
| - placements = dtensor.placements |
63 |
| - logger.info(f"DTensor shape: {dtensor.shape}") |
64 |
| - logger.info(f"Number of dimensions: {len(placements)}") |
65 |
| - |
66 |
| - for dim, placement in enumerate(placements): |
67 |
| - logger.info(f"Dimension {dim}:") |
68 |
| - logger.info(f" Placement type: {placement.type}") |
69 |
| - if placement.type == "shard": |
70 |
| - logger.info(f" Sharding spec: {placement.sharding_spec}") |
71 |
| - elif placement.type == "replicate": |
72 |
| - logger.info(" Replicated across devices") |
73 |
| - else: |
74 |
| - logger.info(f" Other placement type: {placement.type}") |
75 |
| - |
76 |
| - logger.info(f"Device mesh shape: {dtensor.device_mesh.shape}") |
77 |
| - logger.info(f"Device mesh devices: {dtensor.device_mesh.device_type}") |
| 35 | +def shard( |
| 36 | + full_tensor: torch.Tensor, |
| 37 | + placements: Sequence[Placement], |
| 38 | + device_mesh: Optional[DeviceMesh] = None, |
| 39 | +) -> DTensor: |
| 40 | + """ |
| 41 | + Shards a full tensor based on indicated placements, and returns a |
| 42 | + DTensor containing the shard. |
| 43 | + Args: |
| 44 | + full_tensor (torch.Tensor): the full tensor to be sharded. |
| 45 | + placements (Sequence[:class:`Placement`]): the placements that |
| 46 | + describes how to place the local tensor on DeviceMesh. |
| 47 | + device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to place the |
| 48 | + DTensor. Must have same dimension as the number of placements. |
| 49 | + If not specified, would be retrieve from current context. |
| 50 | + Returns: |
| 51 | + A :class:`DTensor` object with the shard as its local tensor. |
| 52 | + Examples: |
| 53 | + >>> # xdoctest: +SKIP("need world_size and rank") |
| 54 | + >>> device_mesh = dist.init_device_mesh("cuda", (world_size,)) |
| 55 | + >>> full_tensor = torch.arange(world_size, device=f"cuda:{rank}") |
| 56 | + >>> placements = [Shard(1)] |
| 57 | + >>> dtensor = shard(full_tensor, placements, device_mesh) |
| 58 | + """ |
| 59 | + device_mesh = device_mesh or _mesh_resources.get_current_mesh() |
| 60 | + |
| 61 | + shape, offset = compute_local_shape_and_global_offset( |
| 62 | + full_tensor.shape, device_mesh, placements |
| 63 | + ) |
| 64 | + slices = [ |
| 65 | + slice(cur_offset, cur_offset + cur_shape) |
| 66 | + for cur_shape, cur_offset in zip(shape, offset) |
| 67 | + ] |
| 68 | + local_tensor = full_tensor[slices] |
| 69 | + return DTensor.from_local(local_tensor, device_mesh, placements) |
0 commit comments