Skip to content

Commit aea6fa8

Browse files
Sam Luryefacebook-github-bot
authored andcommitted
Fix warnings in test_controller.py (#30)
Summary: Pull Request resolved: #30 Fix warnings related to mesh slicing in `test_controller.py`, and update the file to use `TestingContext` in order to deal with device mesh shutdown warnings. Reviewed By: colin2328 Differential Revision: D75106997 fbshipit-source-id: 05fd5c85b612c6356560a4df26854931d2d90c64
1 parent 54989b3 commit aea6fa8

File tree

3 files changed

+47
-35
lines changed

3 files changed

+47
-35
lines changed

python/monarch/_testing.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ def local_rust_device_mesh(
9898
num_hosts,
9999
gpu_per_host,
100100
activate: bool = True,
101+
controller_params=None,
101102
) -> Generator[DeviceMesh, None, None]:
102103
# Create a new system and mesh for test.
103104
with local_mesh(
@@ -111,6 +112,7 @@ def local_rust_device_mesh(
111112
num_worker_procs=num_hosts * gpu_per_host,
112113
gpus_per_host=gpu_per_host,
113114
),
115+
controller_params=controller_params,
114116
) as dm:
115117
try:
116118
if activate:
@@ -129,11 +131,13 @@ def local_rust_device_mesh(
129131

130132
@contextmanager
131133
def local_device_mesh(
132-
self, num_hosts, gpu_per_host, activate=True, rust=False
134+
self, num_hosts, gpu_per_host, activate=True, rust=False, controller_params=None
133135
) -> Generator[DeviceMesh, None, None]:
134136
start = time.time()
135137
if rust:
136-
generator = self.local_rust_device_mesh(num_hosts, gpu_per_host, activate)
138+
generator = self.local_rust_device_mesh(
139+
num_hosts, gpu_per_host, activate, controller_params=controller_params
140+
)
137141
else:
138142
generator = self.local_py_device_mesh(num_hosts, gpu_per_host, activate)
139143
with generator as dm:

python/monarch/common/tensor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ def slice_mesh(self, **kwargs: Dict[str, Union[int, slice]]) -> "MeshSliceTensor
349349
# because a device mesh also has caches for doing collectives.
350350
# but this is an easy way to create a MeshSliceTensor until we optimize
351351
# how we represent mesh slices.
352-
slicing = self.mesh(**kwargs)
352+
slicing = self.mesh.slice(**kwargs)
353353
return MeshSliceTensor(self, slicing)
354354

355355
def delete_ref(self, ref: int):
@@ -432,7 +432,7 @@ def to_mesh(
432432
for dim in broadcast_dims
433433
]
434434
destinations = [
435-
mesh(**dict(dim_settings)).processes
435+
mesh.slice(**dict(dim_settings)).processes
436436
for dim_settings in itertools.product(*dim_sequences)
437437
]
438438
else:

python/tests/test_controller.py

Lines changed: 39 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import sys
1212
import traceback
1313
from contextlib import contextmanager
14-
from typing import Generator
1514

1615
import monarch
1716
import monarch.random
@@ -68,7 +67,7 @@ def local_rust_device_mesh(
6867
gpu_per_host,
6968
activate: bool = True,
7069
controller_params: ControllerParams | None = None,
71-
) -> Generator[DeviceMesh, None, None]:
70+
):
7271
with local_mesh(
7372
hosts=hosts,
7473
gpus_per_host=gpu_per_host,
@@ -111,10 +110,11 @@ def local_device_mesh(
111110
backend_type,
112111
activate=True,
113112
):
114-
return (
115-
local.local_device_mesh(N, gpu_per_host, activate)
116-
if backend_type == BackendType.PY
117-
else local_rust_device_mesh(N, gpu_per_host, activate)
113+
return local.local_device_mesh(
114+
N,
115+
gpu_per_host,
116+
activate,
117+
rust=backend_type == BackendType.RS,
118118
)
119119

120120
def test_errors(self, backend_type):
@@ -124,7 +124,7 @@ def test_errors(self, backend_type):
124124
with pytest.raises(TypeError, match="LOCAL_TENSOR"):
125125
t.add(y)
126126
with pytest.raises(TypeError, match="WRONG_MESH"):
127-
sm = device_mesh(host=0)
127+
sm = device_mesh.slice(host=0)
128128
with sm.activate():
129129
x = torch.rand(3, 4)
130130
x.add(y)
@@ -137,8 +137,8 @@ def test_errors(self, backend_type):
137137

138138
def test_sub_mesh(self, backend_type):
139139
with self.local_device_mesh(2, 2, backend_type) as device_mesh:
140-
h0 = device_mesh(host=0)
141-
h1 = device_mesh(host=1)
140+
h0 = device_mesh.slice(host=0)
141+
h1 = device_mesh.slice(host=1)
142142
with h0.activate():
143143
_ = torch.rand(3, 4)
144144
with h1.activate():
@@ -167,7 +167,7 @@ def test_dim1_mesh(self, backend_type):
167167

168168
def test_sub_mesh_use_only_one(self, backend_type):
169169
with self.local_device_mesh(2, 2, backend_type, activate=False) as device_mesh:
170-
h0 = device_mesh(host=0)
170+
h0 = device_mesh.slice(host=0)
171171

172172
with h0.activate():
173173
x = torch.ones(3, 4)
@@ -178,7 +178,7 @@ def test_sub_mesh_use_only_one(self, backend_type):
178178

179179
def test_sub_mesh_process_grop(self, backend_type):
180180
with self.local_device_mesh(2, 2, backend_type, activate=False) as device_mesh:
181-
h0 = device_mesh(host=0)
181+
h0 = device_mesh.slice(host=0)
182182
pg0 = h0.process_group(("gpu",))
183183
pg1 = h0.process_group(("gpu",))
184184
# Is there a way to functionally test that these two PG's aren't
@@ -309,8 +309,8 @@ def test_mutate(self, backend_type):
309309

310310
def test_movement(self, backend_type):
311311
with self.local_device_mesh(2, 2, backend_type) as device_mesh:
312-
sm0 = device_mesh(host=0)
313-
sm1 = device_mesh(host=1)
312+
sm0 = device_mesh.slice(host=0)
313+
sm1 = device_mesh.slice(host=1)
314314

315315
with sm0.activate():
316316
x = torch.rand(3, 4, device="cuda")
@@ -325,7 +325,7 @@ def test_movement(self, backend_type):
325325
def test_broadcast_one(self, backend_type):
326326
with self.local_device_mesh(2, 2, backend_type) as device_mesh:
327327
for dim in ("host", "gpu"):
328-
subset = device_mesh(**{dim: 1})
328+
subset = device_mesh.slice(**{dim: 1})
329329
with subset.activate():
330330
x = torch.rand(3, device="cuda")
331331
y = x.to_mesh(device_mesh)
@@ -339,7 +339,7 @@ def test_broadcast_one(self, backend_type):
339339

340340
def test_broadcast_two(self, backend_type):
341341
with self.local_device_mesh(2, 2, backend_type) as device_mesh:
342-
subset = device_mesh(host=1, gpu=1)
342+
subset = device_mesh.slice(host=1, gpu=1)
343343
with subset.activate():
344344
x = torch.rand(3, device="cuda")
345345
y = x.to_mesh(device_mesh)
@@ -368,8 +368,8 @@ def test_autograd(self, backend_type):
368368

369369
def test_mesh_semantics(self, backend_type):
370370
with self.local_device_mesh(2, 2, backend_type) as device_mesh:
371-
host0 = device_mesh(host=0)
372-
host1 = device_mesh(host=1)
371+
host0 = device_mesh.slice(host=0)
372+
host1 = device_mesh.slice(host=1)
373373
with host0.activate():
374374
x = torch.randn(5)
375375
y = x * 5
@@ -392,8 +392,8 @@ def backward(grad_x: Tensor):
392392
return x.to_mesh(mesh), backward
393393

394394
with self.local_device_mesh(2, 2, backend_type) as device_mesh:
395-
host0 = device_mesh(host=0)
396-
host1 = device_mesh(host=1)
395+
host0 = device_mesh.slice(host=0)
396+
host1 = device_mesh.slice(host=1)
397397
with host0.activate():
398398
x = torch.rand(3, 4, requires_grad=True, device="cuda")
399399
y = torch.rand(4, 3, requires_grad=True, device="cuda")
@@ -440,7 +440,7 @@ def test_to_mesh_aliasing(self, backend_type):
440440
),
441441
pp=2,
442442
)
443-
pp_meshes = [ppmesh(pp=i) for i in range(2)]
443+
pp_meshes = [ppmesh.slice(pp=i) for i in range(2)]
444444

445445
with ppmesh.activate():
446446
with pp_meshes[0].activate():
@@ -469,8 +469,8 @@ def test_to_mesh_cow(self, backend_type):
469469
def test_to_mesh_stream(self, backend_type):
470470
other = monarch.Stream("other")
471471
with self.local_device_mesh(2, 2, backend_type) as mesh:
472-
m0 = mesh(host=0)
473-
m1 = mesh(host=1)
472+
m0 = mesh.slice(host=0)
473+
m1 = mesh.slice(host=1)
474474
with m0.activate():
475475
t2 = torch.rand(3, 4, device="cuda").to_mesh(m1, stream=other)
476476
with m1.activate(), other.activate():
@@ -492,7 +492,7 @@ def test_dropped_trace(self, backend_type):
492492

493493
def test_sub_mesh_reduce(self, backend_type):
494494
with self.local_device_mesh(2, 2, backend_type) as device_mesh:
495-
host1 = device_mesh(host=1)
495+
host1 = device_mesh.slice(host=1)
496496
with host1.activate():
497497
myrank = (
498498
(device_mesh.rank("host") + 1) * 2 + device_mesh.rank("gpu") + 1
@@ -579,8 +579,8 @@ def test_reduce_pytree(self, backend_type):
579579

580580
def test_to_mesh_pytree(self, backend_type):
581581
with self.local_device_mesh(2, 2, backend_type) as device_mesh:
582-
host0 = device_mesh(host=0)
583-
host1 = device_mesh(host=1)
582+
host0 = device_mesh.slice(host=0)
583+
host1 = device_mesh.slice(host=1)
584584

585585
with host0.activate():
586586
a = torch.zeros((1,), device="cuda")
@@ -610,8 +610,8 @@ def test_slice_mesh_pytree(self, backend_type):
610610
host0_slices = monarch.slice_mesh(tensor_dict, host=0)
611611
host1_slices = monarch.slice_mesh(tensor_dict, host=1)
612612

613-
host0 = device_mesh(host=0)
614-
host1 = device_mesh(host=1)
613+
host0 = device_mesh.slice(host=0)
614+
host1 = device_mesh.slice(host=1)
615615

616616
host0_tensors = monarch.to_mesh(host0_slices, host0)
617617
host1_tensors = monarch.to_mesh(host1_slices, host1)
@@ -646,8 +646,12 @@ def test_panicking_worker():
646646

647647
def test_timeout_warning(caplog):
648648
timeout = 3
649-
with local_rust_device_mesh(
650-
1, 2, True, ControllerParams(1, timeout, 100, False)
649+
with local.local_device_mesh(
650+
1,
651+
2,
652+
True,
653+
rust=True,
654+
controller_params=ControllerParams(1, timeout, 100, False),
651655
) as dm:
652656
for _ in range(3):
653657
dm.client.new_node([], [])
@@ -672,8 +676,12 @@ def test_timeout_warning(caplog):
672676

673677
def test_timeout_failure():
674678
timeout = 3
675-
with local_rust_device_mesh(
676-
1, 1, True, ControllerParams(1, timeout, 100, True)
679+
with local.local_device_mesh(
680+
1,
681+
1,
682+
True,
683+
rust=True,
684+
controller_params=ControllerParams(1, timeout, 100, True),
677685
) as dm:
678686
for _ in range(3):
679687
dm.client.new_node([], [])

0 commit comments

Comments
 (0)