Skip to content

Fix warnings in test_controller.py #30

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions python/monarch/_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def local_rust_device_mesh(
num_hosts,
gpu_per_host,
activate: bool = True,
controller_params=None,
) -> Generator[DeviceMesh, None, None]:
# Create a new system and mesh for test.
with local_mesh(
Expand All @@ -111,6 +112,7 @@ def local_rust_device_mesh(
num_worker_procs=num_hosts * gpu_per_host,
gpus_per_host=gpu_per_host,
),
controller_params=controller_params,
) as dm:
try:
if activate:
Expand All @@ -129,11 +131,13 @@ def local_rust_device_mesh(

@contextmanager
def local_device_mesh(
self, num_hosts, gpu_per_host, activate=True, rust=False
self, num_hosts, gpu_per_host, activate=True, rust=False, controller_params=None
) -> Generator[DeviceMesh, None, None]:
start = time.time()
if rust:
generator = self.local_rust_device_mesh(num_hosts, gpu_per_host, activate)
generator = self.local_rust_device_mesh(
num_hosts, gpu_per_host, activate, controller_params=controller_params
)
else:
generator = self.local_py_device_mesh(num_hosts, gpu_per_host, activate)
with generator as dm:
Expand Down
4 changes: 2 additions & 2 deletions python/monarch/common/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def slice_mesh(self, **kwargs: Dict[str, Union[int, slice]]) -> "MeshSliceTensor
# because a device mesh also has caches for doing collectives.
# but this is an easy way to create a MeshSliceTensor until we optimize
# how we represent mesh slices.
slicing = self.mesh(**kwargs)
slicing = self.mesh.slice(**kwargs)
return MeshSliceTensor(self, slicing)

def delete_ref(self, ref: int):
Expand Down Expand Up @@ -432,7 +432,7 @@ def to_mesh(
for dim in broadcast_dims
]
destinations = [
mesh(**dict(dim_settings)).processes
mesh.slice(**dict(dim_settings)).processes
for dim_settings in itertools.product(*dim_sequences)
]
else:
Expand Down
70 changes: 39 additions & 31 deletions python/tests/test_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import sys
import traceback
from contextlib import contextmanager
from typing import Generator

import monarch
import monarch.random
Expand Down Expand Up @@ -68,7 +67,7 @@ def local_rust_device_mesh(
gpu_per_host,
activate: bool = True,
controller_params: ControllerParams | None = None,
) -> Generator[DeviceMesh, None, None]:
):
with local_mesh(
hosts=hosts,
gpus_per_host=gpu_per_host,
Expand Down Expand Up @@ -111,10 +110,11 @@ def local_device_mesh(
backend_type,
activate=True,
):
return (
local.local_device_mesh(N, gpu_per_host, activate)
if backend_type == BackendType.PY
else local_rust_device_mesh(N, gpu_per_host, activate)
return local.local_device_mesh(
N,
gpu_per_host,
activate,
rust=backend_type == BackendType.RS,
)

def test_errors(self, backend_type):
Expand All @@ -124,7 +124,7 @@ def test_errors(self, backend_type):
with pytest.raises(TypeError, match="LOCAL_TENSOR"):
t.add(y)
with pytest.raises(TypeError, match="WRONG_MESH"):
sm = device_mesh(host=0)
sm = device_mesh.slice(host=0)
with sm.activate():
x = torch.rand(3, 4)
x.add(y)
Expand All @@ -137,8 +137,8 @@ def test_errors(self, backend_type):

def test_sub_mesh(self, backend_type):
with self.local_device_mesh(2, 2, backend_type) as device_mesh:
h0 = device_mesh(host=0)
h1 = device_mesh(host=1)
h0 = device_mesh.slice(host=0)
h1 = device_mesh.slice(host=1)
with h0.activate():
_ = torch.rand(3, 4)
with h1.activate():
Expand Down Expand Up @@ -167,7 +167,7 @@ def test_dim1_mesh(self, backend_type):

def test_sub_mesh_use_only_one(self, backend_type):
with self.local_device_mesh(2, 2, backend_type, activate=False) as device_mesh:
h0 = device_mesh(host=0)
h0 = device_mesh.slice(host=0)

with h0.activate():
x = torch.ones(3, 4)
Expand All @@ -178,7 +178,7 @@ def test_sub_mesh_use_only_one(self, backend_type):

def test_sub_mesh_process_grop(self, backend_type):
with self.local_device_mesh(2, 2, backend_type, activate=False) as device_mesh:
h0 = device_mesh(host=0)
h0 = device_mesh.slice(host=0)
pg0 = h0.process_group(("gpu",))
pg1 = h0.process_group(("gpu",))
# Is there a way to functionally test that these two PG's aren't
Expand Down Expand Up @@ -309,8 +309,8 @@ def test_mutate(self, backend_type):

def test_movement(self, backend_type):
with self.local_device_mesh(2, 2, backend_type) as device_mesh:
sm0 = device_mesh(host=0)
sm1 = device_mesh(host=1)
sm0 = device_mesh.slice(host=0)
sm1 = device_mesh.slice(host=1)

with sm0.activate():
x = torch.rand(3, 4, device="cuda")
Expand All @@ -325,7 +325,7 @@ def test_movement(self, backend_type):
def test_broadcast_one(self, backend_type):
with self.local_device_mesh(2, 2, backend_type) as device_mesh:
for dim in ("host", "gpu"):
subset = device_mesh(**{dim: 1})
subset = device_mesh.slice(**{dim: 1})
with subset.activate():
x = torch.rand(3, device="cuda")
y = x.to_mesh(device_mesh)
Expand All @@ -339,7 +339,7 @@ def test_broadcast_one(self, backend_type):

def test_broadcast_two(self, backend_type):
with self.local_device_mesh(2, 2, backend_type) as device_mesh:
subset = device_mesh(host=1, gpu=1)
subset = device_mesh.slice(host=1, gpu=1)
with subset.activate():
x = torch.rand(3, device="cuda")
y = x.to_mesh(device_mesh)
Expand Down Expand Up @@ -368,8 +368,8 @@ def test_autograd(self, backend_type):

def test_mesh_semantics(self, backend_type):
with self.local_device_mesh(2, 2, backend_type) as device_mesh:
host0 = device_mesh(host=0)
host1 = device_mesh(host=1)
host0 = device_mesh.slice(host=0)
host1 = device_mesh.slice(host=1)
with host0.activate():
x = torch.randn(5)
y = x * 5
Expand All @@ -392,8 +392,8 @@ def backward(grad_x: Tensor):
return x.to_mesh(mesh), backward

with self.local_device_mesh(2, 2, backend_type) as device_mesh:
host0 = device_mesh(host=0)
host1 = device_mesh(host=1)
host0 = device_mesh.slice(host=0)
host1 = device_mesh.slice(host=1)
with host0.activate():
x = torch.rand(3, 4, requires_grad=True, device="cuda")
y = torch.rand(4, 3, requires_grad=True, device="cuda")
Expand Down Expand Up @@ -440,7 +440,7 @@ def test_to_mesh_aliasing(self, backend_type):
),
pp=2,
)
pp_meshes = [ppmesh(pp=i) for i in range(2)]
pp_meshes = [ppmesh.slice(pp=i) for i in range(2)]

