Skip to content

Commit d08e538

Browse files
amirafzalifacebook-github-bot
authored andcommitted
create UDF builtins for rng/seed torch functions (#27)
Summary: Pull Request resolved: #27 **Diff Purpose & Changes** 1. Creating the random/RNG remote functions for our new builtins library. torch.manual_seed(seed) torch.initial_seed() torch.get_rng_state() torch.set_rng_state(state) torch.cuda.get_rng_state_all() torch.cuda.set_rng_state_all(states) These two appear to be the same function. we can consider removing one or the other in the library. torch.seed() torch.random.seed() Reviewed By: vidhyav, colin2328 Differential Revision: D72944566 fbshipit-source-id: 77abd37b764685c9d6d29ecfb860b435f07b0e11
1 parent c1172ad commit d08e538

File tree

2 files changed

+263
-0
lines changed

2 files changed

+263
-0
lines changed

python/monarch/builtins/random.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre strict
8+
from typing import Callable
9+
10+
import torch
11+
from monarch.common.remote import remote
12+
13+
14+
@remote(propagate="inspect")
15+
def set_manual_seed_remote(seed: int, process_idx: int = 0) -> None:
16+
torch.manual_seed(seed ^ process_idx)
17+
18+
19+
@remote(propagate=lambda: 0)
20+
def initial_seed_remote() -> int:
21+
return torch.initial_seed()
22+
23+
24+
@remote(propagate=lambda: torch.zeros(1))
25+
def get_rng_state_remote() -> torch.Tensor:
26+
return torch.get_rng_state()
27+
28+
29+
@remote(propagate="inspect")
30+
def set_rng_state_remote(new_state: torch.Tensor) -> None:
31+
torch.set_rng_state(new_state)
32+
33+
34+
def _run_no_return(f: Callable) -> None:
35+
f()
36+
return None
37+
38+
39+
# TODO: return result when uint64 is supported from remote function
40+
@remote(propagate=lambda: _run_no_return(torch.seed))
41+
def seed_remote() -> None:
42+
torch.seed()
43+
44+
45+
# same underlying implementation as seed_remote (torch.seed)
46+
# TODO: return result when uint64 is supported from remote function
47+
@remote(propagate=lambda: _run_no_return(torch.random.seed))
48+
def random_seed_remote() -> None:
49+
torch.random.seed()
50+
51+
52+
@remote(propagate="inspect")
53+
def manual_seed_cuda_remote(seed: int) -> None:
54+
torch.cuda.manual_seed(seed)
55+
56+
57+
@remote(propagate="inspect")
58+
def manual_seed_all_cuda_remote(seed: int) -> None:
59+
torch.cuda.manual_seed_all(seed)
60+
61+
62+
@remote(propagate=lambda: [torch.zeros(1)])
63+
def get_rng_state_all_cuda_remote() -> list[torch.Tensor]:
64+
return torch.cuda.get_rng_state_all()
65+
66+
67+
@remote(propagate="inspect")
68+
def set_rng_state_all_cuda_remote(states: list[torch.Tensor]) -> None:
69+
torch.cuda.set_rng_state_all(states)

python/tests/builtins/test_random.py

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
import pytest
9+
import torch
10+
from monarch import fetch_shard, no_mesh
11+
12+
from monarch._testing import BackendType, TestingContext
13+
from monarch.builtins.random import (
14+
get_rng_state_all_cuda_remote,
15+
get_rng_state_remote,
16+
initial_seed_remote,
17+
manual_seed_all_cuda_remote,
18+
manual_seed_cuda_remote,
19+
random_seed_remote,
20+
seed_remote,
21+
set_manual_seed_remote,
22+
set_rng_state_all_cuda_remote,
23+
set_rng_state_remote,
24+
)
25+
26+
27+
@pytest.mark.timeout(120)
28+
@pytest.mark.parametrize("backend_type", [BackendType.PY, BackendType.RS])
29+
class TestRandomFunctions:
30+
local = None
31+
32+
@classmethod
33+
def setup_class(cls):
34+
cls.local = TestingContext().__enter__()
35+
36+
@classmethod
37+
def teardown_class(cls):
38+
if cls.local is not None:
39+
cls.local.__exit__(None, None, None)
40+
41+
@classmethod
42+
def local_device_mesh(cls, num_hosts, gpu_per_host, backend_type, activate=True):
43+
return cls.local.local_device_mesh(
44+
num_hosts,
45+
gpu_per_host,
46+
activate,
47+
rust=backend_type == BackendType.RS,
48+
)
49+
50+
def test_set_manual_seed_remote(self, backend_type):
51+
with self.local_device_mesh(1, 1, backend_type) as device_mesh:
52+
with device_mesh.activate():
53+
set_manual_seed_remote(12345)
54+
t1 = torch.rand(5, 5)
55+
56+
set_manual_seed_remote(12345)
57+
t2 = torch.rand(5, 5)
58+
59+
set_manual_seed_remote(12346)
60+
t3 = torch.rand(5, 5)
61+
62+
# t1 == t2 (same seed), t1 != t3 (different seed)
63+
result = fetch_shard((t1, t2, t3)).result()
64+
with no_mesh.activate():
65+
assert torch.equal(result[0], result[1])
66+
assert not torch.equal(result[0], result[2])
67+
68+
def test_set_manual_seed_remote_with_process_idx(self, backend_type):
69+
with self.local_device_mesh(1, 1, backend_type) as device_mesh:
70+
with device_mesh.activate():
71+
set_manual_seed_remote(12345, process_idx=0)
72+
t1 = torch.rand(5, 5)
73+
74+
set_manual_seed_remote(12345, process_idx=1)
75+
t2 = torch.rand(5, 5)
76+
77+
result = fetch_shard((t1, t2)).result()
78+
with no_mesh.activate():
79+
assert not torch.equal(result[0], result[1])
80+
81+
def test_initial_seed_remote(self, backend_type):
82+
with self.local_device_mesh(1, 1, backend_type) as device_mesh:
83+
with device_mesh.activate():
84+
seed_value = initial_seed_remote()
85+
86+
result = fetch_shard(seed_value).result()
87+
with no_mesh.activate():
88+
assert isinstance(result, int)
89+
90+
def test_get_rng_state(self, backend_type):
91+
with self.local_device_mesh(1, 1, backend_type) as device_mesh:
92+
with device_mesh.activate():
93+
state1 = get_rng_state_remote()
94+
state2 = get_rng_state_remote()
95+
96+
# generate a random tensor to change the state
97+
_ = torch.rand(5, 5)
98+
99+
state3 = get_rng_state_remote()
100+
101+
result = fetch_shard((state1, state2, state3)).result()
102+
with no_mesh.activate():
103+
assert torch.equal(result[0], result[1])
104+
assert not torch.equal(result[0], result[2])
105+
106+
def test_set_rng_state(self, backend_type):
107+
with self.local_device_mesh(1, 1, backend_type) as device_mesh:
108+
with device_mesh.activate():
109+
# save the initial RNG state
110+
state = get_rng_state_remote()
111+
112+
t1 = torch.rand(3, 3)
113+
t2 = torch.rand(3, 3)
114+
115+
# restore the saved RNG state
116+
set_rng_state_remote(state)
117+
t3 = torch.rand(3, 3)
118+
119+
# t1 == t3 (same state), t1 != t2 (different state)
120+
result = fetch_shard((t1, t2, t3)).result()
121+
with no_mesh.activate():
122+
assert not torch.equal(result[0], result[1])
123+
assert torch.equal(result[0], result[2])
124+
125+
# seed and random.seed seem to be the same function.
126+
def test_random_seed(self, backend_type):
127+
with self.local_device_mesh(1, 1, backend_type) as device_mesh:
128+
with device_mesh.activate():
129+
random_seed_remote()
130+
t1 = torch.rand(5, 5)
131+
132+
random_seed_remote()
133+
t2 = torch.rand(5, 5)
134+
135+
seed_remote()
136+
t3 = torch.rand(5, 5)
137+
138+
result = fetch_shard((t1, t2, t3)).result()
139+
with no_mesh.activate():
140+
assert not torch.equal(result[0], result[1])
141+
assert not torch.equal(result[1], result[2])
142+
143+
def test_get_rng_state_all_cuda(self, backend_type):
144+
NUM_GPUS = 1
145+
with self.local_device_mesh(1, NUM_GPUS, backend_type) as device_mesh:
146+
with device_mesh.activate():
147+
states = get_rng_state_all_cuda_remote()
148+
149+
result = fetch_shard(states).result()
150+
with no_mesh.activate():
151+
assert isinstance(result, list)
152+
assert len(result) == NUM_GPUS
153+
154+
def test_set_rng_state_all_cuda(self, backend_type):
155+
with self.local_device_mesh(1, 1, backend_type) as device_mesh:
156+
with device_mesh.activate():
157+
# save the initial RNG states
158+
states = get_rng_state_all_cuda_remote()
159+
t1 = torch.rand(3, 3, device="cuda")
160+
161+
# restore the saved RNG states
162+
set_rng_state_all_cuda_remote(states)
163+
t2 = torch.rand(3, 3, device="cuda")
164+
165+
# t1 == t2 (same state)
166+
result = fetch_shard((t1, t2)).result()
167+
with no_mesh.activate():
168+
assert torch.equal(result[0], result[1])
169+
170+
def test_cuda_manual_seed(self, backend_type):
171+
with self.local_device_mesh(1, 1, backend_type) as device_mesh:
172+
with device_mesh.activate():
173+
self._cuda_seed_test(manual_seed_cuda_remote)
174+
175+
def test_cuda_manual_seed_all(self, backend_type):
176+
with self.local_device_mesh(1, 1, backend_type) as device_mesh:
177+
with device_mesh.activate():
178+
self._cuda_seed_test(manual_seed_all_cuda_remote)
179+
180+
def _cuda_seed_test(self, seed_func):
181+
seed_func(12345)
182+
t1 = torch.rand(5, 5, device="cuda")
183+
184+
seed_func(12345)
185+
t2 = torch.rand(5, 5, device="cuda")
186+
187+
seed_func(54321)
188+
t3 = torch.rand(5, 5, device="cuda")
189+
190+
# t1 = t2 (same seed), t1 != t3 (different seed)
191+
result = fetch_shard((t1, t2, t3)).result()
192+
with no_mesh.activate():
193+
assert torch.equal(result[0], result[1])
194+
assert not torch.equal(result[0], result[2])

0 commit comments

Comments
 (0)