Skip to content

[stdlib] prevent preventable buffer overflows when constructing Strings from C strings. #42221

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Apr 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions stdlib/private/StdlibUnittest/StdlibUnittest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2182,9 +2182,10 @@ func _getSystemVersionPlistProperty(_ propertyName: String) -> String? {
func _getSystemVersionPlistProperty(_ propertyName: String) -> String? {
var count = 0
sysctlbyname("kern.osproductversion", nil, &count, nil, 0)
var s = [CChar](repeating: 0, count: count)
sysctlbyname("kern.osproductversion", &s, &count, nil, 0)
return String(cString: &s)
return withUnsafeTemporaryAllocation(of: CChar.self, capacity: count) {
sysctlbyname("kern.osproductversion", $0.baseAddress, &count, nil, 0)
return String(cString: $0.baseAddress!)
}
}
#endif
#endif
Expand Down
20 changes: 10 additions & 10 deletions stdlib/private/SwiftReflectionTest/SwiftReflectionTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -373,25 +373,25 @@ internal func reflect(instanceAddress: UInt,
shouldUnwrapClassExistential: Bool = false) {
while let command = readLine(strippingNewline: true) {
switch command {
case String(validatingUTF8: RequestInstanceKind)!:
case RequestInstanceKind:
sendValue(kind.rawValue)
case String(validatingUTF8: RequestShouldUnwrapClassExistential)!:
case RequestShouldUnwrapClassExistential:
sendValue(shouldUnwrapClassExistential)
case String(validatingUTF8: RequestInstanceAddress)!:
case RequestInstanceAddress:
sendValue(instanceAddress)
case String(validatingUTF8: RequestReflectionInfos)!:
case RequestReflectionInfos:
sendReflectionInfos()
case String(validatingUTF8: RequestImages)!:
case RequestImages:
sendImages()
case String(validatingUTF8: RequestReadBytes)!:
case RequestReadBytes:
sendBytes()
case String(validatingUTF8: RequestSymbolAddress)!:
case RequestSymbolAddress:
sendSymbolAddress()
case String(validatingUTF8: RequestStringLength)!:
case RequestStringLength:
sendStringLength()
case String(validatingUTF8: RequestPointerSize)!:
case RequestPointerSize:
sendPointerSize()
case String(validatingUTF8: RequestDone)!:
case RequestDone:
return
default:
fatalError("Unknown request received: '\(Array(command.utf8))'!")
Expand Down
227 changes: 218 additions & 9 deletions stdlib/public/core/CString.swift
Original file line number Diff line number Diff line change
Expand Up @@ -42,22 +42,84 @@ extension String {
/// }
/// // Prints "Caf�"
///
/// - Parameter cString: A pointer to a null-terminated UTF-8 code sequence.
public init(cString: UnsafePointer<CChar>) {
let len = UTF8._nullCodeUnitOffset(in: cString)
/// - Parameter nullTerminatedUTF8: A pointer to a null-terminated UTF-8 code sequence.
public init(cString nullTerminatedUTF8: UnsafePointer<CChar>) {
let len = UTF8._nullCodeUnitOffset(in: nullTerminatedUTF8)
self = String._fromUTF8Repairing(
UnsafeBufferPointer(start: cString._asUInt8, count: len)).0
UnsafeBufferPointer(start: nullTerminatedUTF8._asUInt8, count: len)).0
}

@inlinable
@_alwaysEmitIntoClient
public init(cString nullTerminatedUTF8: [CChar]) {
self = nullTerminatedUTF8.withUnsafeBytes {
String(_checkingCString: $0.assumingMemoryBound(to: UInt8.self))
}
}

@_alwaysEmitIntoClient
private init(_checkingCString bytes: UnsafeBufferPointer<UInt8>) {
guard let length = bytes.firstIndex(of: 0) else {
_preconditionFailure(
"input of String.init(cString:) must be null-terminated"
)
}
self = String._fromUTF8Repairing(
UnsafeBufferPointer(
start: bytes.baseAddress._unsafelyUnwrappedUnchecked,
count: length
)
).0
}

@inlinable
@_alwaysEmitIntoClient
@available(*, deprecated, message: "Use String(_ scalar: Unicode.Scalar)")
public init(cString nullTerminatedUTF8: inout CChar) {
guard nullTerminatedUTF8 == 0 else {
_preconditionFailure(
"input of String.init(cString:) must be null-terminated"
)
}
self = ""
}

/// Creates a new string by copying the null-terminated UTF-8 data referenced
/// by the given pointer.
///
/// This is identical to `init(cString: UnsafePointer<CChar>)` but operates on
/// an unsigned sequence of bytes.
public init(cString: UnsafePointer<UInt8>) {
let len = UTF8._nullCodeUnitOffset(in: cString)
public init(cString nullTerminatedUTF8: UnsafePointer<UInt8>) {
let len = UTF8._nullCodeUnitOffset(in: nullTerminatedUTF8)
self = String._fromUTF8Repairing(
UnsafeBufferPointer(start: cString, count: len)).0
UnsafeBufferPointer(start: nullTerminatedUTF8, count: len)).0
}

@inlinable
@_alwaysEmitIntoClient
public init(cString nullTerminatedUTF8: [UInt8]) {
self = nullTerminatedUTF8.withUnsafeBufferPointer {
String(_checkingCString: $0)
}
}

@inlinable
@_alwaysEmitIntoClient
@available(*, deprecated, message: "Use a copy of the String argument")
public init(cString nullTerminatedUTF8: String) {
self = nullTerminatedUTF8.withCString(String.init(cString:))
}

@inlinable
@_alwaysEmitIntoClient
@available(*, deprecated, message: "Use String(_ scalar: Unicode.Scalar)")
public init(cString nullTerminatedUTF8: inout UInt8) {
guard nullTerminatedUTF8 == 0 else {
_preconditionFailure(
"input of String.init(cString:) must be null-terminated"
)
}
self = ""
}

/// Creates a new string by copying and validating the null-terminated UTF-8
Expand Down Expand Up @@ -95,6 +157,40 @@ extension String {
self = str
}

@inlinable
@_alwaysEmitIntoClient
public init?(validatingUTF8 cString: [CChar]) {
guard let length = cString.firstIndex(of: 0) else {
_preconditionFailure(
"input of String.init(validatingUTF8:) must be null-terminated"
)
}
guard let string = cString.prefix(length).withUnsafeBytes({
String._tryFromUTF8($0.assumingMemoryBound(to: UInt8.self))
}) else { return nil }

self = string
}

@inlinable
@_alwaysEmitIntoClient
@available(*, deprecated, message: "Use a copy of the String argument")
public init?(validatingUTF8 cString: String) {
self = cString.withCString(String.init(cString:))
}

@inlinable
@_alwaysEmitIntoClient
@available(*, deprecated, message: "Use String(_ scalar: Unicode.Scalar)")
public init?(validatingUTF8 cString: inout CChar) {
guard cString == 0 else {
_preconditionFailure(
"input of String.init(validatingUTF8:) must be null-terminated"
)
}
self = ""
}

/// Creates a new string by copying the null-terminated data referenced by
/// the given pointer using the specified encoding.
///
Expand Down Expand Up @@ -166,6 +262,77 @@ extension String {
return String._fromCodeUnits(
codeUnits, encoding: encoding, repair: isRepairing)
}

@_specialize(where Encoding == Unicode.UTF8)
@_specialize(where Encoding == Unicode.UTF16)
@inlinable // Fold away specializations
@_alwaysEmitIntoClient
public static func decodeCString<Encoding: _UnicodeEncoding>(
_ cString: [Encoding.CodeUnit],
as encoding: Encoding.Type,
repairingInvalidCodeUnits isRepairing: Bool = true
) -> (result: String, repairsMade: Bool)? {
guard let length = cString.firstIndex(of: 0) else {
_preconditionFailure(
"input of decodeCString(_:as:repairingInvalidCodeUnits:) must be null-terminated"
)
}

if _fastPath(encoding == Unicode.UTF8.self) {
return cString.prefix(length).withUnsafeBytes {
buf -> (result: String, repairsMade: Bool)? in
let codeUnits = buf.assumingMemoryBound(to: UInt8.self)
if isRepairing {
return String._fromUTF8Repairing(codeUnits)
}
else if let str = String._tryFromUTF8(codeUnits) {
return (str, false)
}
return nil
}
}

return cString.prefix(length).withUnsafeBufferPointer {
buf -> (result: String, repairsMade: Bool)? in
String._fromCodeUnits(buf, encoding: encoding, repair: isRepairing)
}
}

@_specialize(where Encoding == Unicode.UTF8)
@_specialize(where Encoding == Unicode.UTF16)
@inlinable
@_alwaysEmitIntoClient
@available(*, deprecated, message: "Use a copy of the String argument")
public static func decodeCString<Encoding: _UnicodeEncoding>(
_ cString: String,
as encoding: Encoding.Type,
repairingInvalidCodeUnits isRepairing: Bool = true
) -> (result: String, repairsMade: Bool)? {
return cString.withCString(encodedAs: encoding) {
String.decodeCString(
$0, as: encoding, repairingInvalidCodeUnits: isRepairing
)
}
}

@_specialize(where Encoding == Unicode.UTF8)
@_specialize(where Encoding == Unicode.UTF16)
@inlinable
@_alwaysEmitIntoClient
@available(*, deprecated, message: "Use String(_ scalar: Unicode.Scalar)")
public static func decodeCString<Encoding: _UnicodeEncoding>(
_ cString: inout Encoding.CodeUnit,
as encoding: Encoding.Type,
repairingInvalidCodeUnits isRepairing: Bool = true
) -> (result: String, repairsMade: Bool)? {
guard cString == 0 else {
_preconditionFailure(
"input of decodeCString(_:as:repairingInvalidCodeUnits:) must be null-terminated"
)
}
return ("", false)
}

/// Creates a string from the null-terminated sequence of bytes at the given
/// pointer.
///
Expand All @@ -179,10 +346,52 @@ extension String {
@_specialize(where Encoding == Unicode.UTF16)
@inlinable // Fold away specializations
public init<Encoding: Unicode.Encoding>(
decodingCString ptr: UnsafePointer<Encoding.CodeUnit>,
decodingCString nullTerminatedCodeUnits: UnsafePointer<Encoding.CodeUnit>,
as sourceEncoding: Encoding.Type
) {
self = String.decodeCString(ptr, as: sourceEncoding)!.0
self = String.decodeCString(nullTerminatedCodeUnits, as: sourceEncoding)!.0
}

@_specialize(where Encoding == Unicode.UTF8)
@_specialize(where Encoding == Unicode.UTF16)
@inlinable // Fold away specializations
@_alwaysEmitIntoClient
public init<Encoding: Unicode.Encoding>(
decodingCString nullTerminatedCodeUnits: [Encoding.CodeUnit],
as sourceEncoding: Encoding.Type
) {
self = String.decodeCString(nullTerminatedCodeUnits, as: sourceEncoding)!.0
}

@_specialize(where Encoding == Unicode.UTF8)
@_specialize(where Encoding == Unicode.UTF16)
@inlinable
@_alwaysEmitIntoClient
@available(*, deprecated, message: "Use a copy of the String argument")
public init<Encoding: _UnicodeEncoding>(
decodingCString nullTerminatedCodeUnits: String,
as sourceEncoding: Encoding.Type
) {
self = nullTerminatedCodeUnits.withCString(encodedAs: sourceEncoding) {
String(decodingCString: $0, as: sourceEncoding.self)
}
}

@_specialize(where Encoding == Unicode.UTF8)
@_specialize(where Encoding == Unicode.UTF16)
@inlinable // Fold away specializations
@_alwaysEmitIntoClient
@available(*, deprecated, message: "Use String(_ scalar: Unicode.Scalar)")
public init<Encoding: Unicode.Encoding>(
decodingCString nullTerminatedCodeUnits: inout Encoding.CodeUnit,
as sourceEncoding: Encoding.Type
) {
guard nullTerminatedCodeUnits == 0 else {
_preconditionFailure(
"input of String.init(decodingCString:as:) must be null-terminated"
)
}
self = ""
}
}

Expand Down
4 changes: 3 additions & 1 deletion test/Prototypes/PatternMatching.swift
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,9 @@ extension Collection where Iterator.Element == UTF8.CodeUnit {
a.reserveCapacity(count + 1)
a.append(contentsOf: self)
a.append(0)
return String(reflecting: String(cString: a))
return a.withUnsafeBufferPointer {
String(reflecting: String(cString: $0.baseAddress!))
}
}
}

Expand Down
Loading