Skip to content

Commit 369a6f7

Browse files
authored
Merge pull request #1162 from mot256/fnel/6.x/fix-buffer-overflow
Fix buffer overflow in WireFormatting.WriteLongstr
2 parents 1fac6fc + 9b340dc commit 369a6f7

File tree

2 files changed

+198
-4
lines changed

2 files changed

+198
-4
lines changed

projects/RabbitMQ.Client/client/impl/WireFormatting.cs

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,7 @@ public static unsafe int WriteShortstr(Span<byte> span, string val)
428428
{
429429
try
430430
{
431-
int bytesWritten = Encoding.UTF8.GetBytes(chars, val.Length, bytes, maxLength);
431+
int bytesWritten = val.Length > 0 ? Encoding.UTF8.GetBytes(chars, val.Length, bytes, maxLength) : 0;
432432
span[0] = (byte)bytesWritten;
433433
return bytesWritten + 1;
434434
}
@@ -441,12 +441,20 @@ public static unsafe int WriteShortstr(Span<byte> span, string val)
441441

442442
public static unsafe int WriteLongstr(Span<byte> span, string val)
443443
{
444+
int maxLength = span.Length - 4;
444445
fixed (char* chars = val)
445446
fixed (byte* bytes = &span.Slice(4).GetPinnableReference())
446447
{
447-
int bytesWritten = Encoding.UTF8.GetBytes(chars, val.Length, bytes, span.Length);
448-
NetworkOrderSerializer.WriteUInt32(span, (uint)bytesWritten);
449-
return bytesWritten + 4;
448+
try
449+
{
450+
int bytesWritten = val.Length > 0 ? Encoding.UTF8.GetBytes(chars, val.Length, bytes, maxLength) : 0;
451+
NetworkOrderSerializer.WriteUInt32(span, (uint)bytesWritten);
452+
return bytesWritten + 4;
453+
}
454+
catch (ArgumentException)
455+
{
456+
throw new ArgumentOutOfRangeException(nameof(val), val, $"Value exceeds the maximum allowed length of {maxLength} bytes.");
457+
}
450458
}
451459
}
452460

projects/Unit/TestWireFormatting.cs

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
using System;
2+
using System.Linq;
3+
using System.Text;
4+
using NUnit.Framework;
5+
using RabbitMQ.Client.Impl;
6+
7+
namespace RabbitMQ.Client.Unit
8+
{
9+
[TestFixture]
10+
internal class TestWireFormatting : WireFormattingFixture
11+
{
12+
[TestCase("", 1 + 0)]
13+
[TestCase("1", 1 + 1)]
14+
[TestCase("12", 1 + 2)]
15+
[TestCase("ǽ", 1 + 2, Description = "Latin Small Letter AE With Acute (U+01FD) amounts to 2 bytes")]
16+
public void TestWriteShortstr_BytesWritten(string inputStr, int expectedBytesWritten)
17+
{
18+
byte[] arr = new byte[expectedBytesWritten];
19+
Assert.AreEqual(expectedBytesWritten, WireFormatting.WriteShortstr(arr, inputStr));
20+
Assert.AreEqual(expectedBytesWritten - 1, arr[0]);
21+
}
22+
23+
[TestCase("12", 0, 0)]
24+
[TestCase("12", 10, 0)]
25+
[TestCase("12", 1, 1)]
26+
[TestCase("12", 10, 1)]
27+
[TestCase("12", 2, 2)]
28+
[TestCase("12345", 5, 5)]
29+
[TestCase("12", 20, 2)]
30+
[TestCase("12345", 50, 5)]
31+
[TestCase("ǽ", 2, 2, Description = "Latin Small Letter AE With Acute (U+01FD) amounts to 2 bytes. length byte + 2 bytes should not fit in span of length 2")]
32+
[TestCase("ǽ", 4, 2, Description = "Latin Small Letter AE With Acute (U+01FD) amounts to 2 bytes. length byte + 2 bytes should not fit in span of length 2")]
33+
public void TestWriteShortstr_FailsOnSpanLengthViolation(string inputStr, int bufferSize, int spanLength)
34+
{
35+
byte[] arr = new byte[bufferSize];
36+
Assert.Throws<ArgumentOutOfRangeException>(() => WireFormatting.WriteShortstr(arr.AsSpan(0, spanLength), inputStr));
37+
38+
if (bufferSize > spanLength)
39+
{
40+
// Ensure that even though we got an exception, the method never wrote further than the provided span (possible due to unsafe code)
41+
for (int i = spanLength; i < bufferSize - spanLength; i++)
42+
{
43+
Assert.Zero(arr[i]);
44+
}
45+
}
46+
}
47+
48+
[TestCase("", 4 + 0)]
49+
[TestCase("1", 4 + 1)]
50+
[TestCase("12", 4 + 2)]
51+
[TestCase("ǽ", 4 + 2, Description = "Latin Small Letter AE With Acute (U+01FD) amounts to 2 bytes")]
52+
public void TestWriteLongstr_BytesWritten(string inputStr, int expectedBytesWritten)
53+
{
54+
byte[] arr = new byte[expectedBytesWritten];
55+
Assert.AreEqual(expectedBytesWritten, WireFormatting.WriteLongstr(arr, inputStr));
56+
57+
int expectedLengthWritten = expectedBytesWritten - 4;
58+
Assert.AreEqual(expectedLengthWritten >> 24 & 0xFF, arr[0]);
59+
Assert.AreEqual(expectedLengthWritten >> 16 & 0xFF, arr[1]);
60+
Assert.AreEqual(expectedLengthWritten >> 8 & 0xFF, arr[2]);
61+
Assert.AreEqual(expectedLengthWritten & 0xFF, arr[3]);
62+
}
63+
64+
[Test]
65+
public void TestWriteLongstr_BytesWritten_VeryLarge()
66+
{
67+
TestWriteLongstr_BytesWritten(new string('*', 0x01020304), 4 + 0x01020304);
68+
}
69+
70+
[Test]
71+
public void TestWriteLongstr_BytesWritten_VeryLarge2()
72+
{
73+
TestWriteLongstr_BytesWritten(new string('ǽ', 0x01020304), 4 + (0x01020304 * 2));
74+
}
75+
76+
[TestCase("12", 0, 0)]
77+
[TestCase("12", 10, 0)]
78+
[TestCase("12", 1, 1)]
79+
[TestCase("12", 10, 1)]
80+
[TestCase("12", 4, 4)]
81+
[TestCase("12", 10, 4)]
82+
[TestCase("12", 5, 5)]
83+
[TestCase("12345", 5, 5)]
84+
[TestCase("12", 15, 5)]
85+
[TestCase("12345", 15, 5)]
86+
[TestCase("ǽ", 2, 2, Description = "Latin Small Letter AE With Acute (U+01FD) amounts to 2 bytes. 4 length bytes + 2 bytes should not fit in span of length 2")]
87+
public void TestWriteLongstr_FailsOnSpanLengthViolation(string inputStr, int bufferSize, int spanLength)
88+
{
89+
byte[] arr = new byte[bufferSize];
90+
Assert.Throws<ArgumentOutOfRangeException>(() => WireFormatting.WriteLongstr(arr.AsSpan(0, spanLength), inputStr));
91+
92+
if (bufferSize > spanLength)
93+
{
94+
// Ensure that even though we got an exception, the method never wrote further than the provided span (possible due to unsafe code)
95+
for (int i = spanLength; i < bufferSize - spanLength; i++)
96+
{
97+
Assert.Zero(arr[i]);
98+
}
99+
}
100+
}
101+
102+
[TestCaseSource(nameof(GetTestReadShortstrData))]
103+
public void TestReadShortstr_BytesRead(byte[] inputBuffer, int expectedBytesRead, string expectedString)
104+
{
105+
Assert.AreEqual(expectedString, WireFormatting.ReadShortstr(inputBuffer, out int bytesRead));
106+
Assert.AreEqual(expectedBytesRead, bytesRead);
107+
}
108+
109+
private static object[][] GetTestReadShortstrData()
110+
{
111+
return new object[][]
112+
{
113+
new object[] { new byte[] { 0 }, 1, "" },
114+
new object[] { new byte[] { 1, (byte)'1' }, 1 + 1, "1" },
115+
new object[] { new byte[] { 2, (byte)'1', (byte)'2' }, 1 + 2, "12" },
116+
new object[] { new byte[] { 2, 0xC7, 0xBD }, 1 + 2, "ǽ" } };
117+
}
118+
119+
[TestCaseSource(nameof(GetTestReadShortstrFailsOnSpanLengthViolationData))]
120+
public void TestReadShortstr_FailsOnSpanLengthViolation(byte[] inputBuffer)
121+
{
122+
int bytesRead = 0;
123+
Assert.Throws<ArgumentOutOfRangeException>(() => WireFormatting.ReadShortstr(inputBuffer, out bytesRead));
124+
Assert.AreEqual(0, bytesRead);
125+
}
126+
127+
private static object[][] GetTestReadShortstrFailsOnSpanLengthViolationData()
128+
{
129+
return new object[][]
130+
{
131+
new object[] { new byte[] { 1 } },
132+
new object[] { new byte[] { 2, (byte)'1' } },
133+
new object[] { new byte[] { 255 } }
134+
};
135+
}
136+
137+
[TestCaseSource(nameof(GetTestReadLongstrData))]
138+
public void TestReadLongstr_BytesRead(byte[] inputBuffer, int expectedBytesRead, string expectedString)
139+
{
140+
byte[] stringRead = WireFormatting.ReadLongstr(inputBuffer);
141+
Assert.True(Encoding.UTF8.GetBytes(expectedString).SequenceEqual(stringRead));
142+
Assert.AreEqual(expectedBytesRead, stringRead.Length + 4);
143+
}
144+
145+
private static object[][] GetTestReadLongstrData()
146+
{
147+
return new object[][]
148+
{
149+
new object[] { new byte[] { 0, 0, 0, 0 }, 4, "" },
150+
new object[] { new byte[] { 0, 0, 0, 1, (byte)'1' }, 4 + 1, "1" },
151+
new object[] { new byte[] { 0, 0, 0, 2, (byte)'1', (byte)'2' }, 4 + 2, "12" },
152+
new object[] { new byte[] { 0, 0, 0, 2, 0xC7, 0xBD }, 4 + 2, "ǽ" },
153+
};
154+
}
155+
156+
[Test]
157+
public void TestReadLongstr_BytesRead_VeryLarge()
158+
{
159+
var str = new string('*', 0x01020304);
160+
var chars = str.ToCharArray();
161+
var buffer = new byte[4 + Encoding.UTF8.GetMaxByteCount(chars.Length)];
162+
buffer[0] = 0x01;
163+
buffer[1] = 0x02;
164+
buffer[2] = 0x03;
165+
buffer[3] = 0x04;
166+
Encoding.UTF8.GetEncoder().GetBytes(chars, 0, chars.Length, buffer, 4, true);
167+
168+
TestReadLongstr_BytesRead(buffer, 4 + 0x01020304, str);
169+
}
170+
171+
[Test]
172+
public void TestReadLongstr_BytesRead_VeryLarge2()
173+
{
174+
var str = new string('ǽ', 0x01020304);
175+
var chars = str.ToCharArray();
176+
var buffer = new byte[4 + Encoding.UTF8.GetMaxByteCount(chars.Length)];
177+
buffer[0] = 0x02;
178+
buffer[1] = 0x04;
179+
buffer[2] = 0x06;
180+
buffer[3] = 0x08;
181+
Encoding.UTF8.GetEncoder().GetBytes(chars, 0, chars.Length, buffer, 4, true);
182+
183+
TestReadLongstr_BytesRead(buffer, 4 + (0x01020304 * 2), str);
184+
}
185+
}
186+
}

0 commit comments

Comments
 (0)