Skip to content

Commit 2c49bee

Browse files
madsodgaardgwynne
andauthored
Add proper support for Decimal (#194)
* Use `PostgresNumeric` for `Decimal` instead of String * Make `Decimal` conform to `PSQLCodable` * Fix support for text decimals * Add integration test for decimal string serialization * Test inserting decimal to text column Co-authored-by: Gwynne Raskind <[email protected]>
1 parent f91f23d commit 2c49bee

File tree

5 files changed

+128
-8
lines changed

5 files changed

+128
-8
lines changed

Sources/PostgresNIO/Data/PostgresData+Decimal.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ extension PostgresData {
1818

1919
extension Decimal: PostgresDataConvertible {
2020
public static var postgresDataType: PostgresDataType {
21-
return String.postgresDataType
21+
return .numeric
2222
}
2323

2424
public init?(postgresData: PostgresData) {
@@ -29,6 +29,6 @@ extension Decimal: PostgresDataConvertible {
2929
}
3030

3131
public var postgresData: PostgresData? {
32-
return .init(decimal: self)
32+
return .init(numeric: PostgresNumeric(decimal: self))
3333
}
3434
}
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import NIOCore
2+
import struct Foundation.Decimal
3+
4+
extension Decimal: PSQLCodable {
5+
var psqlType: PSQLDataType {
6+
.numeric
7+
}
8+
9+
var psqlFormat: PSQLFormat {
10+
.binary
11+
}
12+
13+
static func decode(from byteBuffer: inout ByteBuffer, type: PSQLDataType, format: PSQLFormat, context: PSQLDecodingContext) throws -> Decimal {
14+
switch (format, type) {
15+
case (.binary, .numeric):
16+
guard let numeric = PostgresNumeric(buffer: &byteBuffer) else {
17+
throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: byteBuffer, context: context)
18+
}
19+
return numeric.decimal
20+
case (.text, .numeric):
21+
guard let string = byteBuffer.readString(length: byteBuffer.readableBytes), let value = Decimal(string: string) else {
22+
throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: byteBuffer, context: context)
23+
}
24+
return value
25+
default:
26+
throw PSQLCastingError.failure(targetType: Self.self, type: type, postgresData: byteBuffer, context: context)
27+
}
28+
}
29+
30+
func encode(into byteBuffer: inout ByteBuffer, context: PSQLEncodingContext) {
31+
let numeric = PostgresNumeric(decimal: self)
32+
byteBuffer.writeInteger(numeric.ndigits)
33+
byteBuffer.writeInteger(numeric.weight)
34+
byteBuffer.writeInteger(numeric.sign)
35+
byteBuffer.writeInteger(numeric.dscale)
36+
var value = numeric.value
37+
byteBuffer.writeBuffer(&value)
38+
}
39+
}

Tests/IntegrationTests/PSQLIntegrationTests.swift

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,31 @@ final class IntegrationTests: XCTestCase {
251251
XCTAssertEqual(try row?.decode(column: "timestamptz", as: Date.self).description, "2016-01-18 00:20:03 +0000")
252252
}
253253

254+
func testDecodeDecimals() {
255+
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
256+
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }
257+
let eventLoop = eventLoopGroup.next()
258+
259+
var conn: PSQLConnection?
260+
XCTAssertNoThrow(conn = try PSQLConnection.test(on: eventLoop).wait())
261+
defer { XCTAssertNoThrow(try conn?.close().wait()) }
262+
263+
var stream: PSQLRowStream?
264+
XCTAssertNoThrow(stream = try conn?.query("""
265+
SELECT
266+
$1::numeric as numeric,
267+
$2::numeric as numeric_negative
268+
""", [Decimal(string: "123456.789123")!, Decimal(string: "-123456.789123")!], logger: .psqlTest).wait())
269+
270+
var rows: [PSQLRow]?
271+
XCTAssertNoThrow(rows = try stream?.all().wait())
272+
XCTAssertEqual(rows?.count, 1)
273+
let row = rows?.first
274+
275+
XCTAssertEqual(try row?.decode(column: "numeric", as: Decimal.self), Decimal(string: "123456.789123")!)
276+
XCTAssertEqual(try row?.decode(column: "numeric_negative", as: Decimal.self), Decimal(string: "-123456.789123")!)
277+
}
278+
254279
func testDecodeUUID() {
255280
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1)
256281
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }

