Skip to content

Add test to try and repro #1168 #1173

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
311 changes: 217 additions & 94 deletions projects/Unit/TestAsyncConsumer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,99 +30,206 @@
//---------------------------------------------------------------------------

using System;
using System.Security.Cryptography;
using System.Text;
using System.Threading;
using System.Threading.Tasks;

using RabbitMQ.Client.Events;
using Xunit;
using Xunit.Abstractions;

namespace RabbitMQ.Client.Unit
{

public class TestAsyncConsumer
{
private readonly ITestOutputHelper _output;

public TestAsyncConsumer(ITestOutputHelper output)
{
_output = output;
}

[Fact]
public void TestBasicRoundtrip()
{
var cf = new ConnectionFactory{ DispatchConsumersAsync = true };
using(IConnection c = cf.CreateConnection())
using(IModel m = c.CreateModel())
{
QueueDeclareOk q = m.QueueDeclare();
byte[] body = System.Text.Encoding.UTF8.GetBytes("async-hi");
m.BasicPublish("", q.QueueName, body);
var consumer = new AsyncEventingBasicConsumer(m);
var are = new AutoResetEvent(false);
consumer.Received += async (o, a) =>
{
are.Set();
await Task.Yield();
};
string tag = m.BasicConsume(q.QueueName, true, consumer);
// ensure we get a delivery
bool waitRes = are.WaitOne(2000);
Assert.True(waitRes);
// unsubscribe and ensure no further deliveries
m.BasicCancel(tag);
m.BasicPublish("", q.QueueName, body);
bool waitResFalse = are.WaitOne(2000);
Assert.False(waitResFalse);
using(IModel m = c.CreateModel())
{
QueueDeclareOk q = m.QueueDeclare();
byte[] body = System.Text.Encoding.UTF8.GetBytes("async-hi");
m.BasicPublish("", q.QueueName, body);
var consumer = new AsyncEventingBasicConsumer(m);
var are = new AutoResetEvent(false);
consumer.Received += async (o, a) =>
{
are.Set();
await Task.Yield();
};
string tag = m.BasicConsume(q.QueueName, true, consumer);
// ensure we get a delivery
bool waitRes = are.WaitOne(2000);
Assert.True(waitRes);
// unsubscribe and ensure no further deliveries
m.BasicCancel(tag);
m.BasicPublish("", q.QueueName, body);
bool waitResFalse = are.WaitOne(2000);
Assert.False(waitResFalse);
}
}
}

[Fact]
public async Task TestBasicRoundtripConcurrent()
{
var cf = new ConnectionFactory{ DispatchConsumersAsync = true, ConsumerDispatchConcurrency = 2 };
using(IConnection c = cf.CreateConnection())
using(IModel m = c.CreateModel())
using (IConnection c = cf.CreateConnection())
{
QueueDeclareOk q = m.QueueDeclare();
const string publish1 = "async-hi-1";
byte[] body = Encoding.UTF8.GetBytes(publish1);
m.BasicPublish("", q.QueueName, body);
const string publish2 = "async-hi-2";
body = Encoding.UTF8.GetBytes(publish2);
m.BasicPublish("", q.QueueName, body);

var consumer = new AsyncEventingBasicConsumer(m);

var publish1SyncSource = new TaskCompletionSource<bool>(TaskCreationOptions.RunContinuationsAsynchronously);
var publish2SyncSource = new TaskCompletionSource<bool>(TaskCreationOptions.RunContinuationsAsynchronously);
var maximumWaitTime = TimeSpan.FromSeconds(5);
var tokenSource = new CancellationTokenSource(maximumWaitTime);
tokenSource.Token.Register(() =>
using (IModel m = c.CreateModel())
{
publish1SyncSource.TrySetResult(false);
publish2SyncSource.TrySetResult(false);
});
QueueDeclareOk q = m.QueueDeclare();
string publish1 = get_unique_string(16384);
byte[] body = Encoding.UTF8.GetBytes(publish1);
m.BasicPublish("", q.QueueName, body);

consumer.Received += async (o, a) =>
{
switch (Encoding.UTF8.GetString(a.Body.ToArray()))
string publish2 = get_unique_string(16384);
body = Encoding.UTF8.GetBytes(publish2);
m.BasicPublish("", q.QueueName, body);

var consumer = new AsyncEventingBasicConsumer(m);

var publish1SyncSource = new TaskCompletionSource<bool>(TaskCreationOptions.RunContinuationsAsynchronously);
var publish2SyncSource = new TaskCompletionSource<bool>(TaskCreationOptions.RunContinuationsAsynchronously);
var maximumWaitTime = TimeSpan.FromSeconds(5);
var tokenSource = new CancellationTokenSource(maximumWaitTime);
tokenSource.Token.Register(() =>
{
case publish1:
publish1SyncSource.TrySetResult(false);
publish2SyncSource.TrySetResult(false);
});

consumer.Received += async (o, a) =>
{
string decoded = Encoding.ASCII.GetString(a.Body.ToArray());
if (decoded == publish1)
{
publish1SyncSource.TrySetResult(true);
await publish2SyncSource.Task;
break;
case publish2:
}
else if (decoded == publish2)
{
publish2SyncSource.TrySetResult(true);
await publish1SyncSource.Task;
break;
}
};
}
};

m.BasicConsume(q.QueueName, true, consumer);
// ensure we get a delivery
m.BasicConsume(q.QueueName, true, consumer);
// ensure we get a delivery

await Task.WhenAll(publish1SyncSource.Task, publish2SyncSource.Task);
await Task.WhenAll(publish1SyncSource.Task, publish2SyncSource.Task);

Assert.True(publish1SyncSource.Task.Result, $"Non concurrent dispatch lead to deadlock after {maximumWaitTime}");
Assert.True(publish2SyncSource.Task.Result, $"Non concurrent dispatch lead to deadlock after {maximumWaitTime}");
Assert.True(publish1SyncSource.Task.Result, $"Non concurrent dispatch lead to deadlock after {maximumWaitTime}");
Assert.True(publish2SyncSource.Task.Result, $"Non concurrent dispatch lead to deadlock after {maximumWaitTime}");
}
}
}

