|  | 
|  | 1 | +import Logging | 
|  | 2 | +import FluentKit | 
|  | 3 | +import FluentBenchmark | 
|  | 4 | +import FluentPostgresDriver | 
|  | 5 | +import XCTest | 
|  | 6 | +import PostgresKit | 
|  | 7 | + | 
|  | 8 | +final class FluentPostgresTransactionControlTests: XCTestCase { | 
|  | 9 | + | 
|  | 10 | + func testTransactionControl() throws { | 
|  | 11 | + try (self.db as! TransactionControlDatabase).beginTransaction().wait() | 
|  | 12 | + | 
|  | 13 | + let todo1 = Todo(title: "Test") | 
|  | 14 | + let todo2 = Todo(title: "Test2") | 
|  | 15 | + try todo1.save(on: self.db).wait() | 
|  | 16 | + try todo2.save(on: self.db).wait() | 
|  | 17 | + | 
|  | 18 | + try (self.db as! TransactionControlDatabase).commitTransaction().wait() | 
|  | 19 | + | 
|  | 20 | + let count = try Todo.query(on: self.db).count().wait() | 
|  | 21 | + XCTAssertEqual(count, 2) | 
|  | 22 | + } | 
|  | 23 | + | 
|  | 24 | + func testRollback() throws { | 
|  | 25 | + try (self.db as! TransactionControlDatabase).beginTransaction().wait() | 
|  | 26 | + | 
|  | 27 | + let todo1 = Todo(title: "Test") | 
|  | 28 | + | 
|  | 29 | + try todo1.save(on: self.db).wait() | 
|  | 30 | + | 
|  | 31 | + let duplicate = Todo(title: "Test") | 
|  | 32 | + var errorCaught = false | 
|  | 33 | + | 
|  | 34 | + do { | 
|  | 35 | + try duplicate.create(on: self.db).wait() | 
|  | 36 | + } catch { | 
|  | 37 | + errorCaught = true | 
|  | 38 | + try (self.db as! TransactionControlDatabase).rollbackTransaction().wait() | 
|  | 39 | + } | 
|  | 40 | + | 
|  | 41 | + if !errorCaught { | 
|  | 42 | + try (self.db as! TransactionControlDatabase).commitTransaction().wait() | 
|  | 43 | + } | 
|  | 44 | + | 
|  | 45 | + XCTAssertTrue(errorCaught) | 
|  | 46 | + let count2 = try Todo.query(on: self.db).count().wait() | 
|  | 47 | + XCTAssertEqual(count2, 0) | 
|  | 48 | + } | 
|  | 49 | + | 
|  | 50 | + var benchmarker: FluentBenchmarker { | 
|  | 51 | + return .init(databases: self.dbs) | 
|  | 52 | + } | 
|  | 53 | + var eventLoopGroup: EventLoopGroup! | 
|  | 54 | + var threadPool: NIOThreadPool! | 
|  | 55 | + var dbs: Databases! | 
|  | 56 | + var db: Database { | 
|  | 57 | + self.benchmarker.database | 
|  | 58 | + } | 
|  | 59 | + var postgres: PostgresDatabase { | 
|  | 60 | + self.db as! PostgresDatabase | 
|  | 61 | + } | 
|  | 62 | + | 
|  | 63 | + override func setUpWithError() throws { | 
|  | 64 | + try super.setUpWithError() | 
|  | 65 | + | 
|  | 66 | + XCTAssert(isLoggingConfigured) | 
|  | 67 | + self.eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) | 
|  | 68 | + self.threadPool = NIOThreadPool(numberOfThreads: 1) | 
|  | 69 | + self.dbs = Databases(threadPool: threadPool, on: self.eventLoopGroup) | 
|  | 70 | + | 
|  | 71 | + self.dbs.use(.testPostgres(subconfig: "A"), as: .a) | 
|  | 72 | + self.dbs.use(.testPostgres(subconfig: "B"), as: .b) | 
|  | 73 | + | 
|  | 74 | + let a = self.dbs.database(.a, logger: Logger(label: "test.fluent.a"), on: self.eventLoopGroup.next()) | 
|  | 75 | + _ = try (a as! PostgresDatabase).query("drop schema public cascade").wait() | 
|  | 76 | + _ = try (a as! PostgresDatabase).query("create schema public").wait() | 
|  | 77 | + | 
|  | 78 | + let b = self.dbs.database(.b, logger: Logger(label: "test.fluent.b"), on: self.eventLoopGroup.next()) | 
|  | 79 | + _ = try (b as! PostgresDatabase).query("drop schema public cascade").wait() | 
|  | 80 | + _ = try (b as! PostgresDatabase).query("create schema public").wait() | 
|  | 81 | + | 
|  | 82 | + try CreateTodo().prepare(on: self.db).wait() | 
|  | 83 | + } | 
|  | 84 | + | 
|  | 85 | + override func tearDownWithError() throws { | 
|  | 86 | + try CreateTodo().revert(on: self.db).wait() | 
|  | 87 | + self.dbs.shutdown() | 
|  | 88 | + try self.threadPool.syncShutdownGracefully() | 
|  | 89 | + try self.eventLoopGroup.syncShutdownGracefully() | 
|  | 90 | + try super.tearDownWithError() | 
|  | 91 | + } | 
|  | 92 | + | 
|  | 93 | + final class Todo: Model { | 
|  | 94 | + static let schema = "todos" | 
|  | 95 | + | 
|  | 96 | + @ID | 
|  | 97 | + var id: UUID? | 
|  | 98 | + | 
|  | 99 | + @Field(key: "title") | 
|  | 100 | + var title: String | 
|  | 101 | + | 
|  | 102 | + init() { } | 
|  | 103 | + init(title: String) { self.title = title; id = nil } | 
|  | 104 | + } | 
|  | 105 | + | 
|  | 106 | + struct CreateTodo: Migration { | 
|  | 107 | + func prepare(on database: Database) -> EventLoopFuture<Void> { | 
|  | 108 | + return database.schema("todos") | 
|  | 109 | + .id() | 
|  | 110 | + .field("title", .string, .required) | 
|  | 111 | + .unique(on: "title") | 
|  | 112 | + .create() | 
|  | 113 | + } | 
|  | 114 | + | 
|  | 115 | + func revert(on database: Database) -> EventLoopFuture<Void> { | 
|  | 116 | + return database.schema("todos").delete() | 
|  | 117 | + } | 
|  | 118 | + } | 
|  | 119 | +} | 
0 commit comments