Tests/IntegrationTests/PostgresNIOTests.swift

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -466,17 +466,41 @@ final class PostgresNIOTests: XCTestCase {
466466
var rows: PostgresQueryResult?
467467
XCTAssertNoThrow(rows = try conn?.query("""
468468
select
469-
$1::numeric::text as a,
470-
$2::numeric::text as b,
471-
$3::numeric::text as c
469+
$1::numeric as a,
470+
$2::numeric as b,
471+
$3::numeric as c
472472
""", [
473473
.init(numeric: a),
474474
.init(numeric: b),
475475
.init(numeric: c)
476476
]).wait())
477-
XCTAssertEqual(rows?.first?.column("a")?.string, "123456.789123")
478-
XCTAssertEqual(rows?.first?.column("b")?.string, "-123456.789123")
479-
XCTAssertEqual(rows?.first?.column("c")?.string, "3.14159265358979")
477+
XCTAssertEqual(rows?.first?.column("a")?.decimal, Decimal(string: "123456.789123")!)
478+
XCTAssertEqual(rows?.first?.column("b")?.decimal, Decimal(string: "-123456.789123")!)
479+
XCTAssertEqual(rows?.first?.column("c")?.decimal, Decimal(string: "3.14159265358979")!)
480+
}
481+
482+
func testDecimalStringSerialization() {
483+
var conn: PostgresConnection?
484+
XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait())
485+
defer { XCTAssertNoThrow( try conn?.close().wait() ) }
486+
487+
XCTAssertNoThrow(_ = try conn?.simpleQuery("DROP TABLE IF EXISTS \"table1\"").wait())
488+
XCTAssertNoThrow(_ = try conn?.simpleQuery("""
489+
CREATE TABLE table1 (
490+
"balance" text NOT NULL
491+
);
492+
""").wait())
493+
defer { XCTAssertNoThrow(_ = try conn?.simpleQuery("DROP TABLE \"table1\"").wait()) }
494+
495+
XCTAssertNoThrow(_ = try conn?.query("INSERT INTO table1 VALUES ($1)", [.init(decimal: Decimal(string: "123456.789123")!)]).wait())
496+
497+
var rows: PostgresQueryResult?
498+
XCTAssertNoThrow(rows = try conn?.query("""
499+
SELECT
500+
"balance"
501+
FROM table1
502+
""").wait())
503+
XCTAssertEqual(rows?.first?.column("balance")?.decimal, Decimal(string: "123456.789123")!)
480504
}
481505

482506
func testMoney() {
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import XCTest
2+
import NIOCore
3+
@testable import PostgresNIO
4+
5+
class Decimal_PSQLCodableTests: XCTestCase {
6+
7+
func testRoundTrip() {
8+
let values: [Decimal] = [1.1, .pi, -5e-12]
9+
10+
for value in values {
11+
var buffer = ByteBuffer()
12+
value.encode(into: &buffer, context: .forTests())
13+
XCTAssertEqual(value.psqlType, .numeric)
14+
let data = PSQLData(bytes: buffer, dataType: .numeric, format: .binary)
15+
16+
var result: Decimal?
17+
XCTAssertNoThrow(result = try data.decode(as: Decimal.self, context: .forTests()))
18+
XCTAssertEqual(value, result)
19+
}
20+
}
21+
22+
func testDecodeFailureInvalidType() {
23+
var buffer = ByteBuffer()
24+
buffer.writeInteger(Int64(0))
25+
let data = PSQLData(bytes: buffer, dataType: .int8, format: .binary)
26+
27+
XCTAssertThrowsError(try data.decode(as: Decimal.self, context: .forTests())) { error in
28+
XCTAssert(error is PSQLCastingError)
29+
}
30+
}
31+
32+
}

0 commit comments

Comments
 (0)