Skip to content

Commit ae89bec

Browse files
amirafzalifacebook-github-bot
authored andcommitted
create UDF builtins for rng/seed torch functions
Summary: **Diff Purpose & Changes** 1. Creating the random/RNG remote functions for our new builtins library. torch.manual_seed(seed) torch.initial_seed() torch.get_rng_state() torch.set_rng_state(state) torch.cuda.get_rng_state_all() torch.cuda.set_rng_state_all(states) These two appear to be the same function. we can consider removing one or the other in the library. torch.seed() torch.random.seed() Reviewed By: vidhyav, colin2328 Differential Revision: D72944566
1 parent a1acccd commit ae89bec

File tree

2 files changed

+238
-0
lines changed

2 files changed

+238
-0
lines changed

python/monarch/builtins/random.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import random
2+
3+
import torch
4+
from monarch.common.remote import remote
5+
6+
7+
@remote(propagate="inspect")
8+
def set_manual_seed_remote(seed: int, process_idx: int = 0) -> None:
9+
torch.manual_seed(seed ^ process_idx)
10+
11+
12+
@remote(propagate=lambda: 0)
13+
def initial_seed_remote() -> int:
14+
return torch.initial_seed()
15+
16+
17+
@remote(propagate=lambda: torch.zeros(1))
18+
def get_rng_state_remote() -> torch.Tensor:
19+
return torch.get_rng_state()
20+
21+
22+
@remote(propagate="inspect")
23+
def set_rng_state_remote(new_state: torch.Tensor) -> None:
24+
torch.set_rng_state(new_state)
25+
26+
27+
@remote(propagate=lambda: int(random.random()))
28+
def seed_remote() -> int:
29+
return torch.seed()
30+
31+
32+
@remote(propagate=lambda: int(random.random()))
33+
def random_seed_remote() -> int:
34+
return torch.random.seed()
35+
36+
37+
@remote(propagate="inspect")
38+
def manual_seed_cuda_remote(seed: int) -> None:
39+
torch.cuda.manual_seed(seed)
40+
41+
42+
@remote(propagate="inspect")
43+
def manual_seed_all_cuda_remote(seed: int) -> None:
44+
torch.cuda.manual_seed_all(seed)
45+
46+
47+
@remote(propagate=lambda: [torch.zeros(1)])
48+
def get_rng_state_all_cuda_remote() -> list[torch.Tensor]:
49+
return torch.cuda.get_rng_state_all()
50+
51+
52+
@remote(propagate="inspect")
53+
def set_rng_state_all_cuda_remote(states: list[torch.Tensor]) -> None:
54+
torch.cuda.set_rng_state_all(states)

