Skip to content

Remove unsafe use ASCII.GetBytes for WriteAscii #18404

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
5 commits merged into from
Feb 26, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11538,7 +11538,7 @@ internal unsafe void CopyToFast(ref BufferWriter<PipeWriter> output)
if (value != null)
{
output.Write(headerKey);
output.WriteAsciiNoValidation(value);
output.WriteAscii(value);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ static void CopyExtraHeaders(ref BufferWriter<PipeWriter> buffer, Dictionary<str
if (value != null)
{
buffer.Write(CrLf);
buffer.WriteAsciiNoValidation(kv.Key);
buffer.WriteAscii(kv.Key);
buffer.Write(ColonSpace);
buffer.WriteAsciiNoValidation(value);
buffer.WriteAscii(value);
}
}
}
Expand Down
12 changes: 6 additions & 6 deletions src/Servers/Kestrel/Core/test/PipelineExtensionTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ public void EncodesAsAscii(string input, byte[] expected)
{
var pipeWriter = _pipe.Writer;
var writer = new BufferWriter<PipeWriter>(pipeWriter);
writer.WriteAsciiNoValidation(input);
writer.WriteAscii(input);
writer.Commit();
pipeWriter.FlushAsync().GetAwaiter().GetResult();
pipeWriter.Complete();
Expand All @@ -111,13 +111,13 @@ public void EncodesAsAscii(string input, byte[] expected)
[InlineData("𤭢𐐝")]
// non-ascii characters stored in 16 bits
[InlineData("ñ٢⛄⛵")]
public void WriteAsciiNoValidationWritesOnlyOneBytePerChar(string input)
public void WriteAsciiWritesOnlyOneBytePerChar(string input)
{
// WriteAscii doesn't validate if characters are in the ASCII range
// but it shouldn't produce more than one byte per character
var writerBuffer = _pipe.Writer;
var writer = new BufferWriter<PipeWriter>(writerBuffer);
writer.WriteAsciiNoValidation(input);
writer.WriteAscii(input);
writer.Commit();
writerBuffer.FlushAsync().GetAwaiter().GetResult();
var reader = _pipe.Reader.ReadAsync().GetAwaiter().GetResult();
Expand All @@ -126,14 +126,14 @@ public void WriteAsciiNoValidationWritesOnlyOneBytePerChar(string input)
}

[Fact]
public void WriteAsciiNoValidation()
public void WriteAscii()
{
const byte maxAscii = 0x7f;
var writerBuffer = _pipe.Writer;
var writer = new BufferWriter<PipeWriter>(writerBuffer);
for (var i = 0; i < maxAscii; i++)
{
writer.WriteAsciiNoValidation(new string((char)i, 1));
writer.WriteAscii(new string((char)i, 1));
}
writer.Commit();
writerBuffer.FlushAsync().GetAwaiter().GetResult();
Expand Down Expand Up @@ -167,7 +167,7 @@ public void WritesAsciiAcrossBlockBoundaries(int stringLength, int gapSize)
Assert.Equal(gapSize, writer.Span.Length);

var bufferLength = writer.Span.Length;
writer.WriteAsciiNoValidation(testString);
writer.WriteAscii(testString);
Assert.NotEqual(bufferLength, writer.Span.Length);
writer.Commit();
writerBuffer.FlushAsync().GetAwaiter().GetResult();
Expand Down
2 changes: 1 addition & 1 deletion src/Servers/Kestrel/shared/KnownHeaders.cs
Original file line number Diff line number Diff line change
Expand Up @@ -945,7 +945,7 @@ internal unsafe void CopyToFast(ref BufferWriter<PipeWriter> output)
if (value != null)
{{
output.Write(headerKey);
output.WriteAsciiNoValidation(value);
output.WriteAscii(value);
}}
}}
}}
Expand Down
151 changes: 26 additions & 125 deletions src/Shared/ServerInfrastructure/BufferExtensions.cs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Buffers;
using System.Diagnostics;
using System.IO.Pipelines;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Text;

namespace System.Buffers
{
Expand Down Expand Up @@ -40,26 +40,19 @@ public static ArraySegment<byte> GetArray(this ReadOnlyMemory<byte> memory)
return result;
}

