Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 45 additions & 46 deletions Sources/PostgreSQL/Codable/PostgreSQLValueEncoder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -20,48 +20,9 @@ struct PostgreSQLDataEncoder {
return try convertible.convertToPostgreSQLData()
}

do {
let encoder = _Encoder()
try encodable.encode(to: encoder)
if let data = encoder.data {
return data
} else {
let type: PostgreSQLDataFormat
if let present = encoder.array.first?.type {
type = present
} else if
let array = Swift.type(of: encodable) as? AnyArray.Type,
let psql = array.anyElementType as? PostgreSQLDataTypeStaticRepresentable.Type
{
if let format = psql.postgreSQLDataType.dataFormat {
type = format
} else {
WARNING("Could not determine PostgreSQL array data type: \(psql.postgreSQLDataType)")
type = .null
}
} else {
WARNING("Could not determine PostgreSQL array data type: \(Swift.type(of: encodable))")
type = .null
}
// encode array
var data = Data()
data += Data.of(Int32(1).bigEndian) // non-null
data += Data.of(Int32(0).bigEndian) // b
data += Data.of(type.raw.bigEndian)
data += Data.of(Int32(encoder.array.count).bigEndian) // length
data += Data.of(Int32(1).bigEndian) // dimensions

for element in encoder.array {
switch element.storage {
case .binary(let value):
data += Data.of(Int32(value.count).bigEndian)
data += value
default: data += Data.of(Int32(0).bigEndian)
}
}
return PostgreSQLData(type.arrayType ?? .null, binary: data)
}
} catch is _KeyedError {
let encoder = _Encoder()
try encodable.encode(to: encoder)
if encoder.keyedEncoding {
struct AnyEncodable: Encodable {
var encodable: Encodable
init(_ encodable: Encodable) {
Expand All @@ -73,6 +34,43 @@ struct PostgreSQLDataEncoder {
}
}
return try PostgreSQLData(.jsonb, binary: [0x01] + JSONEncoder().encode(AnyEncodable(encodable)))
} else if let data = encoder.data {
return data
} else {
let type: PostgreSQLDataFormat
if let present = encoder.array.first?.type {
type = present
} else if
let array = Swift.type(of: encodable) as? AnyArray.Type,
let psql = array.anyElementType as? PostgreSQLDataTypeStaticRepresentable.Type
{
if let format = psql.postgreSQLDataType.dataFormat {
type = format
} else {
WARNING("Could not determine PostgreSQL array data type: \(psql.postgreSQLDataType)")
type = .null
}
} else {
WARNING("Could not determine PostgreSQL array data type: \(Swift.type(of: encodable))")
type = .null
}
// encode array
var data = Data()
data += Data.of(Int32(1).bigEndian) // non-null
data += Data.of(Int32(0).bigEndian) // b
data += Data.of(type.raw.bigEndian)
data += Data.of(Int32(encoder.array.count).bigEndian) // length
data += Data.of(Int32(1).bigEndian) // dimensions

for element in encoder.array {
switch element.storage {
case .binary(let value):
data += Data.of(Int32(value.count).bigEndian)
data += value
default: data += Data.of(Int32(0).bigEndian)
}
}
return PostgreSQLData(type.arrayType ?? .null, binary: data)
}
}

Expand All @@ -84,10 +82,12 @@ struct PostgreSQLDataEncoder {
let userInfo: [CodingUserInfoKey: Any] = [:]
var data: PostgreSQLData?
var array: [PostgreSQLData]
var keyedEncoding: Bool

init() {
self.data = nil
self.array = []
self.keyedEncoding = false
}

func container<Key>(keyedBy type: Key.Type) -> KeyedEncodingContainer<Key> where Key : CodingKey {
Expand Down Expand Up @@ -163,21 +163,20 @@ struct PostgreSQLDataEncoder {
}
}

private struct _KeyedError: Error { }

private struct _KeyedEncodingContainer<Key>: KeyedEncodingContainerProtocol where Key: CodingKey {
let codingPath: [CodingKey] = []
let encoder: _Encoder
init(encoder: _Encoder) {
self.encoder = encoder
encoder.keyedEncoding = true
}

mutating func encodeNil(forKey key: Key) throws {
throw _KeyedError()
return
}

mutating func encode<T>(_ value: T, forKey key: Key) throws where T : Encodable {
throw _KeyedError()
return
}

mutating func nestedContainer<NestedKey>(keyedBy keyType: NestedKey.Type, forKey key: Key) -> KeyedEncodingContainer<NestedKey> where NestedKey: CodingKey {
Expand Down
7 changes: 7 additions & 0 deletions Tests/PostgreSQLTests/PostgreSQLConnectionTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,12 @@ class PostgreSQLConnectionTests: XCTestCase {
print(row)
}

func testRowCodableEmptyKeyed() throws {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add this to the allTests array at the bottom of the file?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops. I always forget allTests. Fixed.

let components = DateComponents()
let row = try PostgreSQLDataEncoder().encode(components)
XCTAssert(row.type == .jsonb)
}

func testRowCodableTypes() throws {
let conn = try PostgreSQLConnection.makeTest()

Expand Down Expand Up @@ -619,6 +625,7 @@ class PostgreSQLConnectionTests: XCTestCase {
("testDataDecoder", testDataDecoder),
("testRowDecoder", testRowDecoder),
("testRowCodableNested", testRowCodableNested),
("testRowCodableEmptyKeyed", testRowCodableEmptyKeyed),
("testRowCodableTypes", testRowCodableTypes),
("testTimeTz", testTimeTz),
("testListen", testListen),
Expand Down