[Fact]
public async Task TestBasicRoundtripConcurrentManyMessages()
{
const int publish_total = 4096;
string queueName = $"{nameof(TestBasicRoundtripConcurrentManyMessages)}-{Guid.NewGuid()}";

string publish1 = get_unique_string(32768);
byte[] body1 = Encoding.ASCII.GetBytes(publish1);
string publish2 = get_unique_string(32768);
byte[] body2 = Encoding.ASCII.GetBytes(publish2);

var cf = new ConnectionFactory{ DispatchConsumersAsync = true, ConsumerDispatchConcurrency = 2 };

using (IConnection c = cf.CreateConnection())
{
using (IModel m = c.CreateModel())
{
QueueDeclareOk q = m.QueueDeclare(queue: queueName, exclusive: false, durable: true);
Assert.Equal(q.QueueName, queueName);
}
}

Task publishTask = Task.Run(() =>
{
using (IConnection c = cf.CreateConnection())
{
using (IModel m = c.CreateModel())
{
QueueDeclareOk q = m.QueueDeclare(queue: queueName, exclusive: false, durable: true);
for (int i = 0; i < publish_total; i++)
{
m.BasicPublish(string.Empty, queueName, body1);
m.BasicPublish(string.Empty, queueName, body2);
}
}
}
});

Task consumeTask = Task.Run(() =>
{
var publish1SyncSource = new TaskCompletionSource<bool>(TaskCreationOptions.RunContinuationsAsynchronously);
var publish2SyncSource = new TaskCompletionSource<bool>(TaskCreationOptions.RunContinuationsAsynchronously);
var maximumWaitTime = TimeSpan.FromSeconds(10);
var tokenSource = new CancellationTokenSource(maximumWaitTime);
tokenSource.Token.Register(() =>
{
publish1SyncSource.TrySetResult(false);
publish2SyncSource.TrySetResult(false);
});

using (IConnection c = cf.CreateConnection())
{
using (IModel m = c.CreateModel())
{
var consumer = new AsyncEventingBasicConsumer(m);

int publish1_count = 0;
int publish2_count = 0;

consumer.Received += async (o, a) =>
{
string decoded = Encoding.ASCII.GetString(a.Body.ToArray());
if (decoded == publish1)
{
if (Interlocked.Increment(ref publish1_count) >= publish_total)
{
publish1SyncSource.TrySetResult(true);
await publish2SyncSource.Task;
}
}
else if (decoded == publish2)
{
if (Interlocked.Increment(ref publish2_count) >= publish_total)
{
publish2SyncSource.TrySetResult(true);
await publish1SyncSource.Task;
}
}
};

m.BasicConsume(queueName, true, consumer);

// ensure we get a delivery
Task.WhenAll(publish1SyncSource.Task, publish2SyncSource.Task);

Assert.True(publish1SyncSource.Task.Result, $"Non concurrent dispatch lead to deadlock after {maximumWaitTime}");
Assert.True(publish2SyncSource.Task.Result, $"Non concurrent dispatch lead to deadlock after {maximumWaitTime}");
}
}
});

await Task.WhenAll(publishTask, consumeTask);
}

