Skip to content

Commit 54989b3

Browse files
Sam Luryefacebook-github-bot
authored andcommitted
Fix slice to selection range conversion (#29)
Summary: Pull Request resolved: #29 Previous implementation was converting `Slice(offset, sizes=[size, 1, 1, ...], strides=[stride, ...])` into a selection with `Range(offset, offset + size, stride)`, which is wrong. It should be `Range(offset, offset + size * stride, stride)`. Reviewed By: shayne-fletcher, andrewjcg Differential Revision: D75102147 fbshipit-source-id: 979c3712e65a49cd4e04de71921b82baead721bc
1 parent 28ba977 commit 54989b3

File tree

2 files changed

+5
-7
lines changed

2 files changed

+5
-7
lines changed

controller/src/lib.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,11 @@ fn slice_to_selection(slice: Slice) -> Selection {
310310
([], []) => dsl::range(slice.offset()..=slice.offset(), dsl::true_()),
311311
// Special case trivial range `Selection`.
312312
([size, rsizes @ ..], [stride, ..]) if rsizes.iter().all(|s| *s == 1) => dsl::range(
313-
Range(slice.offset(), Some(slice.offset() + *size), *stride),
313+
Range(
314+
slice.offset(),
315+
Some(slice.offset() + *size * *stride),
316+
*stride,
317+
),
314318
dsl::true_(),
315319
),
316320
// Fallback to more heavy-weight translation for everything else.

python/tests/test_controller.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import sys
1212
import traceback
1313
from contextlib import contextmanager
14-
from enum import Enum
1514
from typing import Generator
1615

1716
import monarch
@@ -324,8 +323,6 @@ def test_movement(self, backend_type):
324323
_ = b.to_mesh(sm1)
325324

326325
def test_broadcast_one(self, backend_type):
327-
if backend_type == BackendType.RS:
328-
pytest.skip("deadlocks on rust")
329326
with self.local_device_mesh(2, 2, backend_type) as device_mesh:
330327
for dim in ("host", "gpu"):
331328
subset = device_mesh(**{dim: 1})
@@ -341,9 +338,6 @@ def test_broadcast_one(self, backend_type):
341338
assert torch.allclose(a.expand(2, -1), b, rtol=0, atol=0)
342339

343340
def test_broadcast_two(self, backend_type):
344-
if backend_type == BackendType.RS:
345-
pytest.skip("deadlocks on rust")
346-
347341
with self.local_device_mesh(2, 2, backend_type) as device_mesh:
348342
subset = device_mesh(host=1, gpu=1)
349343
with subset.activate():

0 commit comments

Comments
 (0)