Skip to content

feat: Improvements to Leader-Worker barrier #1498

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 90 additions & 73 deletions lib/runtime/src/utils/leader_worker_barrier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,22 +108,7 @@ async fn create_barrier_key<T: Serialize>(
client
.kv_create(key, serialized_data, lease_id)
.await
.map_err(|_| LeaderWorkerBarrierError::BarrierIdNotUnique)?;

Ok(())
}

/// Creates a worker-specific key in etcd
async fn create_worker_key(
client: &Client,
key: &str,
lease_id: Option<i64>,
) -> Result<(), LeaderWorkerBarrierError> {
// TODO: Same as above. This can fail for many reasons.
client
.kv_create(key.to_owned(), serde_json::to_vec(&()).unwrap(), lease_id)
.await
.map_err(|_| LeaderWorkerBarrierError::BarrierWorkerIdNotUnique)?;
.map_err(|_| LeaderWorkerBarrierError::IdNotUnique)?;

Ok(())
}
Expand All @@ -140,8 +125,7 @@ async fn wait_for_signal<T: DeserializeOwned>(
#[derive(Debug)]
pub enum LeaderWorkerBarrierError {
EtcdClientNotFound,
BarrierIdNotUnique,
BarrierWorkerIdNotUnique,
IdNotUnique,
EtcdError(anyhow::Error),
SerdeError(serde_json::Error),
Timeout,
Expand All @@ -150,14 +134,16 @@ pub enum LeaderWorkerBarrierError {
}

/// A barrier for a leader to wait for a specific number of workers to join.
pub struct LeaderBarrier<T> {
pub struct LeaderBarrier<LeaderData, WorkerData> {
barrier_id: String,
num_workers: usize,
timeout: Option<Duration>,
marker: PhantomData<T>,
marker: PhantomData<(LeaderData, WorkerData)>,
}

impl<T: Serialize + DeserializeOwned> LeaderBarrier<T> {
impl<LeaderData: Serialize + DeserializeOwned, WorkerData: Serialize + DeserializeOwned>
LeaderBarrier<LeaderData, WorkerData>
{
pub fn new(barrier_id: String, num_workers: usize, timeout: Option<Duration>) -> Self {
Self {
barrier_id,
Expand All @@ -174,8 +160,8 @@ impl<T: Serialize + DeserializeOwned> LeaderBarrier<T> {
pub async fn sync(
self,
rt: &DistributedRuntime,
data: &T,
) -> anyhow::Result<(), LeaderWorkerBarrierError> {
data: &LeaderData,
) -> anyhow::Result<HashMap<String, WorkerData>, LeaderWorkerBarrierError> {
let etcd_client = rt
.etcd_client()
.ok_or(LeaderWorkerBarrierError::EtcdClientNotFound)?;
Expand All @@ -193,13 +179,17 @@ impl<T: Serialize + DeserializeOwned> LeaderBarrier<T> {
self.signal_completion(&etcd_client, &worker_result, lease_id)
.await?;

worker_result.map(|_| ())
worker_result.map(|r| {
r.into_iter()
.map(|(k, v)| (k.split("/").last().unwrap().to_string(), v))
.collect()
})
}

async fn publish_barrier_data(
&self,
client: &Client,
data: &T,
data: &LeaderData,
lease_id: i64,
) -> Result<(), LeaderWorkerBarrierError> {
let key = barrier_key(&self.barrier_id, BARRIER_DATA);
Expand All @@ -209,21 +199,24 @@ impl<T: Serialize + DeserializeOwned> LeaderBarrier<T> {
async fn wait_for_workers(
&self,
client: &Client,
) -> Result<HashSet<String>, LeaderWorkerBarrierError> {
) -> Result<HashMap<String, WorkerData>, LeaderWorkerBarrierError> {
let key = barrier_key(&self.barrier_id, BARRIER_WORKER);
let workers = wait_for_key_count::<()>(client, key, self.num_workers, self.timeout).await?;
Ok(workers.into_keys().collect())
let workers = wait_for_key_count(client, key, self.num_workers, self.timeout).await?;
Ok(workers)
}

async fn signal_completion(
&self,
client: &Client,
worker_result: &Result<HashSet<String>, LeaderWorkerBarrierError>,
worker_result: &Result<HashMap<String, WorkerData>, LeaderWorkerBarrierError>,
lease_id: i64,
) -> Result<(), LeaderWorkerBarrierError> {
if let Ok(worker_result) = worker_result {
let key = barrier_key(&self.barrier_id, BARRIER_COMPLETE);
create_barrier_key(client, key, worker_result, Some(lease_id)).await?;

let workers = worker_result.keys().collect::<HashSet<_>>();

create_barrier_key(client, key, workers, Some(lease_id)).await?;
} else {
let key = barrier_key(&self.barrier_id, BARRIER_ABORT);
create_barrier_key(client, key, (), Some(lease_id)).await?;
Expand All @@ -234,13 +227,15 @@ impl<T: Serialize + DeserializeOwned> LeaderBarrier<T> {
}

// A barrier to synchronize a worker with a leader.
pub struct WorkerBarrier<T> {
pub struct WorkerBarrier<LeaderData, WorkerData> {
barrier_id: String,
worker_id: String,
marker: PhantomData<T>,
marker: PhantomData<(LeaderData, WorkerData)>,
}

impl<T: Serialize + DeserializeOwned> WorkerBarrier<T> {
impl<LeaderData: Serialize + DeserializeOwned, WorkerData: Serialize + DeserializeOwned>
WorkerBarrier<LeaderData, WorkerData>
{
pub fn new(barrier_id: String, worker_id: String) -> Self {
Self {
barrier_id,
Expand All @@ -259,7 +254,8 @@ impl<T: Serialize + DeserializeOwned> WorkerBarrier<T> {
pub async fn sync(
self,
rt: &DistributedRuntime,
) -> anyhow::Result<T, LeaderWorkerBarrierError> {
data: &WorkerData,
) -> anyhow::Result<LeaderData, LeaderWorkerBarrierError> {
let etcd_client = rt
.etcd_client()
.ok_or(LeaderWorkerBarrierError::EtcdClientNotFound)?;
Expand All @@ -270,20 +266,23 @@ impl<T: Serialize + DeserializeOwned> WorkerBarrier<T> {
let barrier_data = self.get_barrier_data(&etcd_client).await?;

// Register as a worker
let worker_key = self.register_worker(&etcd_client, lease_id).await?;
let worker_key = self.register_worker(&etcd_client, data, lease_id).await?;

// Wait for completion or abort signal
self.wait_for_completion(&etcd_client, worker_key).await?;

Ok(barrier_data)
}

async fn get_barrier_data(&self, client: &Client) -> Result<T, LeaderWorkerBarrierError> {
async fn get_barrier_data(
&self,
client: &Client,
) -> Result<LeaderData, LeaderWorkerBarrierError> {
let data_key = barrier_key(&self.barrier_id, BARRIER_DATA);
let abort_key = barrier_key(&self.barrier_id, BARRIER_ABORT);

tokio::select! {
result = wait_for_key_count::<T>(client, data_key, 1, None) => {
result = wait_for_key_count::<LeaderData>(client, data_key, 1, None) => {
result?.into_values().next()
.ok_or(LeaderWorkerBarrierError::EtcdError(anyhow::anyhow!("No data found")))
}
Expand All @@ -296,15 +295,15 @@ impl<T: Serialize + DeserializeOwned> WorkerBarrier<T> {
async fn register_worker(
&self,
client: &Client,
data: &WorkerData,
lease_id: i64,
) -> Result<String, LeaderWorkerBarrierError> {
let key = barrier_key(
&self.barrier_id,
&format!("{}/{}", BARRIER_WORKER, self.worker_id),
);
create_worker_key(client, &key, Some(lease_id))
.await
.map(|_| key)
create_barrier_key(client, key.clone(), data, Some(lease_id)).await?;
Ok(key)
}

async fn wait_for_completion(
Expand Down Expand Up @@ -354,15 +353,15 @@ mod tests {

assert!(drt.etcd_client().is_none());

let barrier = LeaderBarrier::new("test".to_string(), 2, None);
let worker = WorkerBarrier::<()>::new("test".to_string(), "worker".to_string());
let barrier = LeaderBarrier::<String, String>::new("test".to_string(), 2, None);
let worker = WorkerBarrier::<String, String>::new("test".to_string(), "worker".to_string());

assert!(matches!(
barrier.sync(&drt, &"test".to_string()).await,
Err(LeaderWorkerBarrierError::EtcdClientNotFound)
));
assert!(matches!(
worker.sync(&drt).await,
worker.sync(&drt, &"test".to_string()).await,
Err(LeaderWorkerBarrierError::EtcdClientNotFound)
));
}
Expand All @@ -374,19 +373,24 @@ mod tests {

let id = unique_id();

let leader = LeaderBarrier::new(id.clone(), 1, None);
let worker = WorkerBarrier::<String>::new(id.clone(), "worker".to_string());
let leader = LeaderBarrier::<String, String>::new(id.clone(), 1, None);
let worker = WorkerBarrier::<String, String>::new(id.clone(), "worker".to_string());

let drt_clone = drt.clone();
let leader_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> =
tokio::spawn(async move {
leader.sync(&drt_clone, &"test_data".to_string()).await?;
let worker_data = leader.sync(&drt_clone, &"test_data".to_string()).await?;
assert_eq!(worker_data.len(), 1);
assert_eq!(
worker_data.get("worker").unwrap(),
&"test_worker".to_string()
);
Ok(())
});

let worker_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> =
tokio::spawn(async move {
let res = worker.sync(&drt).await?;
let res = worker.sync(&drt, &"test_worker".to_string()).await?;
assert_eq!(res, "test_data".to_string());

Ok(())
Expand All @@ -405,31 +409,36 @@ mod tests {

let id = unique_id();

let leader1 = LeaderBarrier::new(id.clone(), 1, None);
let leader2 = LeaderBarrier::new(id.clone(), 1, None);
let leader1 = LeaderBarrier::<String, String>::new(id.clone(), 1, None);
let leader2 = LeaderBarrier::<String, String>::new(id.clone(), 1, None);

let worker = WorkerBarrier::<String>::new(id.clone(), "worker".to_string());
let worker = WorkerBarrier::<String, String>::new(id.clone(), "worker".to_string());

let drt_clone = drt.clone();
let leader1_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> =
tokio::spawn(async move {
leader1.sync(&drt_clone, &"test_data".to_string()).await?;
let worker_data = leader1.sync(&drt_clone, &"test_data".to_string()).await?;
assert_eq!(worker_data.len(), 1);
assert_eq!(
worker_data.get("worker").unwrap(),
&"test_worker".to_string()
);

// Now, try to sync leader 2.
let leader2_res = leader2.sync(&drt_clone, &"test_data2".to_string()).await;

// Leader 2 should fail because the barrier ID is the same as leader 1.
assert!(matches!(
leader2_res,
Err(LeaderWorkerBarrierError::BarrierIdNotUnique)
Err(LeaderWorkerBarrierError::IdNotUnique)
));

Ok(())
});

let worker_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> =
tokio::spawn(async move {
let res = worker.sync(&drt).await?;
let res = worker.sync(&drt, &"test_worker".to_string()).await?;
assert_eq!(res, "test_data".to_string());

Ok(())
Expand All @@ -448,26 +457,33 @@ mod tests {

let id = unique_id();

let leader = LeaderBarrier::new(id.clone(), 1, None);
let worker1 = WorkerBarrier::<String>::new(id.clone(), "worker".to_string());
let worker2 = WorkerBarrier::<String>::new(id.clone(), "worker".to_string());
let leader = LeaderBarrier::<String, String>::new(id.clone(), 1, None);
let worker1 = WorkerBarrier::<String, String>::new(id.clone(), "worker".to_string());
let worker2 = WorkerBarrier::<String, String>::new(id.clone(), "worker".to_string());

let drt_clone = drt.clone();
let leader_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> =
tokio::spawn(async move {
leader.sync(&drt_clone, &"test_data".to_string()).await?;
let worker_data = leader.sync(&drt_clone, &"test_data".to_string()).await?;
assert_eq!(worker_data.len(), 1);
assert_eq!(
worker_data.get("worker").unwrap(),
&"test_worker_1".to_string()
);

Ok(())
});

let worker_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> =
tokio::spawn(async move {
worker1.sync(&drt).await?;
let leader_data = worker1.sync(&drt, &"test_worker_1".to_string()).await?;
assert_eq!(leader_data, "test_data".to_string());

let worker2_res = worker2.sync(&drt).await;
let worker2_res = worker2.sync(&drt, &"test_worker_2".to_string()).await;

assert!(matches!(
worker2_res,
Err(LeaderWorkerBarrierError::BarrierWorkerIdNotUnique)
Err(LeaderWorkerBarrierError::IdNotUnique)
));

Ok(())
Expand All @@ -486,9 +502,9 @@ mod tests {

let id = unique_id();

let leader = LeaderBarrier::new(id.clone(), 2, Some(Duration::from_millis(100)));
let worker1 = WorkerBarrier::<()>::new(id.clone(), "worker1".to_string());
let worker2 = WorkerBarrier::<()>::new(id.clone(), "worker2".to_string());
let leader = LeaderBarrier::<(), ()>::new(id.clone(), 2, Some(Duration::from_millis(100)));
let worker1 = WorkerBarrier::<(), ()>::new(id.clone(), "worker1".to_string());
let worker2 = WorkerBarrier::<(), ()>::new(id.clone(), "worker2".to_string());

let drt_clone = drt.clone();
let leader_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> =
Expand All @@ -502,7 +518,7 @@ mod tests {
let drt_clone = drt.clone();
let worker1_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> =
tokio::spawn(async move {
let res = worker1.sync(&drt_clone).await;
let res = worker1.sync(&drt_clone, &()).await;
assert!(matches!(res, Err(LeaderWorkerBarrierError::Aborted)));

Ok(())
Expand All @@ -511,7 +527,7 @@ mod tests {
let worker2_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> =
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(200)).await;
let res = worker2.sync(&drt).await;
let res = worker2.sync(&drt, &()).await;
assert!(matches!(res, Err(LeaderWorkerBarrierError::Aborted)));

Ok(())
Expand All @@ -533,8 +549,9 @@ mod tests {
let id = unique_id();

// Get the leader to send a (), when the worker expects a String.
let leader = LeaderBarrier::new(id.clone(), 1, Some(Duration::from_millis(100)));
let worker1 = WorkerBarrier::<String>::new(id.clone(), "worker1".to_string());
let leader =
LeaderBarrier::<(), String>::new(id.clone(), 1, Some(Duration::from_millis(100)));
let worker1 = WorkerBarrier::<String, String>::new(id.clone(), "worker1".to_string());

let drt_clone = drt.clone();
let leader_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> =
Expand All @@ -549,7 +566,7 @@ mod tests {
let worker_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> =
tokio::spawn(async move {
assert!(matches!(
worker1.sync(&drt).await,
worker1.sync(&drt, &"test_worker".to_string()).await,
Err(LeaderWorkerBarrierError::SerdeError(_))
));

Expand All @@ -569,9 +586,9 @@ mod tests {

let id = unique_id();

let leader = LeaderBarrier::new(id.clone(), 1, None);
let worker1 = WorkerBarrier::<()>::new(id.clone(), "worker1".to_string());
let worker2 = WorkerBarrier::<()>::new(id.clone(), "worker2".to_string());
let leader = LeaderBarrier::<(), ()>::new(id.clone(), 1, None);
let worker1 = WorkerBarrier::<(), ()>::new(id.clone(), "worker1".to_string());
let worker2 = WorkerBarrier::<(), ()>::new(id.clone(), "worker2".to_string());

let drt_clone = drt.clone();
let leader_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> =
Expand All @@ -583,9 +600,9 @@ mod tests {
let worker_join: JoinHandle<Result<(), LeaderWorkerBarrierError>> =
tokio::spawn(async move {
let drt_clone = drt.clone();
let worker1_join = tokio::spawn(async move { worker1.sync(&drt_clone).await });
let worker1_join = tokio::spawn(async move { worker1.sync(&drt_clone, &()).await });

let worker2_join = tokio::spawn(async move { worker2.sync(&drt).await });
let worker2_join = tokio::spawn(async move { worker2.sync(&drt, &()).await });

let (worker1_res, worker2_res) = tokio::join!(worker1_join, worker2_join);

Expand Down
Loading