internal static unsafe void WriteAsciiNoValidation(ref this BufferWriter<PipeWriter> buffer, string data)
internal static void WriteAscii(ref this BufferWriter<PipeWriter> buffer, string data)
{
if (string.IsNullOrEmpty(data))
{
return;
}

var dest = buffer.Span;
var destLength = dest.Length;
var sourceLength = data.Length;

// Fast path, try copying to the available memory directly
if (sourceLength <= destLength)
// Fast path, try encoding to the available memory directly
if (sourceLength <= dest.Length)
{
fixed (char* input = data)
fixed (byte* output = dest)
{
EncodeAsciiCharsToBytes(input, output, sourceLength);
}

Encoding.ASCII.GetBytes(data, dest);
buffer.Advance(sourceLength);
}
else
Expand Down Expand Up @@ -140,123 +133,31 @@ private static void WriteNumericMultiWrite(ref this BufferWriter<PipeWriter> buf
}

[MethodImpl(MethodImplOptions.NoInlining)]
private unsafe static void WriteAsciiMultiWrite(ref this BufferWriter<PipeWriter> buffer, string data)
private static void WriteAsciiMultiWrite(ref this BufferWriter<PipeWriter> buffer, string data)
{
var remaining = data.Length;

fixed (char* input = data)
{
var inputSlice = input;

while (remaining > 0)
{
var writable = Math.Min(remaining, buffer.Span.Length);

if (writable == 0)
{
buffer.Ensure();
continue;
}

fixed (byte* output = buffer.Span)
{
EncodeAsciiCharsToBytes(inputSlice, output, writable);
}

inputSlice += writable;
remaining -= writable;

buffer.Advance(writable);
}
}
}

private static unsafe void EncodeAsciiCharsToBytes(char* input, byte* output, int length)
{
// Note: Not BIGENDIAN or check for non-ascii
const int Shift16Shift24 = (1 << 16) | (1 << 24);
const int Shift8Identity = (1 << 8) | (1);

// Encode as bytes up to the first non-ASCII byte and return count encoded
int i = 0;
// Use Intrinsic switch
if (IntPtr.Size == 8) // 64 bit
{
if (length < 4) goto trailing;

int unaligned = (int)(((ulong)input) & 0x7) >> 1;
// Unaligned chars
for (; i < unaligned; i++)
{
char ch = *(input + i);
*(output + i) = (byte)ch; // Cast convert
}

// Aligned
int ulongDoubleCount = (length - i) & ~0x7;
for (; i < ulongDoubleCount; i += 8)
{
ulong inputUlong0 = *(ulong*)(input + i);
ulong inputUlong1 = *(ulong*)(input + i + 4);
// Pack 16 ASCII chars into 16 bytes
*(uint*)(output + i) =
((uint)((inputUlong0 * Shift16Shift24) >> 24) & 0xffff) |
((uint)((inputUlong0 * Shift8Identity) >> 24) & 0xffff0000);
*(uint*)(output + i + 4) =
((uint)((inputUlong1 * Shift16Shift24) >> 24) & 0xffff) |
((uint)((inputUlong1 * Shift8Identity) >> 24) & 0xffff0000);
}
if (length - 4 > i)
{
ulong inputUlong = *(ulong*)(input + i);
// Pack 8 ASCII chars into 8 bytes
*(uint*)(output + i) =
((uint)((inputUlong * Shift16Shift24) >> 24) & 0xffff) |
((uint)((inputUlong * Shift8Identity) >> 24) & 0xffff0000);
i += 4;
}

trailing:
for (; i < length; i++)
{
char ch = *(input + i);
*(output + i) = (byte)ch; // Cast convert
}
}
else // 32 bit
var dataLength = data.Length;
var offset = 0;
var bytes = buffer.Span;
do
{
// Unaligned chars
if ((unchecked((int)input) & 0x2) != 0)
{
char ch = *input;
i = 1;
*(output) = (byte)ch; // Cast convert
}

// Aligned
int uintCount = (length - i) & ~0x3;
for (; i < uintCount; i += 4)
var writable = Math.Min(dataLength - offset, bytes.Length);
// Zero length spans are possible, though unlikely.
// ASCII.GetBytes and .Advance will both handle them so we won't special case for them.
Encoding.ASCII.GetBytes(data.AsSpan(offset, writable), bytes);
buffer.Advance(writable);

offset += writable;
if (offset >= dataLength)
{
uint inputUint0 = *(uint*)(input + i);
uint inputUint1 = *(uint*)(input + i + 2);
// Pack 4 ASCII chars into 4 bytes
*(ushort*)(output + i) = (ushort)(inputUint0 | (inputUint0 >> 8));
*(ushort*)(output + i + 2) = (ushort)(inputUint1 | (inputUint1 >> 8));
}
if (length - 1 > i)
{
uint inputUint = *(uint*)(input + i);
// Pack 2 ASCII chars into 2 bytes
*(ushort*)(output + i) = (ushort)(inputUint | (inputUint >> 8));
i += 2;
Debug.Assert(offset == dataLength);
// Encoded everything
break;
}

if (i < length)
{
char ch = *(input + i);
*(output + i) = (byte)ch; // Cast convert
}
}
// Get new span, more to encode.
buffer.Ensure();
bytes = buffer.Span;
} while (true);
}

private static byte[] NumericBytesScratch => _numericBytesScratch ?? CreateNumericBytesScratch();
Expand Down