Skip to content
This repository was archived by the owner on Jul 9, 2023. It is now read-only.

Commit 045ec22

Browse files
committed
NetworkStream read cancellation hack
1 parent 51abd03 commit 045ec22

File tree

10 files changed

+79
-44
lines changed

10 files changed

+79
-44
lines changed
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
using System;
2+
using Titanium.Web.Proxy.Network.Tcp;
3+
4+
namespace Titanium.Web.Proxy.EventArguments
5+
{
6+
public class EmptyProxyEventArgs : ProxyEventArgsBase
7+
{
8+
internal EmptyProxyEventArgs(TcpClientConnection clientConnection) : base(clientConnection)
9+
{
10+
}
11+
}
12+
}

src/Titanium.Web.Proxy/EventArguments/SessionEventArgs.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ private async Task readResponseBodyAsync(CancellationToken cancellationToken)
190190
private async Task<byte[]> readBodyAsync(bool isRequest, CancellationToken cancellationToken)
191191
{
192192
using var bodyStream = new MemoryStream();
193-
using var writer = new HttpStream(bodyStream, BufferPool);
193+
using var writer = new HttpStream(bodyStream, BufferPool, cancellationToken);
194194

195195
if (isRequest)
196196
{

src/Titanium.Web.Proxy/ExplicitClientHandler.cs

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,12 @@ private async Task handleClient(ExplicitProxyEndPoint endPoint, TcpClientConnect
3333
var cancellationTokenSource = new CancellationTokenSource();
3434
var cancellationToken = cancellationTokenSource.Token;
3535

36-
var clientStream = new HttpClientStream(clientConnection, clientConnection.GetStream(), BufferPool);
36+
var clientStream = new HttpClientStream(clientConnection, clientConnection.GetStream(), BufferPool, cancellationToken);
3737

3838
Task<TcpServerConnection>? prefetchConnectionTask = null;
3939
bool closeServerConnection = false;
4040
bool calledRequestHandler = false;
4141

42-
SslStream? sslStream = null;
43-
4442
try
4543
{
4644
TunnelConnectSessionEventArgs? connectArgs = null;
@@ -191,6 +189,7 @@ private async Task handleClient(ExplicitProxyEndPoint endPoint, TcpClientConnect
191189
}
192190

193191
X509Certificate2? certificate = null;
192+
SslStream? sslStream = null;
194193
try
195194
{
196195
sslStream = new SslStream(clientStream, false);
@@ -221,14 +220,16 @@ private async Task handleClient(ExplicitProxyEndPoint endPoint, TcpClientConnect
221220
#endif
222221

223222
// HTTPS server created - we can now decrypt the client's traffic
224-
clientStream = new HttpClientStream(clientStream.Connection, sslStream, BufferPool);
223+
clientStream = new HttpClientStream(clientStream.Connection, sslStream, BufferPool, cancellationToken);
225224
sslStream = null; // clientStream was created, no need to keep SSL stream reference
226225

227226
clientStream.DataRead += (o, args) => connectArgs.OnDecryptedDataSent(args.Buffer, args.Offset, args.Count);
228227
clientStream.DataWrite += (o, args) => connectArgs.OnDecryptedDataReceived(args.Buffer, args.Offset, args.Count);
229228
}
230229
catch (Exception e)
231230
{
231+
sslStream?.Dispose();
232+
232233
var certName = certificate?.GetNameInfo(X509NameType.SimpleName, false);
233234
throw new ProxyConnectException(
234235
$"Couldn't authenticate host '{connectHostname}' with certificate '{certName}'.", e, connectArgs);
@@ -401,12 +402,16 @@ await Http2Helper.SendHttp2(clientStream, connection.Stream,
401402
}
402403
finally
403404
{
405+
if (!cancellationTokenSource.IsCancellationRequested)
406+
{
407+
cancellationTokenSource.Cancel();
408+
}
409+
404410
if (!calledRequestHandler)
405411
{
406412
await tcpConnectionFactory.Release(prefetchConnectionTask, closeServerConnection);
407413
}
408414

409-
sslStream?.Dispose();
410415
clientStream.Dispose();
411416
}
412417
}

src/Titanium.Web.Proxy/Extensions/StreamExtensions.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ internal static async Task CopyToAsync(this Stream input, Stream output, Action<
4242
{
4343
// cancellation is not working on Socket ReadAsync
4444
// https://github.com/dotnet/corefx/issues/15033
45-
int num = await input.ReadAsync(buffer, 0, buffer.Length, CancellationToken.None)
46-
.withCancellation(cancellationToken);
45+
int num = await input.ReadAsync(buffer, 0, buffer.Length, cancellationToken)
46+
.WithCancellation(cancellationToken);
4747
int bytesRead;
4848
if ((bytesRead = num) != 0 && !cancellationToken.IsCancellationRequested)
4949
{
@@ -62,7 +62,7 @@ internal static async Task CopyToAsync(this Stream input, Stream output, Action<
6262
}
6363
}
6464

65-
private static async Task<T> withCancellation<T>(this Task<T> task, CancellationToken cancellationToken) where T : struct
65+
internal static async Task<T> WithCancellation<T>(this Task<T> task, CancellationToken cancellationToken) where T : struct
6666
{
6767
var tcs = new TaskCompletionSource<bool>();
6868
using (cancellationToken.Register(s => ((TaskCompletionSource<bool>)s).TrySetResult(true), tcs))

src/Titanium.Web.Proxy/Helpers/HttpClientStream.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ internal sealed class HttpClientStream : HttpStream
1212
{
1313
public TcpClientConnection Connection { get; }
1414

15-
internal HttpClientStream(TcpClientConnection connection, Stream stream, IBufferPool bufferPool)
16-
: base(stream, bufferPool)
15+
internal HttpClientStream(TcpClientConnection connection, Stream stream, IBufferPool bufferPool, CancellationToken cancellationToken)
16+
: base(stream, bufferPool, cancellationToken)
1717
{
1818
Connection = connection;
1919
}

src/Titanium.Web.Proxy/Helpers/HttpServerStream.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ namespace Titanium.Web.Proxy.Helpers
1010
{
1111
internal sealed class HttpServerStream : HttpStream
1212
{
13-
internal HttpServerStream(Stream stream, IBufferPool bufferPool)
14-
: base(stream, bufferPool)
13+
internal HttpServerStream(Stream stream, IBufferPool bufferPool, CancellationToken cancellationToken)
14+
: base(stream, bufferPool, cancellationToken)
1515
{
1616
}
1717

src/Titanium.Web.Proxy/Helpers/HttpStream.cs

Lines changed: 41 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
using System.Threading.Tasks;
1010
using Titanium.Web.Proxy.Compression;
1111
using Titanium.Web.Proxy.EventArguments;
12+
using Titanium.Web.Proxy.Extensions;
1213
using Titanium.Web.Proxy.Http;
1314
using Titanium.Web.Proxy.Models;
1415
using Titanium.Web.Proxy.Shared;
@@ -19,7 +20,7 @@ namespace Titanium.Web.Proxy.Helpers
1920
{
2021
internal class HttpStream : Stream, IHttpStreamWriter, IHttpStreamReader, IPeekStream
2122
{
22-
private readonly bool swallowException;
23+
private readonly bool isNetworkStream;
2324
private readonly bool leaveOpen;
2425
private readonly byte[] streamBuffer;
2526

@@ -37,6 +38,7 @@ internal class HttpStream : Stream, IHttpStreamWriter, IHttpStreamReader, IPeekS
3738
private bool closedRead;
3839

3940
private readonly IBufferPool bufferPool;
41+
private readonly CancellationToken cancellationToken;
4042

4143
public event EventHandler<DataEventArgs>? DataRead;
4244

@@ -71,18 +73,20 @@ static HttpStream()
7173
/// </summary>
7274
/// <param name="baseStream">The base stream.</param>
7375
/// <param name="bufferPool">Bufferpool.</param>
76+
/// <param name="cancellationToken">The cancellation token.</param>
7477
/// <param name="leaveOpen"><see langword="true" /> to leave the stream open after disposing the <see cref="T:CustomBufferedStream" /> object; otherwise, <see langword="false" />.</param>
75-
internal HttpStream(Stream baseStream, IBufferPool bufferPool, bool leaveOpen = false)
78+
internal HttpStream(Stream baseStream, IBufferPool bufferPool, CancellationToken cancellationToken, bool leaveOpen = false)
7679
{
7780
if (baseStream is NetworkStream)
7881
{
79-
swallowException = true;
82+
isNetworkStream = true;
8083
}
8184

8285
this.baseStream = baseStream;
8386
this.leaveOpen = leaveOpen;
8487
streamBuffer = bufferPool.GetBuffer();
8588
this.bufferPool = bufferPool;
89+
this.cancellationToken = cancellationToken;
8690
}
8791

8892
/// <summary>
@@ -102,7 +106,7 @@ public override void Flush()
102106
catch
103107
{
104108
closedWrite = true;
105-
if (!swallowException)
109+
if (!isNetworkStream)
106110
throw;
107111
}
108112
}
@@ -181,7 +185,7 @@ public override void Write(byte[] buffer, int offset, int count)
181185
catch
182186
{
183187
closedWrite = true;
184-
if (!swallowException)
188+
if (!isNetworkStream)
185189
throw;
186190
}
187191
}
@@ -228,7 +232,7 @@ public override async Task FlushAsync(CancellationToken cancellationToken)
228232
catch
229233
{
230234
closedWrite = true;
231-
if (!swallowException)
235+
if (!isNetworkStream)
232236
throw;
233237
}
234238
}
@@ -450,7 +454,7 @@ public override async Task WriteAsync(byte[] buffer, int offset, int count, Canc
450454
catch
451455
{
452456
closedWrite = true;
453-
if (!swallowException)
457+
if (!isNetworkStream)
454458
throw;
455459
}
456460
}
@@ -476,7 +480,7 @@ public override void WriteByte(byte value)
476480
catch
477481
{
478482
closedWrite = true;
479-
if (!swallowException)
483+
if (!isNetworkStream)
480484
throw;
481485
}
482486
finally
@@ -609,7 +613,7 @@ public bool FillBuffer()
609613
}
610614
catch
611615
{
612-
if (!swallowException)
616+
if (!isNetworkStream)
613617
throw;
614618
}
615619
finally
@@ -655,17 +659,28 @@ public async ValueTask<bool> FillBufferAsync(CancellationToken cancellationToken
655659
bool result = false;
656660
try
657661
{
658-
int readBytes = await baseStream.ReadAsync(streamBuffer, bufferLength, bytesToRead, cancellationToken);
662+
var readTask = baseStream.ReadAsync(streamBuffer, bufferLength, bytesToRead, cancellationToken);
663+
if (isNetworkStream)
664+
{
665+
readTask = readTask.WithCancellation(cancellationToken);
666+
}
667+
668+
int readBytes = await readTask;
659669
result = readBytes > 0;
660670
if (result)
661671
{
662672
OnDataRead(streamBuffer, bufferLength, readBytes);
663673
bufferLength += readBytes;
664674
}
665675
}
676+
catch (ObjectDisposedException)
677+
{
678+
if (!isNetworkStream)
679+
throw;
680+
}
666681
catch
667682
{
668-
if (!swallowException)
683+
if (!isNetworkStream)
669684
throw;
670685
}
671686
finally
@@ -771,14 +786,18 @@ public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, Asy
771786
return base.BeginRead(buffer, offset, count, callback, state);
772787
}
773788

774-
var vAsyncResult = this.ReadAsync(buffer, offset, count);
789+
var vAsyncResult = this.ReadAsync(buffer, offset, count, cancellationToken);
790+
if (isNetworkStream)
791+
{
792+
vAsyncResult = vAsyncResult.WithCancellation(cancellationToken);
793+
}
775794

776795
vAsyncResult.ContinueWith(pAsyncResult =>
777796
{
778797
// use TaskExtended to pass State as AsyncObject
779798
// callback will call EndRead (otherwise, it will block)
780799
callback?.Invoke(new TaskResult<int>(pAsyncResult, state));
781-
});
800+
}, cancellationToken);
782801

783802
return vAsyncResult;
784803
}
@@ -811,12 +830,12 @@ public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, As
811830
return base.BeginWrite(buffer, offset, count, callback, state);
812831
}
813832

814-
var vAsyncResult = this.WriteAsync(buffer, offset, count);
833+
var vAsyncResult = this.WriteAsync(buffer, offset, count, cancellationToken);
815834

816835
vAsyncResult.ContinueWith(pAsyncResult =>
817836
{
818837
callback?.Invoke(new TaskResult(pAsyncResult, state));
819-
});
838+
}, cancellationToken);
820839

821840
return vAsyncResult;
822841
}
@@ -868,7 +887,7 @@ private async ValueTask writeAsyncInternal(string value, bool addNewLine, Cancel
868887
catch
869888
{
870889
closedWrite = true;
871-
if (!swallowException)
890+
if (!isNetworkStream)
872891
throw;
873892
}
874893
finally
@@ -893,7 +912,7 @@ private async ValueTask writeAsyncInternal(string value, bool addNewLine, Cancel
893912
catch
894913
{
895914
closedWrite = true;
896-
if (!swallowException)
915+
if (!isNetworkStream)
897916
throw;
898917
}
899918
}
@@ -940,7 +959,7 @@ internal async ValueTask WriteAsync(byte[] data, bool flush = false, Cancellatio
940959
catch
941960
{
942961
closedWrite = true;
943-
if (!swallowException)
962+
if (!isNetworkStream)
944963
throw;
945964
}
946965
}
@@ -964,7 +983,7 @@ internal async Task WriteAsync(byte[] data, int offset, int count, bool flush,
964983
catch
965984
{
966985
closedWrite = true;
967-
if (!swallowException)
986+
if (!isNetworkStream)
968987
throw;
969988
}
970989
}
@@ -1011,7 +1030,7 @@ public async Task CopyBodyAsync(RequestResponseBase requestResponse, bool useOri
10111030

10121031
try
10131032
{
1014-
var http = new HttpStream(s, bufferPool, true);
1033+
var http = new HttpStream(s, bufferPool, cancellationToken, true);
10151034
await http.CopyBodyAsync(writer, false, -1, onCopy, cancellationToken);
10161035
}
10171036
finally
@@ -1196,7 +1215,7 @@ public override async ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, Cancella
11961215
catch
11971216
{
11981217
closedWrite = true;
1199-
if (!swallowException)
1218+
if (!isNetworkStream)
12001219
throw;
12011220
}
12021221
}
@@ -1217,7 +1236,7 @@ public async Task WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken canc
12171236
}
12181237
catch
12191238
{
1220-
if (!swallowException)
1239+
if (!isNetworkStream)
12211240
throw;
12221241
}
12231242
}

src/Titanium.Web.Proxy/Network/Tcp/TcpConnectionFactory.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,7 @@ private async Task<TcpServerConnection> createServerConnection(string remoteHost
445445

446446
await proxyServer.InvokeServerConnectionCreateEvent(tcpClient);
447447

448-
stream = new HttpServerStream(tcpClient.GetStream(), proxyServer.BufferPool);
448+
stream = new HttpServerStream(tcpClient.GetStream(), proxyServer.BufferPool, cancellationToken);
449449

450450
if (externalProxy != null && (isConnect || isHttps))
451451
{
@@ -487,7 +487,7 @@ private async Task<TcpServerConnection> createServerConnection(string remoteHost
487487
(sender, targetHost, localCertificates, remoteCertificate, acceptableIssuers) =>
488488
proxyServer.SelectClientCertificate(sender, sessionArgs, targetHost, localCertificates,
489489
remoteCertificate, acceptableIssuers));
490-
stream = new HttpServerStream(sslStream, proxyServer.BufferPool);
490+
stream = new HttpServerStream(sslStream, proxyServer.BufferPool, cancellationToken);
491491

492492
var options = new SslClientAuthenticationOptions
493493
{

src/Titanium.Web.Proxy/RequestHandler.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ private async Task<RetryResult> handleHttpSessionRequest(SessionEventArgs args,
272272
cancellationToken);
273273

274274
// for connection pool, retry fails until cache is exhausted.
275-
return await retryPolicy<ServerConnectionException>().ExecuteAsync(async (connection) =>
275+
return await retryPolicy<ServerConnectionException>().ExecuteAsync(async connection =>
276276
{
277277
// set the connection and send request headers
278278
args.HttpClient.SetConnection(connection);

0 commit comments

Comments
 (0)