Skip to content

Commit 087dc97

Browse files
committed
Merge pull request mauricio#54 from dylex/transaction
Add Connection.inTransaction to wrap queries in a transaction block
2 parents 11acaf4 + 3edda7a commit 087dc97

File tree

4 files changed

+124
-1
lines changed

4 files changed

+124
-1
lines changed

db-async-common/src/main/scala/com/github/mauricio/async/db/Connection.scala

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,4 +115,25 @@ trait Connection {
115115

116116
def sendPreparedStatement(query: String, values: Seq[Any] = List()): Future[QueryResult]
117117

118+
/**
119+
*
120+
* Executes an (asynchronous) function within a transaction block.
121+
* If the function completes successfully, the transaction is committed, otherwise it is aborted.
122+
*
123+
* @param f operation to execute on this connection
124+
* @return result of f, conditional on transaction operations succeeding
125+
*/
126+
127+
def inTransaction[A](f : Connection => Future[A])(implicit executionContext : scala.concurrent.ExecutionContext) : Future[A] = {
128+
this.sendQuery("BEGIN").flatMap { _ =>
129+
val p = scala.concurrent.Promise[A]()
130+
f(this).onComplete { r =>
131+
this.sendQuery(if (r.isFailure) "ROLLBACK" else "COMMIT").onComplete {
132+
case scala.util.Failure(e) if r.isSuccess => p.failure(e)
133+
case _ => p.complete(r)
134+
}
135+
}
136+
p.future
137+
}
138+
}
118139
}

db-async-common/src/main/scala/com/github/mauricio/async/db/pool/ConnectionPool.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,4 +95,17 @@ class ConnectionPool[T <: Connection](
9595
def sendPreparedStatement(query: String, values: Seq[Any] = List()): Future[QueryResult] =
9696
this.use(_.sendPreparedStatement(query, values))(executionContext)
9797

98+
/**
99+
*
100+
* Picks one connection and executes an (asynchronous) function on it within a transaction block.
101+
* If the function completes successfully, the transaction is committed, otherwise it is aborted.
102+
* Either way, the connection is returned to the pool on completion.
103+
*
104+
* @param f operation to execute on a connection
105+
* @return result of f, conditional on transaction operations succeeding
106+
*/
107+
108+
override def inTransaction[A](f : Connection => Future[A])(implicit context : ExecutionContext = executionContext) : Future[A] =
109+
this.use(_.inTransaction[A](f)(context))(executionContext)
110+
98111
}
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
package com.github.mauricio.async.db.postgresql
2+
3+
import org.specs2.mutable.Specification
4+
import com.github.mauricio.async.db.util.Log
5+
import com.github.mauricio.async.db.exceptions.DatabaseException
6+
import scala.concurrent.ExecutionContext.Implicits.global
7+
import scala.util.control.Exception.catching
8+
9+
class TransactionSpec extends Specification with DatabaseTestHelper {
10+
11+
val log = Log.get[TransactionSpec]
12+
13+
val tableCreate = "CREATE TEMP TABLE transaction_test (x integer PRIMARY KEY)"
14+
def tableInsert(x : Int) = "INSERT INTO transaction_test VALUES (" + x.toString + ")"
15+
val tableSelect = "SELECT x FROM transaction_test ORDER BY x"
16+
17+
"transactions" should {
18+
19+
"commit simple inserts" in {
20+
withHandler { handler =>
21+
executeDdl(handler, tableCreate)
22+
await(handler.inTransaction { conn =>
23+
conn.sendQuery(tableInsert(1)).flatMap { _ =>
24+
conn.sendQuery(tableInsert(2))
25+
}
26+
})
27+
28+
val rows = executeQuery(handler, tableSelect).rows.get
29+
rows.length === 2
30+
rows(0)(0) === 1
31+
rows(1)(0) === 2
32+
}
33+
}
34+
35+
"rollback on error" in {
36+
withHandler { handler =>
37+
executeDdl(handler, tableCreate)
38+
catching(classOf[DatabaseException]).opt(
39+
await(handler.inTransaction { conn =>
40+
conn.sendQuery(tableInsert(1)).flatMap { _ =>
41+
conn.sendQuery(tableInsert(1))
42+
}
43+
})
44+
) === None
45+
46+
val rows = executeQuery(handler, tableSelect).rows.get
47+
rows.length === 0
48+
}
49+
50+
}
51+
52+
"rollback explicitly" in {
53+
withHandler { handler =>
54+
executeDdl(handler, tableCreate)
55+
await(handler.inTransaction { conn =>
56+
conn.sendQuery(tableInsert(1)).flatMap { _ =>
57+
conn.sendQuery("ROLLBACK")
58+
}
59+
})
60+
61+
val rows = executeQuery(handler, tableSelect).rows.get
62+
rows.length === 0
63+
}
64+
65+
}
66+
67+
"rollback to savepoint" in {
68+
withHandler { handler =>
69+
executeDdl(handler, tableCreate)
70+
await(handler.inTransaction { conn =>
71+
conn.sendQuery(tableInsert(1)).flatMap { _ =>
72+
conn.sendQuery("SAVEPOINT one").flatMap { _ =>
73+
conn.sendQuery(tableInsert(2)).flatMap { _ =>
74+
conn.sendQuery("ROLLBACK TO SAVEPOINT one")
75+
}
76+
}
77+
}
78+
})
79+
80+
val rows = executeQuery(handler, tableSelect).rows.get
81+
rows.length === 1
82+
rows(0)(0) === 1
83+
}
84+
85+
}
86+
87+
}
88+
89+
}

postgresql-async/src/test/scala/com/github/mauricio/async/db/postgresql/pool/SingleThreadedAsyncObjectPoolSpec.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ class SingleThreadedAsyncObjectPoolSpec extends Specification with DatabaseTestH
8484

8585
}
8686

87-
"it shoudl remove idle connections once the time limit has been reached" in {
87+
"it should remove idle connections once the time limit has been reached" in {
8888

8989
withPool({
9090
pool =>

0 commit comments

Comments
 (0)