Skip to content

Commit 127e255

Browse files
authored
feature: compare local pids (#611)
* rustler_sys: add 'nif_compare_pids' * rustler: add (Partial)Eq/Ord for LocalPid * rustler_tests: add tests for LocalPid cmp/eq * sys: define 'enif_compare_pids' to behave like macro in C code * tests: add unit test with equality check for local pids
1 parent b882d51 commit 127e255

File tree

6 files changed

+86
-0
lines changed

6 files changed

+86
-0
lines changed

rustler/src/types/local_pid.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use crate::wrapper::{pid, ErlNifPid};
22
use crate::{Decoder, Encoder, Env, Error, NifResult, Term};
3+
use std::cmp::Ordering;
34
use std::mem::MaybeUninit;
45

56
#[derive(Copy, Clone)]
@@ -36,6 +37,27 @@ impl Encoder for LocalPid {
3637
}
3738
}
3839

40+
impl PartialEq for LocalPid {
41+
fn eq(&self, other: &Self) -> bool {
42+
unsafe { rustler_sys::enif_compare_pids(self.as_c_arg(), other.as_c_arg()) == 0 }
43+
}
44+
}
45+
46+
impl Eq for LocalPid {}
47+
48+
impl PartialOrd for LocalPid {
49+
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
50+
Some(self.cmp(other))
51+
}
52+
}
53+
54+
impl Ord for LocalPid {
55+
fn cmp(&self, other: &Self) -> Ordering {
56+
let cmp = unsafe { rustler_sys::enif_compare_pids(self.as_c_arg(), other.as_c_arg()) };
57+
cmp.cmp(&0)
58+
}
59+
}
60+
3961
impl<'a> Env<'a> {
4062
/// Return the calling process's pid.
4163
///

rustler_sys/src/rustler_sys_api.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,12 @@ pub unsafe fn enif_make_pid(_env: *mut ErlNifEnv, pid: ErlNifPid) -> ERL_NIF_TER
203203
pid.pid
204204
}
205205

206+
/// See [enif_compare_pids](http://erlang.org/doc/man/erl_nif.html#enif_compare_pids) in the Erlang docs
207+
pub unsafe fn enif_compare_pids(pid1: *const ErlNifPid, pid2: *const ErlNifPid) -> c_int {
208+
// Mimics the implementation of the enif_compare_pids macro
209+
enif_compare((*pid1).pid, (*pid2).pid)
210+
}
211+
206212
/// See [ErlNifSysInfo](http://www.erlang.org/doc/man/erl_nif.html#ErlNifSysInfo) in the Erlang docs.
207213
#[allow(missing_copy_implementations)]
208214
#[repr(C)]

rustler_tests/lib/rustler_test.ex

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ defmodule RustlerTest do
3333
def sum_list(_), do: err()
3434
def make_list(), do: err()
3535

36+
def compare_local_pids(_, _), do: err()
37+
def are_equal_local_pids(_, _), do: err()
38+
3639
def term_debug(_), do: err()
3740

3841
def term_debug_and_reparse(term) do

rustler_tests/native/rustler_test/src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ mod test_dirty;
55
mod test_env;
66
mod test_error;
77
mod test_list;
8+
mod test_local_pid;
89
mod test_map;
910
mod test_nif_attrs;
1011
mod test_path;
@@ -28,6 +29,8 @@ rustler::init!(
2829
test_primitives::echo_i128,
2930
test_list::sum_list,
3031
test_list::make_list,
32+
test_local_pid::compare_local_pids,
33+
test_local_pid::are_equal_local_pids,
3134
test_term::term_debug,
3235
test_term::term_eq,
3336
test_term::term_cmp,
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
use std::cmp::Ordering;
2+
3+
use rustler::LocalPid;
4+
5+
#[rustler::nif]
6+
pub fn compare_local_pids(lhs: LocalPid, rhs: LocalPid) -> i32 {
7+
match lhs.cmp(&rhs) {
8+
Ordering::Less => -1,
9+
Ordering::Equal => 0,
10+
Ordering::Greater => 1,
11+
}
12+
}
13+
14+
#[rustler::nif]
15+
pub fn are_equal_local_pids(lhs: LocalPid, rhs: LocalPid) -> bool {
16+
lhs == rhs
17+
}

rustler_tests/test/local_pid_test.exs

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
defmodule RustlerTest.LocalPidTest do
2+
use ExUnit.Case, async: true
3+
4+
def make_pid() do
5+
{:ok, pid} = Task.start(fn -> :ok end)
6+
pid
7+
end
8+
9+
def compare(lhs, rhs) do
10+
cond do
11+
lhs < rhs -> -1
12+
lhs == rhs -> 0
13+
lhs > rhs -> 1
14+
end
15+
end
16+
17+
test "local pid comparison" do
18+
# We make sure that the code we have in rust code matches the comparisons
19+
# that are performed in the BEAM code.
20+
pids = for _ <- 1..3, do: make_pid()
21+
22+
for lhs <- pids, rhs <- pids do
23+
assert RustlerTest.compare_local_pids(lhs, rhs) == compare(lhs, rhs)
24+
end
25+
end
26+
27+
test "local pid equality" do
28+
pids = for _ <- 1..3, do: make_pid()
29+
30+
for lhs <- pids, rhs <- pids do
31+
expected = lhs == rhs
32+
assert RustlerTest.are_equal_local_pids(lhs, rhs) == expected
33+
end
34+
end
35+
end

0 commit comments

Comments
 (0)