Skip to content

Commit cbdbcce

Browse files
committed
* Implement confirmation tracking and await-ing in BasicPublishAsync
1 parent 56b1625 commit cbdbcce

File tree

4 files changed

+118
-50
lines changed

4 files changed

+118
-50
lines changed

projects/RabbitMQ.Client/Impl/ChannelBase.cs

Lines changed: 69 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,9 @@ internal abstract class ChannelBase : IChannel, IRecoverable
6060
private bool _publisherConfirmationsEnabled = false;
6161
private bool _publisherConfirmationTrackingEnabled = false;
6262
private ulong _nextPublishSeqNo = 0;
63-
private SemaphoreSlim _confirmSemaphore = new(1, 1);
64-
private LinkedList<ulong> _pendingDeliveryTags = new();
65-
private List<TaskCompletionSource<bool>> _confirmsTaskCompletionSources = new();
63+
private readonly SemaphoreSlim _confirmSemaphore = new(1, 1);
64+
private readonly LinkedList<ulong> _pendingDeliveryTags = new();
65+
private readonly Dictionary<ulong, TaskCompletionSource<bool>> _confirmsTaskCompletionSources = new();
6666

6767
private bool _onlyAcksReceived = true;
6868

@@ -508,7 +508,7 @@ await _confirmSemaphore.WaitAsync(reason.CancellationToken)
508508
if (_confirmsTaskCompletionSources?.Count > 0)
509509
{
510510
var exception = new AlreadyClosedException(reason);
511-
foreach (TaskCompletionSource<bool> confirmsTaskCompletionSource in _confirmsTaskCompletionSources)
511+
foreach (TaskCompletionSource<bool> confirmsTaskCompletionSource in _confirmsTaskCompletionSources.Values)
512512
{
513513
confirmsTaskCompletionSource.TrySetException(exception);
514514
}
@@ -983,6 +983,7 @@ public async ValueTask BasicPublishAsync<TProperties>(string exchange, string ro
983983
CancellationToken cancellationToken = default)
984984
where TProperties : IReadOnlyBasicProperties, IAmqpHeader
985985
{
986+
TaskCompletionSource<bool>? publisherConfirmationTcs = null;
986987
if (_publisherConfirmationsEnabled)
987988
{
988989
await _confirmSemaphore.WaitAsync(cancellationToken)
@@ -991,11 +992,9 @@ await _confirmSemaphore.WaitAsync(cancellationToken)
991992
{
992993
if (_publisherConfirmationTrackingEnabled)
993994
{
994-
if (_pendingDeliveryTags is null)
995-
{
996-
throw new InvalidOperationException(InternalConstants.BugFound);
997-
}
998995
_pendingDeliveryTags.AddLast(_nextPublishSeqNo);
996+
publisherConfirmationTcs = new(TaskCreationOptions.RunContinuationsAsynchronously);
997+
_confirmsTaskCompletionSources[_nextPublishSeqNo] = publisherConfirmationTcs;
999998
}
1000999

10011000
_nextPublishSeqNo++;
@@ -1039,7 +1038,7 @@ await ModelSendAsync(in cmd, in basicProperties, body, cancellationToken)
10391038
.ConfigureAwait(false);
10401039
}
10411040
}
1042-
catch
1041+
catch (Exception ex)
10431042
{
10441043
if (_publisherConfirmationsEnabled)
10451044
{
@@ -1059,7 +1058,21 @@ await _confirmSemaphore.WaitAsync(cancellationToken)
10591058
}
10601059
}
10611060

1062-
throw;
1061+
if (publisherConfirmationTcs is not null)
1062+
{
1063+
publisherConfirmationTcs.SetException(ex);
1064+
}
1065+
else
1066+
{
1067+
throw;
1068+
}
1069+
}
1070+
1071+
if (publisherConfirmationTcs is not null)
1072+
{
1073+
// TODO timeout?
1074+
await publisherConfirmationTcs.Task
1075+
.ConfigureAwait(false);
10631076
}
10641077
}
10651078