[Fact]
public void TestBasicRoundtripNoWait()
{
Expand Down Expand Up @@ -164,47 +271,49 @@ public void ConcurrentEventingTestForReceived()

var cf = new ConnectionFactory{ DispatchConsumersAsync = true };
using (IConnection c = cf.CreateConnection())
using (IModel m = c.CreateModel())
{
QueueDeclareOk q = m.QueueDeclare();
var consumer = new AsyncEventingBasicConsumer(m);
m.BasicConsume(q.QueueName, true, consumer);
var countdownEvent = new CountdownEvent(NumberOfThreads);
var tasks = new Task[NumberOfThreads];
for (int i = 0; i < NumberOfThreads; i++)
using (IModel m = c.CreateModel())
{
int threadIndex = i;
tasks[i] = Task.Run(() =>
QueueDeclareOk q = m.QueueDeclare();
var consumer = new AsyncEventingBasicConsumer(m);
m.BasicConsume(q.QueueName, true, consumer);
var countdownEvent = new CountdownEvent(NumberOfThreads);
var tasks = new Task[NumberOfThreads];
for (int i = 0; i < NumberOfThreads; i++)
{
countdownEvent.Signal();
countdownEvent.Wait();
int start = threadIndex * NumberOfRegistrations;
for (int j = start; j < start + NumberOfRegistrations; j++)
int threadIndex = i;
tasks[i] = Task.Run(() =>
{
int receivedIndex = j;
consumer.Received += (sender, eventArgs) =>
countdownEvent.Signal();
countdownEvent.Wait();
int start = threadIndex * NumberOfRegistrations;
for (int j = start; j < start + NumberOfRegistrations; j++)
{
called[receivedIndex] = 1;
return Task.CompletedTask;
};
}
});
}
int receivedIndex = j;
consumer.Received += (sender, eventArgs) =>
{
called[receivedIndex] = 1;
return Task.CompletedTask;
};
}
});
}

countdownEvent.Wait();
Task.WaitAll(tasks);
countdownEvent.Wait();
Task.WaitAll(tasks);

// Add last receiver
var are = new AutoResetEvent(false);
consumer.Received += (o, a) =>
{
are.Set();
return Task.CompletedTask;
};
// Add last receiver
var are = new AutoResetEvent(false);
consumer.Received += (o, a) =>
{
are.Set();
return Task.CompletedTask;
};

// Send message
m.BasicPublish("", q.QueueName, ReadOnlyMemory<byte>.Empty);
are.WaitOne(TimingFixture.TestTimeout);
// Send message
m.BasicPublish("", q.QueueName, ReadOnlyMemory<byte>.Empty);
are.WaitOne(TimingFixture.TestTimeout);
}
}

// Check received messages
Expand All @@ -216,13 +325,27 @@ public void NonAsyncConsumerShouldThrowInvalidOperationException()
{
var cf = new ConnectionFactory{ DispatchConsumersAsync = true };
using(IConnection c = cf.CreateConnection())
using(IModel m = c.CreateModel())
{
QueueDeclareOk q = m.QueueDeclare();
byte[] body = System.Text.Encoding.UTF8.GetBytes("async-hi");
m.BasicPublish("", q.QueueName, body);
var consumer = new EventingBasicConsumer(m);
Assert.Throws<InvalidOperationException>(() => m.BasicConsume(q.QueueName, false, consumer));
using(IModel m = c.CreateModel())
{
QueueDeclareOk q = m.QueueDeclare();
byte[] body = System.Text.Encoding.UTF8.GetBytes("async-hi");
m.BasicPublish("", q.QueueName, body);
var consumer = new EventingBasicConsumer(m);
Assert.Throws<InvalidOperationException>(() => m.BasicConsume(q.QueueName, false, consumer));
}
}
}

private string get_unique_string(int string_length)
{
using (var rng = RandomNumberGenerator.Create())
{
var bit_count = (string_length * 6);
var byte_count = ((bit_count + 7) / 8); // rounded up
var bytes = new byte[byte_count];
rng.GetBytes(bytes);
return Convert.ToBase64String(bytes);
}
}
}
Expand Down
1 change: 0 additions & 1 deletion projects/Unit/TestEventingConsumer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@

namespace RabbitMQ.Client.Unit
{

public class TestEventingConsumer : IntegrationFixture
{
public TestEventingConsumer(ITestOutputHelper output) : base(output)
Expand Down