11
11
import sys
12
12
import traceback
13
13
from contextlib import contextmanager
14
- from typing import Generator
15
14
16
15
import monarch
17
16
import monarch .random
@@ -68,7 +67,7 @@ def local_rust_device_mesh(
68
67
gpu_per_host ,
69
68
activate : bool = True ,
70
69
controller_params : ControllerParams | None = None ,
71
- ) -> Generator [ DeviceMesh , None , None ] :
70
+ ):
72
71
with local_mesh (
73
72
hosts = hosts ,
74
73
gpus_per_host = gpu_per_host ,
@@ -111,10 +110,11 @@ def local_device_mesh(
111
110
backend_type ,
112
111
activate = True ,
113
112
):
114
- return (
115
- local .local_device_mesh (N , gpu_per_host , activate )
116
- if backend_type == BackendType .PY
117
- else local_rust_device_mesh (N , gpu_per_host , activate )
113
+ return local .local_device_mesh (
114
+ N ,
115
+ gpu_per_host ,
116
+ activate ,
117
+ rust = backend_type == BackendType .RS ,
118
118
)
119
119
120
120
def test_errors (self , backend_type ):
@@ -124,7 +124,7 @@ def test_errors(self, backend_type):
124
124
with pytest .raises (TypeError , match = "LOCAL_TENSOR" ):
125
125
t .add (y )
126
126
with pytest .raises (TypeError , match = "WRONG_MESH" ):
127
- sm = device_mesh (host = 0 )
127
+ sm = device_mesh . slice (host = 0 )
128
128
with sm .activate ():
129
129
x = torch .rand (3 , 4 )
130
130
x .add (y )
@@ -137,8 +137,8 @@ def test_errors(self, backend_type):
137
137
138
138
def test_sub_mesh (self , backend_type ):
139
139
with self .local_device_mesh (2 , 2 , backend_type ) as device_mesh :
140
- h0 = device_mesh (host = 0 )
141
- h1 = device_mesh (host = 1 )
140
+ h0 = device_mesh . slice (host = 0 )
141
+ h1 = device_mesh . slice (host = 1 )
142
142
with h0 .activate ():
143
143
_ = torch .rand (3 , 4 )
144
144
with h1 .activate ():
@@ -167,7 +167,7 @@ def test_dim1_mesh(self, backend_type):
167
167
168
168
def test_sub_mesh_use_only_one (self , backend_type ):
169
169
with self .local_device_mesh (2 , 2 , backend_type , activate = False ) as device_mesh :
170
- h0 = device_mesh (host = 0 )
170
+ h0 = device_mesh . slice (host = 0 )
171
171
172
172
with h0 .activate ():
173
173
x = torch .ones (3 , 4 )
@@ -178,7 +178,7 @@ def test_sub_mesh_use_only_one(self, backend_type):
178
178
179
179
def test_sub_mesh_process_grop (self , backend_type ):
180
180
with self .local_device_mesh (2 , 2 , backend_type , activate = False ) as device_mesh :
181
- h0 = device_mesh (host = 0 )
181
+ h0 = device_mesh . slice (host = 0 )
182
182
pg0 = h0 .process_group (("gpu" ,))
183
183
pg1 = h0 .process_group (("gpu" ,))
184
184
# Is there a way to functionally test that these two PG's aren't
@@ -309,8 +309,8 @@ def test_mutate(self, backend_type):
309
309
310
310
def test_movement (self , backend_type ):
311
311
with self .local_device_mesh (2 , 2 , backend_type ) as device_mesh :
312
- sm0 = device_mesh (host = 0 )
313
- sm1 = device_mesh (host = 1 )
312
+ sm0 = device_mesh . slice (host = 0 )
313
+ sm1 = device_mesh . slice (host = 1 )
314
314
315
315
with sm0 .activate ():
316
316
x = torch .rand (3 , 4 , device = "cuda" )
@@ -325,7 +325,7 @@ def test_movement(self, backend_type):
325
325
def test_broadcast_one (self , backend_type ):
326
326
with self .local_device_mesh (2 , 2 , backend_type ) as device_mesh :
327
327
for dim in ("host" , "gpu" ):
328
- subset = device_mesh (** {dim : 1 })
328
+ subset = device_mesh . slice (** {dim : 1 })
329
329
with subset .activate ():
330
330
x = torch .rand (3 , device = "cuda" )
331
331
y = x .to_mesh (device_mesh )
@@ -339,7 +339,7 @@ def test_broadcast_one(self, backend_type):
339
339
340
340
def test_broadcast_two (self , backend_type ):
341
341
with self .local_device_mesh (2 , 2 , backend_type ) as device_mesh :
342
- subset = device_mesh (host = 1 , gpu = 1 )
342
+ subset = device_mesh . slice (host = 1 , gpu = 1 )
343
343
with subset .activate ():
344
344
x = torch .rand (3 , device = "cuda" )
345
345
y = x .to_mesh (device_mesh )
@@ -368,8 +368,8 @@ def test_autograd(self, backend_type):
368
368
369
369
def test_mesh_semantics (self , backend_type ):
370
370
with self .local_device_mesh (2 , 2 , backend_type ) as device_mesh :
371
- host0 = device_mesh (host = 0 )
372
- host1 = device_mesh (host = 1 )
371
+ host0 = device_mesh . slice (host = 0 )
372
+ host1 = device_mesh . slice (host = 1 )
373
373
with host0 .activate ():
374
374
x = torch .randn (5 )
375
375
y = x * 5
@@ -392,8 +392,8 @@ def backward(grad_x: Tensor):
392
392
return x .to_mesh (mesh ), backward
393
393
394
394
with self .local_device_mesh (2 , 2 , backend_type ) as device_mesh :
395
- host0 = device_mesh (host = 0 )
396
- host1 = device_mesh (host = 1 )
395
+ host0 = device_mesh . slice (host = 0 )
396
+ host1 = device_mesh . slice (host = 1 )
397
397
with host0 .activate ():
398
398
x = torch .rand (3 , 4 , requires_grad = True , device = "cuda" )
399
399
y = torch .rand (4 , 3 , requires_grad = True , device = "cuda" )
@@ -440,7 +440,7 @@ def test_to_mesh_aliasing(self, backend_type):
440
440
),
441
441
pp = 2 ,
442
442
)
443
- pp_meshes = [ppmesh (pp = i ) for i in range (2 )]
443
+ pp_meshes = [ppmesh . slice (pp = i ) for i in range (2 )]
444
444
445
445
with ppmesh .activate ():
446
446
with pp_meshes [0 ].activate ():
@@ -469,8 +469,8 @@ def test_to_mesh_cow(self, backend_type):
469
469
def test_to_mesh_stream (self , backend_type ):
470
470
other = monarch .Stream ("other" )
471
471
with self .local_device_mesh (2 , 2 , backend_type ) as mesh :
472
- m0 = mesh (host = 0 )
473
- m1 = mesh (host = 1 )
472
+ m0 = mesh . slice (host = 0 )
473
+ m1 = mesh . slice (host = 1 )
474
474
with m0 .activate ():
475
475
t2 = torch .rand (3 , 4 , device = "cuda" ).to_mesh (m1 , stream = other )
476
476
with m1 .activate (), other .activate ():
@@ -492,7 +492,7 @@ def test_dropped_trace(self, backend_type):
492
492
493
493
def test_sub_mesh_reduce (self , backend_type ):
494
494
with self .local_device_mesh (2 , 2 , backend_type ) as device_mesh :
495
- host1 = device_mesh (host = 1 )
495
+ host1 = device_mesh . slice (host = 1 )
496
496
with host1 .activate ():
497
497
myrank = (
498
498
(device_mesh .rank ("host" ) + 1 ) * 2 + device_mesh .rank ("gpu" ) + 1
@@ -579,8 +579,8 @@ def test_reduce_pytree(self, backend_type):
579
579
580
580
def test_to_mesh_pytree (self , backend_type ):
581
581
with self .local_device_mesh (2 , 2 , backend_type ) as device_mesh :
582
- host0 = device_mesh (host = 0 )
583
- host1 = device_mesh (host = 1 )
582
+ host0 = device_mesh . slice (host = 0 )
583
+ host1 = device_mesh . slice (host = 1 )
584
584
585
585
with host0 .activate ():
586
586
a = torch .zeros ((1 ,), device = "cuda" )
@@ -610,8 +610,8 @@ def test_slice_mesh_pytree(self, backend_type):
610
610
host0_slices = monarch .slice_mesh (tensor_dict , host = 0 )
611
611
host1_slices = monarch .slice_mesh (tensor_dict , host = 1 )
612
612
613
- host0 = device_mesh (host = 0 )
614
- host1 = device_mesh (host = 1 )
613
+ host0 = device_mesh . slice (host = 0 )
614
+ host1 = device_mesh . slice (host = 1 )
615
615
616
616
host0_tensors = monarch .to_mesh (host0_slices , host0 )
617
617
host1_tensors = monarch .to_mesh (host1_slices , host1 )
@@ -646,8 +646,12 @@ def test_panicking_worker():
646
646
647
647
def test_timeout_warning (caplog ):
648
648
timeout = 3
649
- with local_rust_device_mesh (
650
- 1 , 2 , True , ControllerParams (1 , timeout , 100 , False )
649
+ with local .local_device_mesh (
650
+ 1 ,
651
+ 2 ,
652
+ True ,
653
+ rust = True ,
654
+ controller_params = ControllerParams (1 , timeout , 100 , False ),
651
655
) as dm :
652
656
for _ in range (3 ):
653
657
dm .client .new_node ([], [])
@@ -672,8 +676,12 @@ def test_timeout_warning(caplog):
672
676
673
677
def test_timeout_failure ():
674
678
timeout = 3
675
- with local_rust_device_mesh (
676
- 1 , 1 , True , ControllerParams (1 , timeout , 100 , True )
679
+ with local .local_device_mesh (
680
+ 1 ,
681
+ 1 ,
682
+ True ,
683
+ rust = True ,
684
+ controller_params = ControllerParams (1 , timeout , 100 , True ),
677
685
) as dm :
678
686
for _ in range (3 ):
679
687
dm .client .new_node ([], [])
0 commit comments