Skip to content

Commit d75531b

Browse files
committed
Add a shard helper for creating DTensors
1 parent 246b783 commit d75531b

File tree

1 file changed

+54
-62
lines changed

1 file changed

+54
-62
lines changed
Lines changed: 54 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,77 +1,69 @@
11
import torch
2-
from torch.distributed._tensor import DTensor, Shard, Replicate
3-
2+
from torch.distributed import DeviceMesh
3+
from torch.distributed._tensor import DTensor, Shard, Replicate, Placement
4+
from torch.distributed.tensor._utils import compute_local_shape_and_global_offset
45

56
from collections import defaultdict
7+
from typing import Optional, Sequence
68

79
from torchchat.distributed.logging_utils import SingletonLogger
810
logger = SingletonLogger.get_logger()
911

1012

11-
def convert_to_dtensor(weight_tensor, dtensor_template):
12-
"""Adjust a loaded tensor to match the shape/placement of the model DTensor and copy the data into it"""
13-
14-
if weight_tensor.shape != dtensor_template.shape:
13+
def convert_to_dtensor(
14+
full_tensor: torch.Tensor,
15+
dtensor_template: DTensor,
16+
) -> DTensor:
17+
"""
18+
Converts a full tensor to a DTensor with the same placements as the given
19+
DTensor template.
20+
"""
21+
if full_tensor.shape != dtensor_template.shape:
1522
raise ValueError(
16-
f"Shape mismatch: weight tensor shape {weight_tensor.shape} "
23+
f"Shape mismatch: weight tensor shape {full_tensor.shape} "
1724
f"doesn't match DTensor shape {dtensor_template.shape}"
1825
)
1926

20-
placements = dtensor_template.placements
21-
mesh = dtensor_template.device_mesh
22-
mesh_dims = mesh.ndim
23-
24-
for placement in placements:
25-
if isinstance(placement, Shard):
26-
shard_dim = placement.dim
27-
28-
if shard_dim >= weight_tensor.dim():
29-
raise ValueError(
30-
f"Shard dimension {shard_dim} is out of range for tensor with {weight_tensor.dim()} dimensions."
31-
)
32-
33-
num_shards = mesh.size(
34-
0
35-
) # Assuming sharding is always along the first mesh dimension
36-
shard_size = weight_tensor.size(shard_dim) // num_shards
37-
shard_index = mesh.get_coordinate()[0]
38-
39-
start_idx = shard_index * shard_size
40-
end_idx = start_idx + shard_size
41-
42-
slice_list = [slice(None)] * weight_tensor.dim()
43-
slice_list[shard_dim] = slice(start_idx, end_idx)
44-
weight_tensor = weight_tensor[tuple(slice_list)]
45-
46-
elif isinstance(placement, Replicate):
47-
continue
48-
else:
49-
raise ValueError(f"Unsupported placement type: {type(placement)}")
50-
51-
new_dtensor = DTensor.from_local(weight_tensor, mesh, placements)
52-
27+
new_dtensor = shard(
28+
full_tensor,
29+
dtensor_template.placements,
30+
dtensor_template.device_mesh
31+
)
5332
return new_dtensor
5433

5534

56-
def inspect_dtensor_sharding(dtensor):
57-
"""hepful debug util for inspecting DTensor sharding"""
58-
if not is_dtensor(dtensor):
59-
logger.info(f"This tensor {dtensor} is not a DTensor")
60-
return
61-
62-
placements = dtensor.placements
63-
logger.info(f"DTensor shape: {dtensor.shape}")
64-
logger.info(f"Number of dimensions: {len(placements)}")
65-
66-
for dim, placement in enumerate(placements):
67-
logger.info(f"Dimension {dim}:")
68-
logger.info(f" Placement type: {placement.type}")
69-
if placement.type == "shard":
70-
logger.info(f" Sharding spec: {placement.sharding_spec}")
71-
elif placement.type == "replicate":
72-
logger.info(" Replicated across devices")
73-
else:
74-
logger.info(f" Other placement type: {placement.type}")
75-
76-
logger.info(f"Device mesh shape: {dtensor.device_mesh.shape}")
77-
logger.info(f"Device mesh devices: {dtensor.device_mesh.device_type}")
35+
def shard(
36+
full_tensor: torch.Tensor,
37+
placements: Sequence[Placement],
38+
device_mesh: Optional[DeviceMesh] = None,
39+
) -> DTensor:
40+
"""
41+
Shards a full tensor based on indicated placements, and returns a
42+
DTensor containing the shard.
43+
Args:
44+
full_tensor (torch.Tensor): the full tensor to be sharded.
45+
placements (Sequence[:class:`Placement`]): the placements that
46+
describes how to place the local tensor on DeviceMesh.
47+
device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to place the
48+
DTensor. Must have same dimension as the number of placements.
49+
If not specified, would be retrieve from current context.
50+
Returns:
51+
A :class:`DTensor` object with the shard as its local tensor.
52+
Examples:
53+
>>> # xdoctest: +SKIP("need world_size and rank")
54+
>>> device_mesh = dist.init_device_mesh("cuda", (world_size,))
55+
>>> full_tensor = torch.arange(world_size, device=f"cuda:{rank}")
56+
>>> placements = [Shard(1)]
57+
>>> dtensor = shard(full_tensor, placements, device_mesh)
58+
"""
59+
device_mesh = device_mesh or _mesh_resources.get_current_mesh()
60+
61+
shape, offset = compute_local_shape_and_global_offset(
62+
full_tensor.shape, device_mesh, placements
63+
)
64+
slices = [
65+
slice(cur_offset, cur_offset + cur_shape)
66+
for cur_shape, cur_offset in zip(shape, offset)
67+
]
68+
local_tensor = full_tensor[slices]
69+
return DTensor.from_local(local_tensor, device_mesh, placements)

0 commit comments

Comments
 (0)