with ppmesh.activate():
with pp_meshes[0].activate():
Expand Down Expand Up @@ -469,8 +469,8 @@ def test_to_mesh_cow(self, backend_type):
def test_to_mesh_stream(self, backend_type):
other = monarch.Stream("other")
with self.local_device_mesh(2, 2, backend_type) as mesh:
m0 = mesh(host=0)
m1 = mesh(host=1)
m0 = mesh.slice(host=0)
m1 = mesh.slice(host=1)
with m0.activate():
t2 = torch.rand(3, 4, device="cuda").to_mesh(m1, stream=other)
with m1.activate(), other.activate():
Expand All @@ -492,7 +492,7 @@ def test_dropped_trace(self, backend_type):

def test_sub_mesh_reduce(self, backend_type):
with self.local_device_mesh(2, 2, backend_type) as device_mesh:
host1 = device_mesh(host=1)
host1 = device_mesh.slice(host=1)
with host1.activate():
myrank = (
(device_mesh.rank("host") + 1) * 2 + device_mesh.rank("gpu") + 1
Expand Down Expand Up @@ -579,8 +579,8 @@ def test_reduce_pytree(self, backend_type):

def test_to_mesh_pytree(self, backend_type):
with self.local_device_mesh(2, 2, backend_type) as device_mesh:
host0 = device_mesh(host=0)
host1 = device_mesh(host=1)
host0 = device_mesh.slice(host=0)
host1 = device_mesh.slice(host=1)

with host0.activate():
a = torch.zeros((1,), device="cuda")
Expand Down Expand Up @@ -610,8 +610,8 @@ def test_slice_mesh_pytree(self, backend_type):
host0_slices = monarch.slice_mesh(tensor_dict, host=0)
host1_slices = monarch.slice_mesh(tensor_dict, host=1)

host0 = device_mesh(host=0)
host1 = device_mesh(host=1)
host0 = device_mesh.slice(host=0)
host1 = device_mesh.slice(host=1)

host0_tensors = monarch.to_mesh(host0_slices, host0)
host1_tensors = monarch.to_mesh(host1_slices, host1)
Expand Down Expand Up @@ -646,8 +646,12 @@ def test_panicking_worker():

def test_timeout_warning(caplog):
timeout = 3
with local_rust_device_mesh(
1, 2, True, ControllerParams(1, timeout, 100, False)
with local.local_device_mesh(
1,
2,
True,
rust=True,
controller_params=ControllerParams(1, timeout, 100, False),
) as dm:
for _ in range(3):
dm.client.new_node([], [])
Expand All @@ -672,8 +676,12 @@ def test_timeout_warning(caplog):

def test_timeout_failure():
timeout = 3
with local_rust_device_mesh(
1, 1, True, ControllerParams(1, timeout, 100, True)
with local.local_device_mesh(
1,
1,
True,
rust=True,
controller_params=ControllerParams(1, timeout, 100, True),
) as dm:
for _ in range(3):
dm.client.new_node([], [])
Expand Down
Loading