diff --git a/Sources/ConnectionPoolModule/ConnectionLease.swift b/Sources/ConnectionPoolModule/ConnectionLease.swift new file mode 100644 index 00000000..77591a58 --- /dev/null +++ b/Sources/ConnectionPoolModule/ConnectionLease.swift @@ -0,0 +1,17 @@ +public struct ConnectionLease: Sendable { + public var connection: Connection + + @usableFromInline + let _release: @Sendable (Connection) -> () + + @inlinable + public init(connection: Connection, release: @escaping @Sendable (Connection) -> Void) { + self.connection = connection + self._release = release + } + + @inlinable + public func release() { + self._release(self.connection) + } +} diff --git a/Sources/ConnectionPoolModule/ConnectionPool.swift b/Sources/ConnectionPoolModule/ConnectionPool.swift index b460b263..ee72337d 100644 --- a/Sources/ConnectionPoolModule/ConnectionPool.swift +++ b/Sources/ConnectionPoolModule/ConnectionPool.swift @@ -88,7 +88,7 @@ public protocol ConnectionRequestProtocol: Sendable { /// A function that is called with a connection or a /// `PoolError`. - func complete(with: Result) + func complete(with: Result, ConnectionPoolError>) } @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) @@ -402,8 +402,11 @@ public final class ConnectionPool< /*private*/ func runRequestAction(_ action: StateMachine.RequestAction) { switch action { case .leaseConnection(let requests, let connection): + let lease = ConnectionLease(connection: connection) { connection in + self.releaseConnection(connection) + } for request in requests { - request.complete(with: .success(connection)) + request.complete(with: .success(lease)) } case .failRequest(let request, let error): diff --git a/Sources/ConnectionPoolModule/ConnectionPoolError.swift b/Sources/ConnectionPoolModule/ConnectionPoolError.swift index 1f1e1d2c..3abfe778 100644 --- a/Sources/ConnectionPoolModule/ConnectionPoolError.swift +++ b/Sources/ConnectionPoolModule/ConnectionPoolError.swift @@ -1,16 +1,25 @@ public struct ConnectionPoolError: Error, Hashable { - enum Base: Error, Hashable { + @usableFromInline + enum Base: Error, Hashable, Sendable { case requestCancelled case poolShutdown } - private let base: Base + @usableFromInline + let base: Base + @inlinable init(_ base: Base) { self.base = base } /// The connection requests got cancelled - public static let requestCancelled = ConnectionPoolError(.requestCancelled) + @inlinable + public static var requestCancelled: Self { + ConnectionPoolError(.requestCancelled) + } /// The connection requests can't be fulfilled as the pool has already been shutdown - public static let poolShutdown = ConnectionPoolError(.poolShutdown) + @inlinable + public static var poolShutdown: Self { + ConnectionPoolError(.poolShutdown) + } } diff --git a/Sources/ConnectionPoolModule/ConnectionRequest.swift b/Sources/ConnectionPoolModule/ConnectionRequest.swift index 1d1c55da..d6654a27 100644 --- a/Sources/ConnectionPoolModule/ConnectionRequest.swift +++ b/Sources/ConnectionPoolModule/ConnectionRequest.swift @@ -5,18 +5,18 @@ public struct ConnectionRequest: ConnectionRequest public var id: ID @usableFromInline - private(set) var continuation: CheckedContinuation + private(set) var continuation: CheckedContinuation, any Error> @inlinable init( id: Int, - continuation: CheckedContinuation + continuation: CheckedContinuation, any Error> ) { self.id = id self.continuation = continuation } - public func complete(with result: Result) { + public func complete(with result: Result, ConnectionPoolError>) { self.continuation.resume(with: result) } } @@ -46,7 +46,7 @@ extension ConnectionPool where Request == ConnectionRequest { } @inlinable - public func leaseConnection() async throws -> Connection { + public func leaseConnection() async throws -> ConnectionLease { let requestID = requestIDGenerator.next() let connection = try await withTaskCancellationHandler { @@ -54,7 +54,7 @@ extension ConnectionPool where Request == ConnectionRequest { throw CancellationError() } - return try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + return try await withCheckedThrowingContinuation { (continuation: CheckedContinuation, Error>) in let request = Request( id: requestID, continuation: continuation @@ -71,8 +71,8 @@ extension ConnectionPool where Request == ConnectionRequest { @inlinable public func withConnection(_ closure: (Connection) async throws -> Result) async throws -> Result { - let connection = try await self.leaseConnection() - defer { self.releaseConnection(connection) } - return try await closure(connection) + let lease = try await self.leaseConnection() + defer { lease.release() } + return try await closure(lease.connection) } } diff --git a/Sources/ConnectionPoolModule/TinyFastSequence.swift b/Sources/ConnectionPoolModule/TinyFastSequence.swift index dff8a30b..df140c98 100644 --- a/Sources/ConnectionPoolModule/TinyFastSequence.swift +++ b/Sources/ConnectionPoolModule/TinyFastSequence.swift @@ -29,6 +29,12 @@ struct TinyFastSequence: Sequence { self.base = .none(reserveCapacity: 0) case 1: self.base = .one(collection.first!, reserveCapacity: 0) + case 2: + self.base = .two( + collection.first!, + collection[collection.index(after: collection.startIndex)], + reserveCapacity: 0 + ) default: if let collection = collection as? Array { self.base = .n(collection) @@ -46,7 +52,7 @@ struct TinyFastSequence: Sequence { case 1: self.base = .one(max2Sequence.first!, reserveCapacity: 0) case 2: - self.base = .n(Array(max2Sequence)) + self.base = .two(max2Sequence.first!, max2Sequence.second!, reserveCapacity: 0) default: fatalError() } @@ -169,7 +175,7 @@ struct TinyFastSequence: Sequence { case .n(let array): if self.index < array.endIndex { - defer { self.index += 1} + defer { self.index += 1 } return array[self.index] } return nil diff --git a/Sources/ConnectionPoolTestUtils/MockRequest.swift b/Sources/ConnectionPoolTestUtils/MockRequest.swift index 5e4e2fc0..3dd8b0fb 100644 --- a/Sources/ConnectionPoolTestUtils/MockRequest.swift +++ b/Sources/ConnectionPoolTestUtils/MockRequest.swift @@ -1,8 +1,6 @@ import _ConnectionPoolModule -public final class MockRequest: ConnectionRequestProtocol, Hashable, Sendable { - public typealias Connection = MockConnection - +public final class MockRequest: ConnectionRequestProtocol, Hashable, Sendable { public struct ID: Hashable, Sendable { var objectID: ObjectIdentifier @@ -11,7 +9,7 @@ public final class MockRequest: ConnectionRequestProtocol, Hashable, Sendable { } } - public init() {} + public init(connectionType: Connection.Type = Connection.self) {} public var id: ID { ID(self) } @@ -23,7 +21,7 @@ public final class MockRequest: ConnectionRequestProtocol, Hashable, Sendable { hasher.combine(self.id) } - public func complete(with: Result) { + public func complete(with: Result, ConnectionPoolError>) { } } diff --git a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift index 9d264bcc..8560b948 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift @@ -752,6 +752,12 @@ struct ConnectionStateMachine { return self.modify(with: action) } + mutating func copyInResponseReceived( + _ copyInResponse: PostgresBackendMessage.CopyInResponse + ) -> ConnectionAction { + return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.copyInResponse(copyInResponse))) + } + mutating func emptyQueryResponseReceived() -> ConnectionAction { guard case .extendedQuery(var queryState, let connectionContext) = self.state, !queryState.isComplete else { return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.emptyQueryResponse)) diff --git a/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift index 087a6c24..5708b6b9 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift @@ -91,7 +91,7 @@ struct ExtendedQueryStateMachine { mutating func cancel() -> Action { switch self.state { case .initialized: - preconditionFailure("Start must be called immediatly after the query was created") + preconditionFailure("Start must be called immediately after the query was created") case .messagesSent(let queryContext), .parseCompleteReceived(let queryContext), @@ -322,6 +322,12 @@ struct ExtendedQueryStateMachine { } } + mutating func copyInResponseReceived( + _ copyInResponse: PostgresBackendMessage.CopyInResponse + ) -> Action { + return self.setAndFireError(.unexpectedBackendMessage(.copyInResponse(copyInResponse))) + } + mutating func emptyQueryResponseReceived() -> Action { guard case .bindCompleteReceived(let queryContext) = self.state else { return self.setAndFireError(.unexpectedBackendMessage(.emptyQueryResponse)) diff --git a/Sources/PostgresNIO/New/Data/String+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/String+PostgresCodable.swift index 6bd09e78..7e8376a7 100644 --- a/Sources/PostgresNIO/New/Data/String+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/String+PostgresCodable.swift @@ -29,6 +29,12 @@ extension String: PostgresDecodable { context: PostgresDecodingContext ) throws { switch (format, type) { + case (.binary, .jsonb): + // Discard the version byte + guard let version = buffer.readInteger(as: UInt8.self), version == 1 else { + throw PostgresDecodingError.Code.failure + } + self = buffer.readString(length: buffer.readableBytes)! case (_, .varchar), (_, .bpchar), (_, .text), diff --git a/Sources/PostgresNIO/New/Messages/CopyInMessage.swift b/Sources/PostgresNIO/New/Messages/CopyInMessage.swift new file mode 100644 index 00000000..46dec648 --- /dev/null +++ b/Sources/PostgresNIO/New/Messages/CopyInMessage.swift @@ -0,0 +1,44 @@ +extension PostgresBackendMessage { + struct CopyInResponse: Hashable { + enum Format: Int8 { + case textual = 0 + case binary = 1 + } + + let format: Format + let columnFormats: [Format] + + static func decode(from buffer: inout ByteBuffer) throws -> Self { + guard let rawFormat = buffer.readInteger(endianness: .big, as: Int8.self) else { + throw PSQLPartialDecodingError.expectedAtLeastNRemainingBytes(1, actual: buffer.readableBytes) + } + guard let format = Format(rawValue: rawFormat) else { + throw PSQLPartialDecodingError.unexpectedValue(value: rawFormat) + } + + guard let numColumns = buffer.readInteger(endianness: .big, as: Int16.self) else { + throw PSQLPartialDecodingError.expectedAtLeastNRemainingBytes(2, actual: buffer.readableBytes) + } + var columnFormatCodes: [Format] = [] + columnFormatCodes.reserveCapacity(Int(numColumns)) + + for _ in 0.. Self { + return PSQLPartialDecodingError( + description: "Unknown message kind: \(messageID)", + file: file, line: line) + } } extension ByteBuffer { diff --git a/Sources/PostgresNIO/New/PostgresChannelHandler.swift b/Sources/PostgresNIO/New/PostgresChannelHandler.swift index 0a14849a..baf801e5 100644 --- a/Sources/PostgresNIO/New/PostgresChannelHandler.swift +++ b/Sources/PostgresNIO/New/PostgresChannelHandler.swift @@ -136,6 +136,8 @@ final class PostgresChannelHandler: ChannelDuplexHandler { action = self.state.closeCompletedReceived() case .commandComplete(let commandTag): action = self.state.commandCompletedReceived(commandTag) + case .copyInResponse(let copyInResponse): + action = self.state.copyInResponseReceived(copyInResponse) case .dataRow(let dataRow): action = self.state.dataRowReceived(dataRow) case .emptyQueryResponse: diff --git a/Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift b/Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift index 97805418..6ca4cc27 100644 --- a/Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift +++ b/Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift @@ -167,6 +167,28 @@ struct PostgresFrontendMessageEncoder { self.buffer.writeMultipleIntegers(UInt32(8), Self.sslRequestCode) } + /// Adds the `CopyData` message ID and `dataLength` to the message buffer but not the actual data. + /// + /// The caller of this function is expected to write the encoder's message buffer to the backend after calling this + /// function, followed by sending the actual data to the backend. + mutating func copyDataHeader(dataLength: UInt32) { + self.clearIfNeeded() + self.buffer.psqlWriteMultipleIntegers(id: .copyData, length: dataLength) + } + + mutating func copyDone() { + self.clearIfNeeded() + self.buffer.psqlWriteMultipleIntegers(id: .copyDone, length: 0) + } + + mutating func copyFail(message: String) { + self.clearIfNeeded() + var messageBuffer = ByteBuffer() + messageBuffer.writeNullTerminatedString(message) + self.buffer.psqlWriteMultipleIntegers(id: .copyFail, length: UInt32(messageBuffer.readableBytes)) + self.buffer.writeImmutableBuffer(messageBuffer) + } + mutating func sync() { self.clearIfNeeded() self.buffer.psqlWriteMultipleIntegers(id: .sync, length: 0) @@ -197,6 +219,9 @@ struct PostgresFrontendMessageEncoder { private enum FrontendMessageID: UInt8, Hashable, Sendable { case bind = 66 // B case close = 67 // C + case copyData = 100 // d + case copyDone = 99 // c + case copyFail = 102 // f case describe = 68 // D case execute = 69 // E case flush = 72 // H diff --git a/Sources/PostgresNIO/Pool/PostgresClient.swift b/Sources/PostgresNIO/Pool/PostgresClient.swift index d54e34eb..0279be07 100644 --- a/Sources/PostgresNIO/Pool/PostgresClient.swift +++ b/Sources/PostgresNIO/Pool/PostgresClient.swift @@ -301,11 +301,11 @@ public final class PostgresClient: Sendable, ServiceLifecycle.Service { /// - Returns: The closure's return value. @_disfavoredOverload public func withConnection(_ closure: (PostgresConnection) async throws -> Result) async throws -> Result { - let connection = try await self.leaseConnection() + let lease = try await self.leaseConnection() - defer { self.pool.releaseConnection(connection) } + defer { lease.release() } - return try await closure(connection) + return try await closure(lease.connection) } #if compiler(>=6.0) @@ -319,11 +319,11 @@ public final class PostgresClient: Sendable, ServiceLifecycle.Service { // DO NOT FIX THE WHITESPACE IN THE NEXT LINE UNTIL 5.10 IS UNSUPPORTED // https://github.com/swiftlang/swift/issues/79285 _ closure: (PostgresConnection) async throws -> sending Result) async throws -> sending Result { - let connection = try await self.leaseConnection() + let lease = try await self.leaseConnection() - defer { self.pool.releaseConnection(connection) } + defer { lease.release() } - return try await closure(connection) + return try await closure(lease.connection) } /// Lease a connection, which is in an open transaction state, for the provided `closure`'s lifetime. @@ -404,7 +404,8 @@ public final class PostgresClient: Sendable, ServiceLifecycle.Service { throw PSQLError(code: .tooManyParameters, query: query, file: file, line: line) } - let connection = try await self.leaseConnection() + let lease = try await self.leaseConnection() + let connection = lease.connection var logger = logger logger[postgresMetadataKey: .connectionID] = "\(connection.id)" @@ -419,12 +420,12 @@ public final class PostgresClient: Sendable, ServiceLifecycle.Service { connection.channel.write(HandlerTask.extendedQuery(context), promise: nil) promise.futureResult.whenFailure { _ in - self.pool.releaseConnection(connection) + lease.release() } return try await promise.futureResult.map { $0.asyncSequence(onFinish: { - self.pool.releaseConnection(connection) + lease.release() }) }.get() } catch var error as PSQLError { @@ -446,7 +447,8 @@ public final class PostgresClient: Sendable, ServiceLifecycle.Service { let logger = logger ?? Self.loggingDisabled do { - let connection = try await self.leaseConnection() + let lease = try await self.leaseConnection() + let connection = lease.connection let promise = connection.channel.eventLoop.makePromise(of: PSQLRowStream.self) let task = HandlerTask.executePreparedStatement(.init( @@ -460,11 +462,11 @@ public final class PostgresClient: Sendable, ServiceLifecycle.Service { connection.channel.write(task, promise: nil) promise.futureResult.whenFailure { _ in - self.pool.releaseConnection(connection) + lease.release() } return try await promise.futureResult - .map { $0.asyncSequence(onFinish: { self.pool.releaseConnection(connection) }) } + .map { $0.asyncSequence(onFinish: { lease.release() }) } .get() .map { try preparedStatement.decodeRow($0) } } catch var error as PSQLError { @@ -504,7 +506,7 @@ public final class PostgresClient: Sendable, ServiceLifecycle.Service { // MARK: - Private Methods - - private func leaseConnection() async throws -> PostgresConnection { + private func leaseConnection() async throws -> ConnectionLease { if !self.runningAtomic.load(ordering: .relaxed) { self.backgroundLogger.warning("Trying to lease connection from `PostgresClient`, but `PostgresClient.run()` hasn't been called yet.") } diff --git a/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift b/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift index c745b4a0..c1ba89cb 100644 --- a/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift +++ b/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift @@ -39,15 +39,13 @@ final class ConnectionPoolTests: XCTestCase { do { for _ in 0..<1000 { async let connectionFuture = try await pool.leaseConnection() - var leasedConnection: MockConnection? + var connectionLease: ConnectionLease? XCTAssertEqual(factory.pendingConnectionAttemptsCount, 0) - leasedConnection = try await connectionFuture - XCTAssertNotNil(leasedConnection) - XCTAssert(createdConnection === leasedConnection) + connectionLease = try await connectionFuture + XCTAssertNotNil(connectionLease) + XCTAssert(createdConnection === connectionLease?.connection) - if let leasedConnection { - pool.releaseConnection(leasedConnection) - } + connectionLease?.release() } } catch { XCTFail("Unexpected error: \(error)") @@ -195,8 +193,8 @@ final class ConnectionPoolTests: XCTestCase { for _ in 0..]() for request in requests { let connection = try await request.future.success - connections.append(connection) + connectionLeases.append(connection) } // Ensure that we got 4 distinct connections - XCTAssertEqual(Set(connections.lazy.map(\.id)).count, 4) + XCTAssertEqual(Set(connectionLeases.lazy.map(\.connection.id)).count, 4) // release all 4 leased connections - for connection in connections { - pool.releaseConnection(connection) + for lease in connectionLeases { + lease.release() } // shutdown @@ -727,7 +725,7 @@ final class ConnectionPoolTests: XCTestCase { // create 4 connection requests let requests = (0..<10).map { ConnectionFuture(id: $0) } pool.leaseConnections(requests) - var connections = [MockConnection]() + var connectionLeases = [ConnectionLease]() await factory.nextConnectAttempt { connectionID in return 10 @@ -735,15 +733,15 @@ final class ConnectionPoolTests: XCTestCase { for request in requests { let connection = try await request.future.success - connections.append(connection) + connectionLeases.append(connection) } // Ensure that all requests got the same connection - XCTAssertEqual(Set(connections.lazy.map(\.id)).count, 1) + XCTAssertEqual(Set(connectionLeases.lazy.map(\.connection.id)).count, 1) // release all 10 leased streams - for connection in connections { - pool.releaseConnection(connection) + for lease in connectionLeases { + lease.release() } for _ in 0..<9 { @@ -792,41 +790,41 @@ final class ConnectionPoolTests: XCTestCase { // create 4 connection requests var requests = (0..<21).map { ConnectionFuture(id: $0) } pool.leaseConnections(requests) - var connections = [MockConnection]() + var connectionLease = [ConnectionLease]() await factory.nextConnectAttempt { connectionID in return 1 } - let connection = try await requests.first!.future.success - connections.append(connection) + let lease = try await requests.first!.future.success + connectionLease.append(lease) requests.removeFirst() - pool.connectionReceivedNewMaxStreamSetting(connection, newMaxStreamSetting: 21) + pool.connectionReceivedNewMaxStreamSetting(lease.connection, newMaxStreamSetting: 21) for (_, request) in requests.enumerated() { let connection = try await request.future.success - connections.append(connection) + connectionLease.append(connection) } // Ensure that all requests got the same connection - XCTAssertEqual(Set(connections.lazy.map(\.id)).count, 1) + XCTAssertEqual(Set(connectionLease.lazy.map(\.connection.id)).count, 1) requests = (22..<42).map { ConnectionFuture(id: $0) } pool.leaseConnections(requests) // release all 21 leased streams in a single call - pool.releaseConnection(connection, streams: 21) + pool.releaseConnection(lease.connection, streams: 21) // ensure all 20 new requests got fulfilled for request in requests { let connection = try await request.future.success - connections.append(connection) + connectionLease.append(connection) } // release all 20 leased streams one by one for _ in requests { - pool.releaseConnection(connection, streams: 1) + pool.releaseConnection(lease.connection, streams: 1) } // shutdown @@ -840,14 +838,14 @@ final class ConnectionPoolTests: XCTestCase { struct ConnectionFuture: ConnectionRequestProtocol { let id: Int - let future: Future + let future: Future> init(id: Int) { self.id = id - self.future = Future(of: MockConnection.self) + self.future = Future(of: ConnectionLease.self) } - func complete(with result: Result) { + func complete(with result: Result, ConnectionPoolError>) { switch result { case .success(let success): self.future.yield(value: success) diff --git a/Tests/ConnectionPoolModuleTests/ConnectionRequestTests.swift b/Tests/ConnectionPoolModuleTests/ConnectionRequestTests.swift index 537efbd9..2952bf8b 100644 --- a/Tests/ConnectionPoolModuleTests/ConnectionRequestTests.swift +++ b/Tests/ConnectionPoolModuleTests/ConnectionRequestTests.swift @@ -6,13 +6,14 @@ final class ConnectionRequestTests: XCTestCase { func testHappyPath() async throws { let mockConnection = MockConnection(id: 1) - let connection = try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + let lease = try await withCheckedThrowingContinuation { (continuation: CheckedContinuation, any Error>) in let request = ConnectionRequest(id: 42, continuation: continuation) XCTAssertEqual(request.id, 42) - continuation.resume(with: .success(mockConnection)) + let lease = ConnectionLease(connection: mockConnection) { _ in } + continuation.resume(with: .success(lease)) } - XCTAssert(connection === mockConnection) + XCTAssert(lease.connection === mockConnection) } func testSadPath() async throws { diff --git a/Tests/ConnectionPoolModuleTests/PoolStateMachine+RequestQueueTests.swift b/Tests/ConnectionPoolModuleTests/PoolStateMachine+RequestQueueTests.swift index b74b86cc..ddd6a71e 100644 --- a/Tests/ConnectionPoolModuleTests/PoolStateMachine+RequestQueueTests.swift +++ b/Tests/ConnectionPoolModuleTests/PoolStateMachine+RequestQueueTests.swift @@ -11,7 +11,7 @@ final class PoolStateMachine_RequestQueueTests: XCTestCase { var queue = TestQueue() XCTAssert(queue.isEmpty) - let request1 = MockRequest() + let request1 = MockRequest(connectionType: MockConnection.self) queue.queue(request1) XCTAssertEqual(queue.count, 1) XCTAssertFalse(queue.isEmpty) @@ -25,11 +25,11 @@ final class PoolStateMachine_RequestQueueTests: XCTestCase { var queue = TestQueue() XCTAssert(queue.isEmpty) - var request1 = MockRequest() + var request1 = MockRequest(connectionType: MockConnection.self) queue.queue(request1) - var request2 = MockRequest() + var request2 = MockRequest(connectionType: MockConnection.self) queue.queue(request2) - var request3 = MockRequest() + var request3 = MockRequest(connectionType: MockConnection.self) queue.queue(request3) do { @@ -49,11 +49,11 @@ final class PoolStateMachine_RequestQueueTests: XCTestCase { var queue = TestQueue() XCTAssert(queue.isEmpty) - var request1 = MockRequest() + var request1 = MockRequest(connectionType: MockConnection.self) queue.queue(request1) - var request2 = MockRequest() + var request2 = MockRequest(connectionType: MockConnection.self) queue.queue(request2) - var request3 = MockRequest() + var request3 = MockRequest(connectionType: MockConnection.self) queue.queue(request3) do { @@ -76,11 +76,11 @@ final class PoolStateMachine_RequestQueueTests: XCTestCase { var queue = TestQueue() XCTAssert(queue.isEmpty) - var request1 = MockRequest() + var request1 = MockRequest(connectionType: MockConnection.self) queue.queue(request1) - var request2 = MockRequest() + var request2 = MockRequest(connectionType: MockConnection.self) queue.queue(request2) - var request3 = MockRequest() + var request3 = MockRequest(connectionType: MockConnection.self) queue.queue(request3) do { @@ -113,11 +113,11 @@ final class PoolStateMachine_RequestQueueTests: XCTestCase { var queue = TestQueue() XCTAssert(queue.isEmpty) - var request1 = MockRequest() + var request1 = MockRequest(connectionType: MockConnection.self) queue.queue(request1) - var request2 = MockRequest() + var request2 = MockRequest(connectionType: MockConnection.self) queue.queue(request2) - var request3 = MockRequest() + var request3 = MockRequest(connectionType: MockConnection.self) queue.queue(request3) do { diff --git a/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift b/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift index c0b6ddcd..08afdf8e 100644 --- a/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift +++ b/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift @@ -7,8 +7,8 @@ typealias TestPoolStateMachine = PoolStateMachine< MockConnection, ConnectionIDGenerator, MockConnection.ID, - MockRequest, - MockRequest.ID, + MockRequest, + MockRequest.ID, MockTimerCancellationToken > @@ -75,7 +75,7 @@ final class PoolStateMachineTests: XCTestCase { XCTAssertEqual(createdAction1.connection, .scheduleTimers([])) // lease connection 1 - let request1 = MockRequest() + let request1 = MockRequest(connectionType: MockConnection.self) let leaseRequest1 = stateMachine.leaseConnection(request1) XCTAssertEqual(leaseRequest1.connection, .cancelTimers([])) XCTAssertEqual(leaseRequest1.request, .leaseConnection(.init(element: request1), connection1)) @@ -84,13 +84,13 @@ final class PoolStateMachineTests: XCTestCase { XCTAssertEqual(stateMachine.releaseConnection(connection1, streams: 1), .none()) // lease connection 1 - let request2 = MockRequest() + let request2 = MockRequest(connectionType: MockConnection.self) let leaseRequest2 = stateMachine.leaseConnection(request2) XCTAssertEqual(leaseRequest2.connection, .cancelTimers([])) XCTAssertEqual(leaseRequest2.request, .leaseConnection(.init(element: request2), connection1)) // request connection while none is available - let request3 = MockRequest() + let request3 = MockRequest(connectionType: MockConnection.self) let leaseRequest3 = stateMachine.leaseConnection(request3) XCTAssertEqual(leaseRequest3.connection, .makeConnection(.init(connectionID: 1), [])) XCTAssertEqual(leaseRequest3.request, .none) @@ -132,7 +132,7 @@ final class PoolStateMachineTests: XCTestCase { XCTAssertEqual(requests.count, 0) // request connection while none exists - let request1 = MockRequest() + let request1 = MockRequest(connectionType: MockConnection.self) let leaseRequest1 = stateMachine.leaseConnection(request1) XCTAssertEqual(leaseRequest1.connection, .makeConnection(.init(connectionID: 0), [])) XCTAssertEqual(leaseRequest1.request, .none) @@ -144,7 +144,7 @@ final class PoolStateMachineTests: XCTestCase { XCTAssertEqual(createdAction1.connection, .none) // request connection while none is available - let request2 = MockRequest() + let request2 = MockRequest(connectionType: MockConnection.self) let leaseRequest2 = stateMachine.leaseConnection(request2) XCTAssertEqual(leaseRequest2.connection, .makeConnection(.init(connectionID: 1), [])) XCTAssertEqual(leaseRequest2.request, .none) @@ -195,13 +195,13 @@ final class PoolStateMachineTests: XCTestCase { XCTAssertEqual(createdAction1.connection, .scheduleTimers([])) // lease connection 1 - let request1 = MockRequest() + let request1 = MockRequest(connectionType: MockConnection.self) let leaseRequest1 = stateMachine.leaseConnection(request1) XCTAssertEqual(leaseRequest1.connection, .cancelTimers([])) XCTAssertEqual(leaseRequest1.request, .leaseConnection(.init(element: request1), connection1)) // request connection while none is available - let request2 = MockRequest() + let request2 = MockRequest(connectionType: MockConnection.self) let leaseRequest2 = stateMachine.leaseConnection(request2) XCTAssertEqual(leaseRequest2.connection, .makeConnection(.init(connectionID: 1), [])) XCTAssertEqual(leaseRequest2.request, .none) @@ -245,7 +245,7 @@ final class PoolStateMachineTests: XCTestCase { XCTAssertEqual(requests.count, 0) // request connection while none exists - let request1 = MockRequest() + let request1 = MockRequest(connectionType: MockConnection.self) let leaseRequest1 = stateMachine.leaseConnection(request1) XCTAssertEqual(leaseRequest1.connection, .makeConnection(.init(connectionID: 0), [])) XCTAssertEqual(leaseRequest1.request, .none) @@ -287,7 +287,7 @@ final class PoolStateMachineTests: XCTestCase { XCTAssertEqual(requests.count, 0) // request connection while none exists - let request1 = MockRequest() + let request1 = MockRequest(connectionType: MockConnection.self) let leaseRequest1 = stateMachine.leaseConnection(request1) XCTAssertEqual(leaseRequest1.connection, .makeConnection(.init(connectionID: 0), [])) XCTAssertEqual(leaseRequest1.request, .none) @@ -309,7 +309,7 @@ final class PoolStateMachineTests: XCTestCase { connection1.closeIfClosing() // request connection while none exists anymore - let request2 = MockRequest() + let request2 = MockRequest(connectionType: MockConnection.self) let leaseRequest2 = stateMachine.leaseConnection(request2) XCTAssertEqual(leaseRequest2.connection, .makeConnection(.init(connectionID: 1), [])) XCTAssertEqual(leaseRequest2.request, .none) @@ -354,7 +354,7 @@ final class PoolStateMachineTests: XCTestCase { XCTAssertEqual(requests.count, 1) // one connection should exist - let request = MockRequest() + let request = MockRequest(connectionType: MockConnection.self) let leaseRequest = stateMachine.leaseConnection(request) XCTAssertEqual(leaseRequest.connection, .none) XCTAssertEqual(leaseRequest.request, .none) diff --git a/Tests/ConnectionPoolModuleTests/TinyFastSequenceTests.swift b/Tests/ConnectionPoolModuleTests/TinyFastSequenceTests.swift index 1a2836b9..b2f04544 100644 --- a/Tests/ConnectionPoolModuleTests/TinyFastSequenceTests.swift +++ b/Tests/ConnectionPoolModuleTests/TinyFastSequenceTests.swift @@ -47,10 +47,18 @@ final class TinyFastSequenceTests: XCTestCase { var twoElemSequence = TinyFastSequence([1, 2]) twoElemSequence.reserveCapacity(8) + twoElemSequence.append(3) guard case .n(let array) = twoElemSequence.base else { return XCTFail("Expected sequence to be backed by an array") } XCTAssertEqual(array.capacity, 8) + + var threeElemSequence = TinyFastSequence([1, 2, 3]) + threeElemSequence.reserveCapacity(8) + guard case .n(let array) = threeElemSequence.base else { + return XCTFail("Expected sequence to be backed by an array") + } + XCTAssertEqual(array.capacity, 8) } func testNewSequenceSlowPath() { diff --git a/Tests/IntegrationTests/PostgresClientTests.swift b/Tests/IntegrationTests/PostgresClientTests.swift index eaf3663f..9ac92754 100644 --- a/Tests/IntegrationTests/PostgresClientTests.swift +++ b/Tests/IntegrationTests/PostgresClientTests.swift @@ -338,7 +338,7 @@ final class PostgresClientTests: XCTestCase { ) var count = 0 - for try await (id, label) in rows.decode((Int, String).self) { + for try await _ in rows.decode((Int, String).self) { count += 1 } XCTAssertEqual(count, 1) diff --git a/Tests/IntegrationTests/PostgresNIOTests.swift b/Tests/IntegrationTests/PostgresNIOTests.swift index 9a58f050..5d27e36a 100644 --- a/Tests/IntegrationTests/PostgresNIOTests.swift +++ b/Tests/IntegrationTests/PostgresNIOTests.swift @@ -930,6 +930,22 @@ final class PostgresNIOTests: XCTestCase { } } + func testJSONBDecodeString() { + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow(try conn?.close().wait()) } + + do { + var rows: PostgresQueryResult? + XCTAssertNoThrow(rows = try conn?.query("select '{\"hello\": \"world\"}'::jsonb as data").wait()) + + var resultString: String? + XCTAssertNoThrow(resultString = try rows?.first?.decode(String.self, context: .default)) + + XCTAssertEqual(resultString, "{\"hello\": \"world\"}") + } + } + func testInt4RangeSerialize() async throws { let conn: PostgresConnection = try await PostgresConnection.test(on: eventLoop).get() self.addTeardownBlock { diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift index ae484acc..872664af 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift @@ -114,7 +114,7 @@ class ExtendedQueryStateMachineTests: XCTestCase { .failQuery(promise, with: psqlError, cleanupContext: .init(action: .close, tasks: [], error: psqlError, closePromise: nil))) } - func testExtendedQueryIsCancelledImmediatly() { + func testExtendedQueryIsCancelledImmediately() { var state = ConnectionStateMachine.readyForQuery() let logger = Logger.psqlTest diff --git a/Tests/PostgresNIOTests/New/Data/String+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/String+PSQLCodableTests.swift index aadeabff..c1843c2a 100644 --- a/Tests/PostgresNIOTests/New/Data/String+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/String+PSQLCodableTests.swift @@ -52,4 +52,15 @@ class String_PSQLCodableTests: XCTestCase { XCTAssertEqual($0 as? PostgresDecodingError.Code, .failure) } } + + func testDecodeFromJSONB() { + let json = #"{"hello": "world"}"# + var buffer = ByteBuffer() + buffer.writeInteger(UInt8(1)) + buffer.writeString(json) + + var decoded: String? + XCTAssertNoThrow(decoded = try String(from: &buffer, type: .jsonb, format: .binary, context: .default)) + XCTAssertEqual(decoded, json) + } } diff --git a/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessageEncoder.swift b/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessageEncoder.swift index 9614bf1e..0c6b37ef 100644 --- a/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessageEncoder.swift +++ b/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessageEncoder.swift @@ -28,6 +28,8 @@ struct PSQLBackendMessageEncoder: MessageToByteEncoder { case .commandComplete(let string): self.encode(messageID: message.id, payload: StringPayload(string), into: &buffer) + case .copyInResponse(let copyInResponse): + self.encode(messageID: message.id, payload: copyInResponse, into: &buffer) case .dataRow(let row): self.encode(messageID: message.id, payload: row, into: &buffer) @@ -99,6 +101,8 @@ extension PostgresBackendMessage { return .closeComplete case .commandComplete: return .commandComplete + case .copyInResponse: + return .copyInResponse case .dataRow: return .dataRow case .emptyQueryResponse: @@ -184,6 +188,16 @@ extension PostgresBackendMessage.BackendKeyData: PSQLMessagePayloadEncodable { } } +extension PostgresBackendMessage.CopyInResponse: PSQLMessagePayloadEncodable { + public func encode(into buffer: inout ByteBuffer) { + buffer.writeInteger(Int8(self.format.rawValue)) + buffer.writeInteger(Int16(self.columnFormats.count)) + for columnFormat in columnFormats { + buffer.writeInteger(Int16(columnFormat.rawValue)) + } + } +} + extension DataRow: PSQLMessagePayloadEncodable { public func encode(into buffer: inout ByteBuffer) { buffer.writeInteger(self.columnCount, as: Int16.self) diff --git a/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift b/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift index 55ccd0a9..d913da22 100644 --- a/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift +++ b/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift @@ -168,6 +168,18 @@ extension PostgresFrontendMessage { ) ) + case .copyData: + return .copyData(CopyData(data: buffer)) + + case .copyDone: + return .copyDone + + case .copyFail: + guard let message = buffer.readNullTerminatedString() else { + throw PSQLPartialDecodingError.fieldNotDecodable(type: String.self) + } + return .copyFail(CopyFail(message: message)) + case .close: preconditionFailure("TODO: Unimplemented") diff --git a/Tests/PostgresNIOTests/New/Extensions/PostgresFrontendMessage.swift b/Tests/PostgresNIOTests/New/Extensions/PostgresFrontendMessage.swift index 2532959a..5fc8144b 100644 --- a/Tests/PostgresNIOTests/New/Extensions/PostgresFrontendMessage.swift +++ b/Tests/PostgresNIOTests/New/Extensions/PostgresFrontendMessage.swift @@ -36,6 +36,14 @@ enum PostgresFrontendMessage: Equatable { let secretKey: Int32 } + struct CopyData: Hashable { + var data: ByteBuffer + } + + struct CopyFail: Hashable { + var message: String + } + enum Close: Hashable { case preparedStatement(String) case portal(String) @@ -170,6 +178,9 @@ enum PostgresFrontendMessage: Equatable { case bind(Bind) case cancel(Cancel) + case copyData(CopyData) + case copyDone + case copyFail(CopyFail) case close(Close) case describe(Describe) case execute(Execute) @@ -186,6 +197,9 @@ enum PostgresFrontendMessage: Equatable { enum ID: UInt8, Equatable { case bind + case copyData + case copyDone + case copyFail case close case describe case execute @@ -201,12 +215,18 @@ enum PostgresFrontendMessage: Equatable { switch rawValue { case UInt8(ascii: "B"): self = .bind + case UInt8(ascii: "c"): + self = .copyDone case UInt8(ascii: "C"): self = .close + case UInt8(ascii: "d"): + self = .copyData case UInt8(ascii: "D"): self = .describe case UInt8(ascii: "E"): self = .execute + case UInt8(ascii: "f"): + self = .copyFail case UInt8(ascii: "H"): self = .flush case UInt8(ascii: "P"): @@ -230,6 +250,12 @@ enum PostgresFrontendMessage: Equatable { switch self { case .bind: return UInt8(ascii: "B") + case .copyData: + return UInt8(ascii: "d") + case .copyDone: + return UInt8(ascii: "c") + case .copyFail: + return UInt8(ascii: "f") case .close: return UInt8(ascii: "C") case .describe: @@ -263,6 +289,12 @@ extension PostgresFrontendMessage { return .bind case .cancel: preconditionFailure("Cancel messages don't have an identifier") + case .copyData: + return .copyData + case .copyDone: + return .copyDone + case .copyFail: + return .copyFail case .close: return .close case .describe: diff --git a/Tests/PostgresNIOTests/New/Messages/CopyTests.swift b/Tests/PostgresNIOTests/New/Messages/CopyTests.swift new file mode 100644 index 00000000..de686ae5 --- /dev/null +++ b/Tests/PostgresNIOTests/New/Messages/CopyTests.swift @@ -0,0 +1,147 @@ +import XCTest +import NIOCore +import NIOTestUtils +@testable import PostgresNIO + +class CopyTests: XCTestCase { + func testDecodeCopyInResponseMessage() throws { + let expected: [PostgresBackendMessage] = [ + .copyInResponse(.init(format: .textual, columnFormats: [.textual, .textual])), + .copyInResponse(.init(format: .binary, columnFormats: [.binary, .binary])), + .copyInResponse(.init(format: .binary, columnFormats: [.textual, .binary])) + ] + + var buffer = ByteBuffer() + + for message in expected { + guard case .copyInResponse(let message) = message else { + return XCTFail("Expected only to get copyInResponse here!") + } + buffer.writeBackendMessage(id: .copyInResponse ) { buffer in + buffer.writeInteger(Int8(message.format.rawValue)) + buffer.writeInteger(Int16(message.columnFormats.count)) + for columnFormat in message.columnFormats { + buffer.writeInteger(UInt16(columnFormat.rawValue)) + } + } + } + try ByteToMessageDecoderVerifier.verifyDecoder( + inputOutputPairs: [(buffer, expected)], + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) } + ) + } + + func testDecodeFailureBecauseOfEmptyMessage() { + var buffer = ByteBuffer() + buffer.writeBackendMessage(id: .copyInResponse) { _ in} + + XCTAssertThrowsError( + try ByteToMessageDecoderVerifier.verifyDecoder( + inputOutputPairs: [(buffer, [])], + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) } + ) + ) { + XCTAssert($0 is PostgresMessageDecodingError) + } + } + + + func testDecodeFailureBecauseOfInvalidFormat() { + var buffer = ByteBuffer() + buffer.writeBackendMessage(id: .copyInResponse) { buffer in + buffer.writeInteger(Int8(20)) // Only 0 and 1 are valid formats + } + + XCTAssertThrowsError( + try ByteToMessageDecoderVerifier.verifyDecoder( + inputOutputPairs: [(buffer, [])], + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) } + ) + ) { + XCTAssert($0 is PostgresMessageDecodingError) + } + } + + func testDecodeFailureBecauseOfMissingColumnNumber() { + var buffer = ByteBuffer() + buffer.writeBackendMessage(id: .copyInResponse) { buffer in + buffer.writeInteger(Int8(0)) + } + + XCTAssertThrowsError( + try ByteToMessageDecoderVerifier.verifyDecoder( + inputOutputPairs: [(buffer, [])], + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) } + ) + ) { + XCTAssert($0 is PostgresMessageDecodingError) + } + } + + + func testDecodeFailureBecauseOfMissingColumns() { + var buffer = ByteBuffer() + buffer.writeBackendMessage(id: .copyInResponse) { buffer in + buffer.writeInteger(Int8(0)) + buffer.writeInteger(Int16(20)) // 20 columns promised, none given + } + + XCTAssertThrowsError( + try ByteToMessageDecoderVerifier.verifyDecoder( + inputOutputPairs: [(buffer, [])], + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) } + ) + ) { + XCTAssert($0 is PostgresMessageDecodingError) + } + } + + func testDecodeFailureBecauseOfInvalidColumnFormat() { + var buffer = ByteBuffer() + buffer.writeBackendMessage(id: .copyInResponse) { buffer in + buffer.writeInteger(Int8(0)) + buffer.writeInteger(Int16(1)) + buffer.writeInteger(Int8(20)) // Only 0 and 1 are valid formats + } + + XCTAssertThrowsError( + try ByteToMessageDecoderVerifier.verifyDecoder( + inputOutputPairs: [(buffer, [])], + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) } + ) + ) { + XCTAssert($0 is PostgresMessageDecodingError) + } + } + + func testEncodeCopyDataHeader() { + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + encoder.copyDataHeader(dataLength: 3) + var byteBuffer = encoder.flushBuffer() + + XCTAssertEqual(byteBuffer.readableBytes, 5) + XCTAssertEqual(PostgresFrontendMessage.ID.copyData.rawValue, byteBuffer.readInteger(as: UInt8.self)) + XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), 7) + } + + func testEncodeCopyDone() { + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + encoder.copyDone() + var byteBuffer = encoder.flushBuffer() + + XCTAssertEqual(byteBuffer.readableBytes, 5) + XCTAssertEqual(PostgresFrontendMessage.ID.copyDone.rawValue, byteBuffer.readInteger(as: UInt8.self)) + XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), 4) + } + + func testEncodeCopyFail() { + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + encoder.copyFail(message: "Oh, no :(") + var byteBuffer = encoder.flushBuffer() + + XCTAssertEqual(byteBuffer.readableBytes, 15) + XCTAssertEqual(PostgresFrontendMessage.ID.copyFail.rawValue, byteBuffer.readInteger(as: UInt8.self)) + XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), 14) + XCTAssertEqual(byteBuffer.readNullTerminatedString(), "Oh, no :(") + } +}