Skip to content

Commit 5ab0c7e

Browse files
committed
[JSON] Improve string decoding
* Don't use UTF8 decoder for decoding escaped strings * Only memcmp ASCII non-escaped strings
1 parent 4c84edb commit 5ab0c7e

File tree

2 files changed

+189
-77
lines changed

2 files changed

+189
-77
lines changed

Sources/SwiftCompilerPluginMessageHandling/JSON/JSONDecoding.swift

Lines changed: 162 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -94,14 +94,33 @@ func decodeFromJSON<T: Decodable>(json: UnsafeBufferPointer<UInt8>) throws -> T
9494

9595
private struct JSONMap {
9696
enum Descriptor: Int {
97-
case nullKeyword // [desc]
98-
case trueKeyword // [desc]
99-
case falseKeyword // [desc]
100-
case number // [desc, pointer, length]
101-
case simpleString // [desc, pointer, length]
102-
case string // [desc, pointer, length]
103-
case object // [desc, count, (key, value)...]
104-
case array // [desc, count, element...]
97+
98+
// MARK: - Keywords; size:1 [desc]
99+
100+
/// 'null'
101+
case nullKeyword
102+
/// 'true' size:1
103+
case trueKeyword
104+
/// 'false' size:1
105+
case falseKeyword
106+
107+
// MARK: - Scalar values; size:3 [desc, pointer, length]
108+
109+
/// Integer and floating number.
110+
case number
111+
/// ASCII non-escaped string.
112+
case asciiSimpleString
113+
/// Non escaped string.
114+
case simpleString
115+
/// String with escape sequences.
116+
case string
117+
118+
// MARK: - Collections; size: 2 + variable [desc, size, element...]
119+
120+
/// Object '{ ... }'. Elements are (key, value)...
121+
case object
122+
/// Array '[ ... ]'.
123+
case array
105124
}
106125
let data: [Int]
107126

@@ -261,16 +280,31 @@ private struct JSONScanner {
261280
mutating func scanString(start: Cursor) throws {
262281
ptr = start
263282
try expect("\"")
283+
264284
var hasEscape = false
285+
var hasNonASCII = false
265286
while hasData && ptr.pointee != UInt8(ascii: "\"") {
287+
// FIXME: Error for non-escaped control characters.
288+
// FIXME: Error for invalid UTF8 sequences.
266289
if ptr.pointee == UInt8(ascii: "\\") {
267290
hasEscape = true
268291
_ = try advance()
292+
} else if ptr.pointee >= 0x80 {
293+
hasNonASCII = true
269294
}
270295
_ = try advance()
271296
}
272297
try expect("\"")
273-
map.record(hasEscape ? .string : .simpleString, range: (start + 1)..<(ptr - 1))
298+
299+
let kind: JSONMap.Descriptor
300+
if hasEscape {
301+
kind = .string
302+
} else if hasNonASCII {
303+
kind = .simpleString
304+
} else {
305+
kind = .asciiSimpleString
306+
}
307+
map.record(kind, range: (start + 1)..<(ptr - 1))
274308
}
275309

276310
mutating func scanNumber(start: Cursor) throws {
@@ -376,7 +410,7 @@ private struct JSONMapValue {
376410
switch JSONMap.Descriptor(rawValue: data[0]) {
377411
case .nullKeyword, .trueKeyword, .falseKeyword:
378412
return 1
379-
case .number, .simpleString, .string:
413+
case .number, .asciiSimpleString, .simpleString, .string:
380414
return 3
381415
case .array, .object:
382416
return data[1]
@@ -412,79 +446,133 @@ extension JSONMapValue {
412446

413447
// MARK: Scalar values
414448
private enum _JSONStringParser {
415-
/// Decode .simpleString value from the buffer.
449+
/// Decode a non-escaped string value from the buffer.
450+
@inline(__always)
416451
static func decodeSimpleString(source: UnsafeBufferPointer<UInt8>) -> String {
417452
if source.count <= 0 {
418453
return ""
419454
}
420-
if #available(macOS 11.0, iOS 14.0, watchOS 7.0, tvOS 14.0, *) {
421-
return String(unsafeUninitializedCapacity: source.count) { buffer in
422-
buffer.initialize(fromContentsOf: source)
423-
}
424-
} else {
425-
return String(decoding: source, as: UTF8.self)
455+
return _makeString(unsafeUninitializedCapacity: source.count) { buffer in
456+
buffer.initialize(fromContentsOf: source)
426457
}
427458
}
428459

429-
/// Helper iterator decoding UTF8 sequence to UnicodeScalar stream.
430-
struct ScalarIterator<S: Sequence>: IteratorProtocol where S.Element == UInt8 {
431-
var backing: S.Iterator
432-
var decoder: UTF8
433-
init(_ source: S) {
434-
self.backing = source.makeIterator()
435-
self.decoder = UTF8()
436-
}
437-
mutating func next() -> UnicodeScalar? {
438-
switch decoder.decode(&backing) {
439-
case .scalarValue(let scalar): return scalar
440-
case .emptyInput: return nil
441-
case .error: fatalError("invalid")
460+
/// Decode a string value that includes escape sequences.
461+
static func decodeStringWithEscapes(source: UnsafeBufferPointer<UInt8>) -> String? {
462+
// JSON string with escape sequences must be 2 bytes or longer.
463+
assert(source.count > 0)
464+
465+
// Decode 'source' UTF-8 JSON string literal into the uninitialized
466+
// UTF-8 buffer. Upon error, return 0 and make an empty string.
467+
let decoded = _makeString(unsafeUninitializedCapacity: source.count) { buffer in
468+
469+
var cursor = source.baseAddress!
470+
let end = cursor + source.count
471+
var mark = cursor
472+
473+
var dest = buffer.baseAddress!
474+
475+
@inline(__always) func flush() {
476+
let count = mark.distance(to: cursor)
477+
dest.initialize(from: mark, count: count)
478+
dest += count
442479
}
443-
}
444-
}
445480

446-
static func decodeStringWithEscapes(source: UnsafeBufferPointer<UInt8>) -> String? {
447-
var string: String = ""
448-
string.reserveCapacity(source.count)
449-
var iter = ScalarIterator(source)
450-
while let scalar = iter.next() {
451-
// NOTE: We don't report detailed errors because we only care well-formed
452-
// payloads from the compiler.
453-
if scalar == "\\" {
454-
switch iter.next() {
455-
case "\"": string.append("\"")
456-
case "'": string.append("'")
457-
case "\\": string.append("\\")
458-
case "/": string.append("/")
459-
case "b": string.append("\u{08}")
460-
case "f": string.append("\u{0C}")
461-
case "n": string.append("\u{0A}")
462-
case "r": string.append("\u{0D}")
463-
case "t": string.append("\u{09}")
464-
case "u":
465-
// We don't care performance of this because \uFFFF style escape is
466-
// pretty rare. We only do it for control characters.
467-
let buffer: [UInt8] = [iter.next(), iter.next(), iter.next(), iter.next()]
468-
.compactMap { $0 }
469-
.compactMap { UInt8(exactly: $0.value) }
470-
471-
guard
472-
buffer.count == 4,
473-
let result: UInt16 = buffer.withUnsafeBufferPointer(_JSONNumberParser.parseHexIntegerDigits(source:)),
474-
let scalar = UnicodeScalar(result)
475-
else {
476-
return nil
477-
}
478-
string.append(Character(scalar))
479-
default:
480-
// invalid escape sequence
481-
return nil
481+
while cursor != end {
482+
if cursor.pointee != UInt8(ascii: "\\") {
483+
cursor += 1
484+
continue
485+
}
486+
487+
// Found an escape sequence. Flush the skipped source into the buffer.
488+
flush()
489+
490+
let hadError = decodeEscapeSequence(cursor: &cursor, end: end) {
491+
dest.initialize(to: $0)
492+
dest += 1
482493
}
483-
} else {
484-
string.append(Character(scalar))
494+
guard !hadError else { return 0 }
495+
496+
// Mark the position of the end of the escape sequence.
497+
mark = cursor
485498
}
499+
500+
// Flush the remaining non-escaped characters.
501+
flush()
502+
503+
return buffer.baseAddress!.distance(to: dest)
504+
}
505+
506+
// If any error is detected, empty string is created.
507+
return decoded.isEmpty ? nil : decoded
508+
}
509+
510+
/// Decode a JSON escape sequence, advance 'cursor' to end of the escape
511+
/// sequence, and call 'processCodeUnit' with the decoded value.
512+
/// Returns 'true' on error.
513+
///
514+
/// NOTE: We don't report detailed errors for now because we only care
515+
/// well-formed payloads from the compiler.
516+
private static func decodeEscapeSequence(
517+
cursor: inout UnsafePointer<UInt8>,
518+
end: UnsafePointer<UInt8>,
519+
into processCodeUnit: (UInt8) -> Void
520+
) -> Bool {
521+
assert(cursor.pointee == UInt8(ascii: "\\"))
522+
guard cursor.distance(to: end) >= 2 else { return true }
523+
524+
// Eat backslash and the next character.
525+
cursor += 2
526+
switch cursor[-1] {
527+
case UInt8(ascii: "\""): processCodeUnit(UInt8(ascii: "\""))
528+
case UInt8(ascii: "'"): processCodeUnit(UInt8(ascii: "'"))
529+
case UInt8(ascii: "\\"): processCodeUnit(UInt8(ascii: "\\"))
530+
case UInt8(ascii: "/"): processCodeUnit(UInt8(ascii: "/"))
531+
case UInt8(ascii: "b"): processCodeUnit(0x08)
532+
case UInt8(ascii: "f"): processCodeUnit(0x0C)
533+
case UInt8(ascii: "n"): processCodeUnit(0x0A)
534+
case UInt8(ascii: "r"): processCodeUnit(0x0D)
535+
case UInt8(ascii: "t"): processCodeUnit(0x09)
536+
case UInt8(ascii: "u"):
537+
guard cursor.distance(to: end) >= 4 else { return true }
538+
539+
// Parse 4 hex digits into a UTF-16 code unit.
540+
let result: UInt16? = _JSONNumberParser.parseHexIntegerDigits(
541+
source: UnsafeBufferPointer(start: cursor, count: 4)
542+
)
543+
guard let result else { return true }
544+
545+
// Transcode UTF-16 code unit to UTF-8.
546+
// FIXME: Support surrogate pairs.
547+
let hadError = transcode(
548+
CollectionOfOne(result).makeIterator(),
549+
from: UTF16.self,
550+
to: UTF8.self,
551+
stoppingOnError: true,
552+
into: processCodeUnit
553+
)
554+
guard !hadError else { return true }
555+
cursor += 4
556+
default:
557+
// invalid escape sequence.
558+
return true
559+
}
560+
return false
561+
}
562+
563+
/// SwiftStdlib 5.3 compatibility shim for
564+
/// 'String.init(unsafeUninitializedCapacity:initializingUTF8With:)'
565+
private static func _makeString(
566+
unsafeUninitializedCapacity capacity: Int,
567+
initializingUTF8With initializer: (UnsafeMutableBufferPointer<UInt8>) throws -> Int
568+
) rethrows -> String {
569+
if #available(macOS 11.0, iOS 14.0, watchOS 7.0, tvOS 14.0, *) {
570+
return try String(unsafeUninitializedCapacity: capacity, initializingUTF8With: initializer)
571+
} else {
572+
let buffer = UnsafeMutableBufferPointer<UInt8>.allocate(capacity: capacity)
573+
let count = try initializer(buffer)
574+
return String(decoding: buffer[..<count], as: UTF8.self)
486575
}
487-
return string
488576
}
489577
}
490578

@@ -574,7 +662,7 @@ extension JSONMapValue {
574662

575663
@inline(__always)
576664
func asString() -> String? {
577-
if self.is(.simpleString) {
665+
if self.is(.asciiSimpleString) || self.is(.simpleString) {
578666
return _JSONStringParser.decodeSimpleString(source: valueBuffer())
579667
}
580668
if self.is(.string) {
@@ -583,12 +671,12 @@ extension JSONMapValue {
583671
return nil
584672
}
585673

586-
/// Returns true if this value represents a string, and it equals to 'str'.
674+
/// Returns true if this value represents a string and equals to 'str'.
587675
///
588676
/// This is faster than 'value.asString() == str' because this doesn't
589677
/// instantiate 'Swift.String' unless there are escaped characters.
590678
func equals(to str: String) -> Bool {
591-
if self.is(.simpleString) {
679+
if self.is(.asciiSimpleString) {
592680
let lhs = valueBuffer()
593681
var str = str
594682
return str.withUTF8 { rhs in

Tests/SwiftCompilerPluginTest/JSONTests.swift

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,11 +94,11 @@ final class JSONTests: XCTestCase {
9494
)
9595
}
9696

97-
func testUnicodeEscape() {
97+
func testEscapedString() {
9898
_testRoundTrip(
99-
of: "\n\u{A9}\u{0}\u{07}\u{1B}",
99+
of: "\n\"\\\u{A9}\u{0}\u{07}\u{1B}",
100100
expectedJSON: #"""
101-
"\n©\u0000\u0007\u001B"
101+
"\n\"\\©\u0000\u0007\u001B"
102102
"""#
103103
)
104104
}
@@ -130,6 +130,21 @@ final class JSONTests: XCTestCase {
130130
)
131131
}
132132

133+
func testInvalidStringDecoding() {
134+
_assertInvalidStrng(#""foo\"#) // EOF after '\'
135+
_assertInvalidStrng(#""\x""#) // Unknown character after '\'
136+
_assertInvalidStrng(#""\u1""#) // Missing 4 digits after '\u'
137+
_assertInvalidStrng(#""\u12""#)
138+
_assertInvalidStrng(#""\u123""#)
139+
_assertInvalidStrng(#""\uEFGH""#) // Invalid HEX characters.
140+
}
141+
142+
func testStringSurrogatePairDecoding() {
143+
// FIXME: Escaped surrogate pairs are not supported.
144+
// Currently parsed as "invalid", but this should be valid '𐐷' (U+10437) character
145+
_assertInvalidStrng(#"\uD801\uDC37"#)
146+
}
147+
133148
func testTypeCoercion() {
134149
_testRoundTripTypeCoercionFailure(of: [false, true], as: [Int].self)
135150
_testRoundTripTypeCoercionFailure(of: [false, true], as: [Int8].self)
@@ -195,10 +210,19 @@ final class JSONTests: XCTestCase {
195210
}
196211
}
197212

213+
private func _assertInvalidStrng(_ json: String) {
214+
do {
215+
var json = json
216+
_ = try json.withUTF8 { try JSON.decode(String.self, from: $0) }
217+
XCTFail("decoding should fail")
218+
} catch {}
219+
}
220+
198221
private func _assertParseError(_ json: String, message: String) {
199222
do {
200223
var json = json
201224
_ = try json.withUTF8 { try JSON.decode(Bool.self, from: $0) }
225+
XCTFail("decoding should fail")
202226
} catch DecodingError.dataCorrupted(let context) {
203227
XCTAssertEqual(
204228
String(describing: try XCTUnwrap(context.underlyingError)),

0 commit comments

Comments
 (0)