@@ -1068,6 +1081,7 @@ public async ValueTask BasicPublishAsync<TProperties>(CachedString exchange, Cac
10681081
CancellationToken cancellationToken = default)
10691082
where TProperties : IReadOnlyBasicProperties, IAmqpHeader
10701083
{
1084+
TaskCompletionSource<bool>? publisherConfirmationTcs = null;
10711085
if (_publisherConfirmationsEnabled)
10721086
{
10731087
await _confirmSemaphore.WaitAsync(cancellationToken)
@@ -1076,11 +1090,9 @@ await _confirmSemaphore.WaitAsync(cancellationToken)
10761090
{
10771091
if (_publisherConfirmationTrackingEnabled)
10781092
{
1079-
if (_pendingDeliveryTags is null)
1080-
{
1081-
throw new InvalidOperationException(InternalConstants.BugFound);
1082-
}
10831093
_pendingDeliveryTags.AddLast(_nextPublishSeqNo);
1094+
publisherConfirmationTcs = new(TaskCreationOptions.RunContinuationsAsynchronously);
1095+
_confirmsTaskCompletionSources[_nextPublishSeqNo] = publisherConfirmationTcs;
10841096
}
10851097

10861098
_nextPublishSeqNo++;
@@ -1124,7 +1136,7 @@ await ModelSendAsync(in cmd, in basicProperties, body, cancellationToken)
11241136
.ConfigureAwait(false);
11251137
}
11261138
}
1127-
catch
1139+
catch (Exception ex)
11281140
{
11291141
if (_publisherConfirmationsEnabled)
11301142
{
@@ -1144,7 +1156,21 @@ await _confirmSemaphore.WaitAsync(cancellationToken)
11441156
}
11451157
}
11461158

1147-
throw;
1159+
if (publisherConfirmationTcs is not null)
1160+
{
1161+
publisherConfirmationTcs.SetException(ex);
1162+
}
1163+
else
1164+
{
1165+
throw;
1166+
}
1167+
}
1168+
1169+
if (publisherConfirmationTcs is not null)
1170+
{
1171+
// TODO timeout?
1172+
await publisherConfirmationTcs.Task
1173+
.ConfigureAwait(false);
11481174
}
11491175
}
11501176

@@ -1253,11 +1279,6 @@ await ModelSendAsync(in method, k.CancellationToken)
12531279