python/tests/builtins/test_random.py

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
# pyre-unsafe
2+
import pytest
3+
import torch
4+
from monarch import fetch_shard, no_mesh
5+
6+
from monarch._testing import BackendType, TestingContext
7+
from monarch.builtins.random import (
8+
get_rng_state_all_cuda_remote,
9+
get_rng_state_remote,
10+
initial_seed_remote,
11+
manual_seed_all_cuda_remote,
12+
manual_seed_cuda_remote,
13+
random_seed_remote,
14+
seed_remote,
15+
set_manual_seed_remote,
16+
set_rng_state_all_cuda_remote,
17+
set_rng_state_remote,
18+
)
19+
20+
21+
@pytest.fixture(scope="module", autouse=True)
22+
def testing_context():
23+
global local
24+
with TestingContext() as local:
25+
yield
26+
27+
28+
@pytest.mark.timeout(120)
29+
@pytest.mark.parametrize("backend_type", [BackendType.PY, BackendType.RS])
30+
class TestRandomFunctions:
31+
@classmethod
32+
def local_device_mesh(cls, num_hosts, gpu_per_host, backend_type, activate=True):
33+
return local.local_device_mesh(
34+
num_hosts,
35+
gpu_per_host,
36+
activate,
37+
rust=backend_type == BackendType.RS,
38+
)
39+
40+
def test_set_manual_seed_remote(self, backend_type):
41+
with self.local_device_mesh(1, 1, backend_type) as device_mesh:
42+
with device_mesh.activate():
43+
set_manual_seed_remote(12345)
44+
t1 = torch.rand(5, 5)
45+
46+
set_manual_seed_remote(12345)
47+
t2 = torch.rand(5, 5)
48+
49+
set_manual_seed_remote(12346)
50+
t3 = torch.rand(5, 5)
51+
52+
# t1 == t2 (same seed), t1 != t3 (different seed)
53+
result = fetch_shard((t1, t2, t3)).result()
54+
with no_mesh.activate():
55+
assert torch.equal(result[0], result[1])
56+
assert not torch.equal(result[0], result[2])
57+
58+
def test_set_manual_seed_remote_with_process_idx(self, backend_type):
59+
with self.local_device_mesh(1, 1, backend_type) as device_mesh:
60+
with device_mesh.activate():
61+
set_manual_seed_remote(12345, process_idx=0)
62+
t1 = torch.rand(5, 5)
63+
64+
set_manual_seed_remote(12345, process_idx=1)
65+
t2 = torch.rand(5, 5)
66+
67+
result = fetch_shard((t1, t2)).result()
68+
with no_mesh.activate():
69+
assert not torch.equal(result[0], result[1])
70+
71+
def test_initial_seed_remote(self, backend_type):
72+
with self.local_device_mesh(1, 1, backend_type) as device_mesh:
73+
with device_mesh.activate():
74+
seed_value = initial_seed_remote()
75+
76+
result = fetch_shard(seed_value).result()
77+
with no_mesh.activate():
78+
assert isinstance(result, int)
79+
80+
def test_get_rng_state(self, backend_type):
81+
with self.local_device_mesh(1, 1, backend_type) as device_mesh:
82+
with device_mesh.activate():
83+
state1 = get_rng_state_remote()
84+
state2 = get_rng_state_remote()
85+
86+
# generate a random tensor to change the state
87+
_ = torch.rand(5, 5)
88+
89+
state3 = get_rng_state_remote()
90+
91+
result = fetch_shard((state1, state2, state3)).result()
92+
with no_mesh.activate():
93+
assert torch.equal(result[0], result[1])
94+
assert not torch.equal(result[0], result[2])
95+
96+
def test_set_rng_state(self, backend_type):
97+
with self.local_device_mesh(1, 1, backend_type) as device_mesh:
98+
with device_mesh.activate():
99+
# save the initial RNG state
100+
state = get_rng_state_remote()
101+
102+
t1 = torch.rand(3, 3)
103+
t2 = torch.rand(3, 3)
104+
105+
# restore the saved RNG state
106+
set_rng_state_remote(state)
107+
t3 = torch.rand(3, 3)
108+
109+
# t1 == t3 (same state), t1 != t2 (different state)
110+
result = fetch_shard((t1, t2, t3)).result()
111+
with no_mesh.activate():
112+
assert not torch.equal(result[0], result[1])
113+
assert torch.equal(result[0], result[2])
114+
115+
# seed and random.seed seem to be the same function.
116+
def test_random_seed(self, backend_type):
117+
with self.local_device_mesh(1, 1, backend_type) as device_mesh:
118+
with device_mesh.activate():
119+
random_seed_remote()
120+
t1 = torch.rand(5, 5)
121+
122+
random_seed_remote()
123+
t2 = torch.rand(5, 5)
124+
125+
seed_remote()
126+
t3 = torch.rand(5, 5)
127+
128+
result = fetch_shard((t1, t2, t3)).result()
129+
with no_mesh.activate():
130+
assert not torch.equal(result[0], result[1])
131+
assert not torch.equal(result[1], result[2])
132+
133+
def test_get_rng_state_all_cuda(self, backend_type):
134+
NUM_GPUS = 1
135+
with self.local_device_mesh(1, NUM_GPUS, backend_type) as device_mesh:
136+
with device_mesh.activate():
137+
states = get_rng_state_all_cuda_remote()
138+
139+
result = fetch_shard(states).result()
140+
with no_mesh.activate():
141+
assert isinstance(result, list)
142+
assert len(result) == NUM_GPUS
143+
144+
def test_set_rng_state_all_cuda(self, backend_type):
145+
with self.local_device_mesh(1, 1, backend_type) as device_mesh:
146+
with device_mesh.activate():
147+
# save the initial RNG states
148+
states = get_rng_state_all_cuda_remote()
149+
t1 = torch.rand(3, 3, device="cuda")
150+
151+
# restore the saved RNG states
152+
set_rng_state_all_cuda_remote(states)
153+
t2 = torch.rand(3, 3, device="cuda")
154+
155+
# t1 == t2 (same state)
156+
result = fetch_shard((t1, t2)).result()
157+
with no_mesh.activate():
158+
assert torch.equal(result[0], result[1])
159+
160+
def test_cuda_manual_seed(self, backend_type):
161+
with self.local_device_mesh(1, 1, backend_type) as device_mesh:
162+
with device_mesh.activate():
163+
self._cuda_seed_test(manual_seed_cuda_remote)
164+
165+
def test_cuda_manual_seed_all(self, backend_type):
166+
with self.local_device_mesh(1, 1, backend_type) as device_mesh:
167+
with device_mesh.activate():
168+
self._cuda_seed_test(manual_seed_all_cuda_remote)
169+
170+
def _cuda_seed_test(self, seed_func):
171+
seed_func(12345)
172+
t1 = torch.rand(5, 5, device="cuda")
173+
174+
seed_func(12345)
175+
t2 = torch.rand(5, 5, device="cuda")
176+
177+
seed_func(54321)
178+
t3 = torch.rand(5, 5, device="cuda")
179+
180+
# t1 = t2 (same seed), t1 != t3 (different seed)
181+
result = fetch_shard((t1, t2, t3)).result()
182+
with no_mesh.activate():
183+
assert torch.equal(result[0], result[1])
184+
assert not torch.equal(result[0], result[2])

0 commit comments

Comments
 (0)