Skip to content

Commit 642e7de

Browse files
committed
fixed transaction
Signed-off-by: Joern Bernhardt <jb@campudus.com>
1 parent 3ccc053 commit 642e7de

File tree

2 files changed

+51
-39
lines changed

2 files changed

+51
-39
lines changed

src/main/scala/io/vertx/asyncsql/database/ConnectionHandler.scala

Lines changed: 46 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
package io.vertx.asyncsql.database
22

33
import scala.collection.JavaConverters.iterableAsScalaIterableConverter
4-
import scala.concurrent.Future
4+
import scala.concurrent.{Promise, Future}
55
import org.vertx.scala.core.json.{JsonElement, JsonArray, JsonObject, Json}
66
import org.vertx.scala.core.logging.Logger
77
import com.github.mauricio.async.db.{Configuration, Connection, QueryResult, RowData}
@@ -31,71 +31,81 @@ trait ConnectionHandler extends ScalaBusMod {
3131

3232
def transactionEnd: String = "COMMIT;"
3333

34+
def transactionRollback: String = "ROLLBACK;"
35+
3436
def statementDelimiter: String = ";"
3537

38+
private def timeout = 500L /* FIXME from config file! */
39+
3640
import org.vertx.scala.core.eventbus._
3741

3842
private def receiver(withConnectionFn: (Connection => Future[SyncReply]) => Future[SyncReply]): Receive = (msg: Message[JsonObject]) => {
3943
def sendAsyncWithPool(fn: Connection => Future[QueryResult]) = AsyncReply(sendWithPool(withConnectionFn)(fn))
4044

4145
{
42-
case "select" => sendAsyncWithPool(rawCommand(selectCommand(msg.body)))
43-
case "insert" => sendAsyncWithPool(rawCommand(insertCommand(msg.body)))
44-
case "prepared" => sendAsyncWithPool(prepared(msg.body))
45-
case "raw" => sendAsyncWithPool(rawCommand(msg.body.getString("command")))
46-
case "transaction" => transaction(withConnectionFn)(msg.body)
46+
case "select" => sendAsyncWithPool(rawCommand(selectCommand(msg.body())))
47+
case "insert" => sendAsyncWithPool(rawCommand(insertCommand(msg.body())))
48+
case "prepared" => sendAsyncWithPool(prepared(msg.body()))
49+
case "raw" => sendAsyncWithPool(rawCommand(msg.body().getString("command")))
4750
}
4851
}
4952

50-
override def receive: Receive = { msg: Message[JsonObject] =>
53+
private def regularReceive: Receive = { msg: Message[JsonObject] =>
5154
receiver(pool.withConnection)(msg).orElse {
5255
case "start" => startTransaction(msg)
56+
case "transaction" => transaction(pool.withConnection)(msg.body())
5357
}
5458
}
5559

60+
override def receive: Receive = regularReceive
61+
5662

5763
//------------------
5864
//New transaction stuff
59-
//TODO reformat when finished
65+
private def mapRepliesToTransactionReceive(c: Connection): BusModReply => BusModReply = {
66+
case AsyncReply(receiveEndFuture) => AsyncReply(receiveEndFuture.map(mapRepliesToTransactionReceive(c)))
67+
case Ok(v, None) => Ok(v, Some(ReceiverWithTimeout(inTransactionReceive(c), timeout, () => failTransaction(c))))
68+
case x => x
69+
}
6070

61-
protected def endTransaction() = {
62-
/* FIXME */
63-
logger.info("ending transaction!")
71+
private def inTransactionReceive(c: Connection): Receive = { msg: Message[JsonObject] =>
72+
def withConnection[T](fn: Connection => Future[T]): Future[T] = fn(c)
73+
74+
receiver(withConnection)(msg).andThen({
75+
case x: BusModReply => mapRepliesToTransactionReceive(c)(x)
76+
case x => x
77+
}).orElse {
78+
case "end" => endTransaction(c)
79+
}
6480
}
6581

6682
protected def startTransaction(msg: Message[JsonObject]) = AsyncReply {
67-
pool.withConnection({ c =>
68-
def withConnection[T](fn: Connection => Future[T]): Future[T] = fn(c)
69-
83+
pool.take().flatMap { c =>
7084
c.sendQuery(transactionStart) map { _ =>
71-
Ok(Json.obj(), Some(ReceiverWithTimeout(transactionQueryReply(withConnection), 500L /* FIXME from config file! */ , endTransaction)))
85+
Ok(Json.obj(), Some(ReceiverWithTimeout(inTransactionReceive(c), timeout, () => failTransaction(c))))
7286
}
73-
})
87+
}
7488
}
7589

76-
private def transactionQueryReply(withConnection: (Connection => Future[SyncReply]) => Future[SyncReply]): Receive = { msg =>
77-
val action = msg.body.getString("action")
90+
protected def failTransaction(c: Connection) = {
91+
logger.info("NO REPLY BACK -> FAIL TRANSACTION!")
92+
c.sendQuery(transactionRollback).andThen({
93+
case _ => pool.giveBack(c)
94+
})
95+
}
7896

79-
receiver(withConnection)(msg).orElse{
80-
case "start" => Error("cannot send 'start' action when inside of transaction!")
81-
case "end" =>
82-
logger.info("got action end!")
83-
// AsyncReply{withConnection}
97+
protected def endTransaction(c: Connection) = {
98+
logger.info("ending transaction!")
99+
AsyncReply {
100+
(for {
101+
qr <- c.sendQuery(transactionEnd)
102+
_ <- pool.giveBack(c)
103+
} yield {
84104
Ok()
105+
}) recover {
106+
case ex => Error("Could not give back connection to pool", "CONNECTION_POOL_EXCEPTION", Json.obj("exception" -> ex))
107+
}
85108
}
86-
87-
val opt = pf.lift(action).map({
88-
case Ok(v, None) => Ok(v, Some(ReceiverWithTimeout(transactionQueryReply(withConnection), 500L /* FIXME from config file! */ , endTransaction)))
89-
case x => x
90-
}: Function[BusModReceiveEnd, BusModReceiveEnd])
91-
92-
opt.getOrElse {
93-
case "start" => Error("cannot send 'start' action when inside of transaction!")
94-
case "end" =>
95-
logger.info("got action end!")
96-
// AsyncReply{withConnection}
97-
Ok()
98-
}: PartialFunction[String, BusModReceiveEnd]
99109
}
100110

101111
//------------------

src/test/scala/io/vertx/asyncsql/test/BaseSqlTests.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -300,24 +300,26 @@ trait BaseSqlTests {
300300
@Test
301301
def startAndEndTransaction(): Unit = {
302302
expectOkMsg(Json.obj("action" -> "start")) map { msg =>
303+
logger.info("Should be in transaction!")
303304
msg.replyWithTimeout(raw("SELECT 15"), 500L, {
304305
case Success(reply) => Option(reply.body().getArray("results")) map { arr =>
305306
assertEquals("ok", reply.body().getString("status"))
306307
assertEquals(1, arr.size())
307308
assertEquals(15, arr
308309
.get[JsonArray](0)
309-
.get[Int](0))
310+
.get[Number](0).longValue())
311+
logger.info("First select DONE!")
310312
reply.replyWithTimeout(Json.obj("action" -> "end"), 500L, {
311313
case Success(endReply) =>
312314
assertEquals("ok", endReply.body().getString("status"))
313315
testComplete()
314316
case Failure(ex) =>
315-
logger.error("timeout", ex)
317+
logger.error("timeout when waiting for final reply (end transaction)", ex)
316318
fail(s"got a timeout when expected end reply ${ex.toString}")
317319
}: Try[Message[JsonObject]] => Unit)
318320
}
319321
case Failure(ex) =>
320-
logger.error("timeout", ex)
322+
logger.error("timeout when waiting for SELECT reply", ex)
321323
fail(s"got a timeout when expected reply ${ex.toString}")
322324
}: Try[Message[JsonObject]] => Unit)
323325
}

0 commit comments

Comments
 (0)