12541280
bool result = await k;
12551281
Debug.Assert(result);
1256-
1257-
// Note:
1258-
// Non-null means confirms are enabled
1259-
_confirmSemaphore = new SemaphoreSlim(1, 1);
1260-
12611282
return;
12621283
}
12631284
finally
@@ -1890,10 +1911,6 @@ internal async Task HandleAckNack(ulong deliveryTag, bool multiple, bool isNack,
18901911
// Only do this if confirms are enabled *and* the library is tracking confirmations
18911912
if (_publisherConfirmationsEnabled && _publisherConfirmationTrackingEnabled)
18921913
{
1893-
if (_pendingDeliveryTags is null)
1894-
{
1895-
throw new InvalidOperationException(InternalConstants.BugFound);
1896-
}
18971914
// let's take a lock so we can assume that deliveryTags are unique, never duplicated and always sorted
18981915
await _confirmSemaphore.WaitAsync(cancellationToken)
18991916
.ConfigureAwait(false);
@@ -1904,28 +1921,45 @@ await _confirmSemaphore.WaitAsync(cancellationToken)
19041921
{
19051922
if (multiple)
19061923
{
1907-
while (_pendingDeliveryTags.First!.Value < deliveryTag)
1908-
{
1909-
_pendingDeliveryTags.RemoveFirst();
1910-
}
1911-
1912-
if (_pendingDeliveryTags.First.Value == deliveryTag)
1924+
do
19131925
{
1914-
_pendingDeliveryTags.RemoveFirst();
1926+
if (_pendingDeliveryTags.First is null)
1927+
{
1928+
break;
1929+
}
1930+
else
1931+
{
1932+
ulong pendingDeliveryTag = _pendingDeliveryTags.First.Value;
1933+
if (pendingDeliveryTag > deliveryTag)
1934+
{
1935+
break;
1936+
}
1937+
else
1938+
{
1939+
TaskCompletionSource<bool> tcs = _confirmsTaskCompletionSources[pendingDeliveryTag];
1940+
tcs.SetResult(true);
1941+
_confirmsTaskCompletionSources.Remove(pendingDeliveryTag);
1942+
_pendingDeliveryTags.RemoveFirst();
1943+
}
1944+
}
19151945
}
1946+
while (true);
19161947
}
19171948
else
19181949
{
1950+
TaskCompletionSource<bool> tcs = _confirmsTaskCompletionSources[deliveryTag];
1951+
tcs.SetResult(true);
1952+
_confirmsTaskCompletionSources.Remove(deliveryTag);
19191953
_pendingDeliveryTags.Remove(deliveryTag);
19201954
}
19211955
}
19221956

1923-
_onlyAcksReceived = _onlyAcksReceived && !isNack;
1957+
_onlyAcksReceived = _onlyAcksReceived && false == isNack;
19241958

1925-
if (_pendingDeliveryTags.Count == 0 && _confirmsTaskCompletionSources!.Count > 0)
1959+
if (_pendingDeliveryTags.Count == 0 && _confirmsTaskCompletionSources.Count > 0)
19261960
{
19271961
// Done, mark tasks
1928-
foreach (TaskCompletionSource<bool> tcs in _confirmsTaskCompletionSources)
1962+
foreach (TaskCompletionSource<bool> tcs in _confirmsTaskCompletionSources.Values)
19291963
{
19301964
tcs.TrySetResult(_onlyAcksReceived);
19311965
}

projects/Test/Applications/PublisherConfirms/PublisherConfirms.cs

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,35 @@
1-
using System;
1+
// This source code is dual-licensed under the Apache License, version
2+
// 2.0, and the Mozilla Public License, version 2.0.
3+
//
4+
// The APL v2.0:
5+
//
6+
//---------------------------------------------------------------------------
7+
// Copyright (c) 2007-2024 Broadcom. All Rights Reserved.
8+
//
9+
// Licensed under the Apache License, Version 2.0 (the "License");
10+
// you may not use this file except in compliance with the License.
11+
// You may obtain a copy of the License at
12+
//
13+
// https://www.apache.org/licenses/LICENSE-2.0
14+
//
15+
// Unless required by applicable law or agreed to in writing, software
16+
// distributed under the License is distributed on an "AS IS" BASIS,
17+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18+
// See the License for the specific language governing permissions and
19+
// limitations under the License.
20+
//---------------------------------------------------------------------------
21+
//
22+
// The MPL v2.0:
23+
//
24+
//---------------------------------------------------------------------------
25+
// This Source Code Form is subject to the terms of the Mozilla Public
26+
// License, v. 2.0. If a copy of the MPL was not distributed with this
27+
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
28+
//
29+
// Copyright (c) 2007-2024 Broadcom. All Rights Reserved.
30+
//---------------------------------------------------------------------------
31+
32+
using System;
233
using System.Collections.Generic;
334
using System.Diagnostics;
435
using System.Text;
@@ -9,6 +40,8 @@
940
const int MESSAGE_COUNT = 50_000;
1041
bool debug = false;
1142

43+
#pragma warning disable CS8321 // Local function is declared but never used
44+
1245
await PublishMessagesIndividuallyAsync();
1346
await PublishMessagesInBatchAsync();
1447
await HandlePublishConfirmsAsynchronously();
@@ -40,8 +73,6 @@ static async Task PublishMessagesIndividuallyAsync()
4073
await channel.BasicPublishAsync(exchange: string.Empty, routingKey: queueName, body: body);
4174
}
4275

43-
// await channel.WaitForConfirmsOrDieAsync();
44-
4576
sw.Stop();
4677

4778
Console.WriteLine($"{DateTime.Now} [INFO] published {MESSAGE_COUNT:N0} messages individually in {sw.ElapsedMilliseconds:N0} ms");
@@ -52,7 +83,8 @@ static async Task PublishMessagesInBatchAsync()
5283
Console.WriteLine($"{DateTime.Now} [INFO] publishing {MESSAGE_COUNT:N0} messages and handling confirms in batches");
5384

5485
await using IConnection connection = await CreateConnectionAsync();
55-
await using IChannel channel = await connection.CreateChannelAsync();
86+
await using IChannel channel = await connection.CreateChannelAsync(publisherConfirmationsEnabled: true,
87+
publisherConfirmationTrackingEnabled: true);
5688

5789
// declare a server-named queue
5890
QueueDeclareOk queueDeclareResult = await channel.QueueDeclareAsync();
@@ -76,8 +108,6 @@ static async Task PublishMessagesInBatchAsync()
76108
using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(5));
77109
await Task.WhenAll(publishTasks).WaitAsync(cts.Token);
78110
publishTasks.Clear();
79-
80-
// await channel.WaitForConfirmsOrDieAsync(cts.Token);
81111
outstandingMessageCount = 0;
82112
}
83113
}
@@ -100,7 +130,8 @@ async Task HandlePublishConfirmsAsynchronously()
100130

101131
// NOTE: setting trackConfirmations to false because this program
102132
// is tracking them itself.
103-
await using IChannel channel = await connection.CreateChannelAsync(publisherConfirmationTrackingEnabled: false);
133+
await using IChannel channel = await connection.CreateChannelAsync(publisherConfirmationsEnabled: true,
134+
publisherConfirmationTrackingEnabled: false);
104135

105136
// declare a server-named queue
106137
QueueDeclareOk queueDeclareResult = await channel.QueueDeclareAsync();
@@ -185,7 +216,9 @@ async Task CleanOutstandingConfirms(ulong deliveryTag, bool multiple)
185216
{
186217
semaphore.Release();
187218
}
188-
publishTasks.Add(channel.BasicPublishAsync(exchange: string.Empty, routingKey: queueName, body: body).AsTask());
219+
220+
ValueTask pt = channel.BasicPublishAsync(exchange: string.Empty, routingKey: queueName, body: body);
221+
publishTasks.Add(pt.AsTask());
189222
}
190223

191224
using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(10));

projects/Test/Common/IntegrationFixture.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,8 @@ public virtual async Task InitializeAsync()
153153

154154
if (_openChannel)
155155
{
156-
_channel = await _conn.CreateChannelAsync();
156+
_channel = await _conn.CreateChannelAsync(publisherConfirmationsEnabled: true,
157+
publisherConfirmationTrackingEnabled: true);
157158
}
158159

159160
if (IsVerbose)

projects/Test/Integration/TestBasicPublish.cs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ public override Task InitializeAsync()
5959
public async Task TestBasicRoundtripArray()
6060
{
6161
_conn = await _connFactory.CreateConnectionAsync();
62-
_channel = await _conn.CreateChannelAsync();
62+
_channel = await _conn.CreateChannelAsync(publisherConfirmationsEnabled: true, publisherConfirmationTrackingEnabled: true);
6363

6464
QueueDeclareOk q = await _channel.QueueDeclareAsync();
6565
var bp = new BasicProperties();
@@ -87,7 +87,7 @@ public async Task TestBasicRoundtripArray()
8787
public async Task TestBasicRoundtripCachedString()
8888
{
8989
_conn = await _connFactory.CreateConnectionAsync();
90-
_channel = await _conn.CreateChannelAsync();
90+
_channel = await _conn.CreateChannelAsync(publisherConfirmationsEnabled: true, publisherConfirmationTrackingEnabled: true);
9191

9292
CachedString exchangeName = new CachedString(string.Empty);
9393
CachedString queueName = new CachedString((await _channel.QueueDeclareAsync()).QueueName);
@@ -115,7 +115,7 @@ public async Task TestBasicRoundtripCachedString()
115115
public async Task TestBasicRoundtripReadOnlyMemory()
116116
{
117117
_conn = await _connFactory.CreateConnectionAsync();
118-
_channel = await _conn.CreateChannelAsync();
118+
_channel = await _conn.CreateChannelAsync(publisherConfirmationsEnabled: true, publisherConfirmationTrackingEnabled: true);
119119

120120
QueueDeclareOk q = await _channel.QueueDeclareAsync();
121121
byte[] sendBody = _encoding.GetBytes("hi");
@@ -142,7 +142,7 @@ public async Task TestBasicRoundtripReadOnlyMemory()
142142
public async Task CanNotModifyPayloadAfterPublish()
143143
{
144144
_conn = await _connFactory.CreateConnectionAsync();
145-
_channel = await _conn.CreateChannelAsync();
145+
_channel = await _conn.CreateChannelAsync(publisherConfirmationsEnabled: true, publisherConfirmationTrackingEnabled: true);
146146

147147
QueueDeclareOk q = await _channel.QueueDeclareAsync();
148148
byte[] sendBody = new byte[1000];
@@ -203,7 +203,7 @@ public async Task TestMaxInboundMessageBodySize()
203203
Assert.Equal(maxMsgSize, cf.Endpoint.MaxInboundMessageBodySize);
204204
Assert.Equal(maxMsgSize, conn.Endpoint.MaxInboundMessageBodySize);
205205

206-
await using (IChannel channel = await conn.CreateChannelAsync())
206+
await using (IChannel channel = await conn.CreateChannelAsync(publisherConfirmationsEnabled: true, publisherConfirmationTrackingEnabled: true))
207207
{
208208
channel.ChannelShutdownAsync += (o, a) =>
209209
{
@@ -286,7 +286,7 @@ public async Task TestMaxInboundMessageBodySize()
286286
public async Task TestPropertiesRoundtrip_Headers()
287287
{
288288
_conn = await _connFactory.CreateConnectionAsync();
289-
_channel = await _conn.CreateChannelAsync();
289+
_channel = await _conn.CreateChannelAsync(publisherConfirmationsEnabled: true, publisherConfirmationTrackingEnabled: true);
290290

291291
var subject = new BasicProperties
292292
{

0 commit comments

Comments
 (0)