Skip to content

Commit f7bdc4e

Browse files
committed
start of mysql / extract stuff into base trait for psql/mysql
1 parent d06a36c commit f7bdc4e

10 files changed

+167
-63
lines changed

src/main/scala/com/campudus/vertx/database/ConnectionHandler.scala

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ package com.campudus.vertx.database
33
import org.vertx.scala.core.eventbus.Message
44
import org.vertx.java.core.json.JsonObject
55
import com.github.mauricio.async.db.Configuration
6-
import com.campudus.vertx.database.pool.PostgreSQLAsyncConnectionPool
6+
import com.campudus.vertx.database.pool.PostgreSqlAsyncConnectionPool
77
import com.campudus.vertx.VertxExecutionContext
88
import com.github.mauricio.async.db.Connection
99
import com.campudus.vertx.busmod.ScalaBusMod
@@ -18,9 +18,12 @@ import com.github.mauricio.async.db.RowData
1818
import collection.JavaConverters._
1919
import com.github.mauricio.async.db.postgresql.exceptions.GenericDatabaseException
2020

21-
class ConnectionHandler(verticle: Verticle, dbType: String, config: Configuration) extends ScalaBusMod with VertxScalaHelpers {
22-
val pool = AsyncConnectionPool(verticle.vertx, dbType, config)
23-
val logger = verticle.container.logger()
21+
trait ConnectionHandler extends ScalaBusMod with VertxScalaHelpers {
22+
val verticle: Verticle
23+
def dbType: String
24+
val config: Configuration
25+
lazy val pool = AsyncConnectionPool(verticle.vertx, dbType, config)
26+
lazy val logger = verticle.container.logger()
2427

2528
override def asyncReceive(msg: Message[JsonObject]) = {
2629
case "select" => select(msg.body)
@@ -30,7 +33,7 @@ class ConnectionHandler(verticle: Verticle, dbType: String, config: Configuratio
3033

3134
def close() = pool.close
3235

33-
private def select(json: JsonObject): Future[Reply] = pool.withConnection({ c: Connection =>
36+
protected def select(json: JsonObject): Future[Reply] = pool.withConnection({ c: Connection =>
3437
val table = escapeField(json.getString("table"))
3538
val command = Option(json.getArray("fields")) match {
3639
case Some(fields) => fields.asScala.toStream.map(elem => escapeField(elem.toString)).mkString("SELECT ", ",", " FROM " + table)
@@ -40,16 +43,16 @@ class ConnectionHandler(verticle: Verticle, dbType: String, config: Configuratio
4043
rawCommand(command)
4144
})
4245

43-
private def escapeField(str: String): String = "\"" + str.replace("\"", "\"\"") + "\""
44-
private def escapeString(str: String): String = "'" + str.replace("'", "''") + "'"
46+
protected def escapeField(str: String): String = "\"" + str.replace("\"", "\"\"") + "\""
47+
protected def escapeString(str: String): String = "'" + str.replace("'", "''") + "'"
4548

46-
private def escapeValue(v: Any): String = v match {
49+
protected def escapeValue(v: Any): String = v match {
4750
case v: Int => v.toString
4851
case v: Boolean => v.toString
4952
case v => escapeString(v.toString)
5053
}
5154

52-
private def insert(json: JsonObject): Future[Reply] = {
55+
protected def insert(json: JsonObject): Future[Reply] = {
5356
val table = json.getString("table")
5457
val fields = json.getArray("fields").asScala
5558
val lines = json.getArray("values").asScala
@@ -68,7 +71,7 @@ class ConnectionHandler(verticle: Verticle, dbType: String, config: Configuratio
6871
rawCommand(cmd.toString)
6972
}
7073

71-
private def rawCommand(command: String): Future[Reply] = pool.withConnection({ c: Connection =>
74+
protected def rawCommand(command: String): Future[Reply] = pool.withConnection({ c: Connection =>
7275
logger.info("sending command: " + command)
7376
c.sendQuery(command) map buildResults recover {
7477
case x: GenericDatabaseException =>
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
package com.campudus.vertx.database
2+
3+
import com.campudus.vertx.Verticle
4+
import com.github.mauricio.async.db.Configuration
5+
6+
class MySqlConnectionHandler(val verticle: Verticle, val config: Configuration, val dbType: String = "mysql") extends ConnectionHandler {
7+
override protected def escapeField(str: String): String = "`" + str.replace("`", "\\`") + "`"
8+
override protected def escapeString(str: String): String = "'" + str.replace("'", "''") + "'"
9+
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
package com.campudus.vertx.database
2+
3+
import com.campudus.vertx.Verticle
4+
import com.github.mauricio.async.db.Configuration
5+
6+
class PostgreSqlConnectionHandler(val verticle: Verticle, val config: Configuration, val dbType: String = "postgresql") extends ConnectionHandler {
7+
8+
}

src/main/scala/com/campudus/vertx/database/Starter.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,10 @@ class Starter extends Verticle {
2727
val dbType = getDatabaseType(config)
2828
val configuration = getConfiguration(config, dbType)
2929

30-
handler = new ConnectionHandler(this, dbType, configuration)
30+
handler = dbType match {
31+
case "postgresql" => new PostgreSqlConnectionHandler(this, configuration)
32+
case "mysql" => new MySqlConnectionHandler(this, configuration)
33+
}
3134
vertx.eventBus.registerHandler(address)(handler)
3235

3336
logger.error("Async database module for MySQL and PostgreSQL started with config " + configuration)

src/main/scala/com/campudus/vertx/database/pool/AsyncConnectionPool.scala

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,18 @@ trait AsyncConnectionPool[ConnType <: Connection] {
4646

4747
object AsyncConnectionPool {
4848

49-
def apply(vertx: Vertx, dbType: String, config: Configuration) = dbType match {
50-
case "postgresql" =>
51-
new PostgreSQLAsyncConnectionPool(
52-
config,
53-
vertx.internal.currentContext().asInstanceOf[EventLoopContext].getEventLoop())
54-
case _ => throw new NotImplementedError
49+
def apply(vertx: Vertx, dbType: String, config: Configuration) = {
50+
println("got db type: " + dbType)
51+
dbType match {
52+
case "postgresql" =>
53+
new PostgreSqlAsyncConnectionPool(
54+
config,
55+
vertx.internal.currentContext().asInstanceOf[EventLoopContext].getEventLoop())
56+
case "mysql" =>
57+
new MySqlAsyncConnectionPool(config,
58+
vertx.internal.currentContext().asInstanceOf[EventLoopContext].getEventLoop())
59+
case _ => throw new NotImplementedError
60+
}
5561
}
5662

5763
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
package com.campudus.vertx.database.pool
2+
3+
import com.github.mauricio.async.db.Configuration
4+
import com.github.mauricio.async.db.postgresql.PostgreSQLConnection
5+
import scala.concurrent.Future
6+
import scala.concurrent.ExecutionContext
7+
import com.campudus.vertx.VertxExecutionContext
8+
import com.github.mauricio.async.db.Connection
9+
import org.vertx.java.core.impl.EventLoopContext
10+
import io.netty.channel.EventLoop
11+
import com.github.mauricio.async.db.mysql.MySQLConnection
12+
13+
class MySqlAsyncConnectionPool(config: Configuration, eventLoop: EventLoop, implicit val executionContext: ExecutionContext = VertxExecutionContext) extends AsyncConnectionPool[PostgreSQLConnection] {
14+
15+
override def take() = new MySQLConnection(configuration = config, group = eventLoop).connect
16+
17+
override def giveBack(connection: Connection) = {
18+
connection.disconnect map (_ => MySqlAsyncConnectionPool.this) recover {
19+
case ex =>
20+
executionContext.reportFailure(ex)
21+
MySqlAsyncConnectionPool.this
22+
}
23+
}
24+
25+
override def close() = Future.successful(MySqlAsyncConnectionPool.this)
26+
27+
}

src/main/scala/com/campudus/vertx/database/pool/PostgreSQLAsyncConnectionPool.scala renamed to src/main/scala/com/campudus/vertx/database/pool/PostgreSqlAsyncConnectionPool.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,18 @@ import com.github.mauricio.async.db.Connection
99
import org.vertx.java.core.impl.EventLoopContext
1010
import io.netty.channel.EventLoop
1111

12-
class PostgreSQLAsyncConnectionPool(config: Configuration, eventLoop: EventLoop, implicit val executionContext: ExecutionContext = VertxExecutionContext) extends AsyncConnectionPool[PostgreSQLConnection] {
12+
class PostgreSqlAsyncConnectionPool(config: Configuration, eventLoop: EventLoop, implicit val executionContext: ExecutionContext = VertxExecutionContext) extends AsyncConnectionPool[PostgreSQLConnection] {
1313

1414
override def take() = new PostgreSQLConnection(configuration = config, group = eventLoop).connect
1515

1616
override def giveBack(connection: Connection) = {
17-
connection.disconnect map (_ => this) recover {
17+
connection.disconnect map (_ => PostgreSqlAsyncConnectionPool.this) recover {
1818
case ex =>
1919
executionContext.reportFailure(ex)
20-
this
20+
PostgreSqlAsyncConnectionPool.this
2121
}
2222
}
2323

24-
override def close() = Future.successful(this)
24+
override def close() = Future.successful(PostgreSqlAsyncConnectionPool.this)
2525

2626
}

src/test/scala/com/campudus/test/postgresql/PostgreSQLTest.scala renamed to src/test/scala/com/campudus/test/BaseSqlTests.scala

Lines changed: 18 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,12 @@
1-
package com.campudus.test.postgresql
1+
package com.campudus.test
22

3-
import org.junit.Test
4-
import scala.concurrent.Promise
53
import scala.concurrent.Future
6-
import org.vertx.testtools.VertxAssert._
7-
import org.vertx.java.core.json.JsonObject
8-
import org.vertx.scala.platform.Verticle
9-
import org.vertx.scala.platform.Container
10-
import org.vertx.scala.core.Vertx
11-
import org.vertx.java.core.eventbus.Message
12-
import com.campudus.vertx.VertxScalaHelpers
13-
import org.vertx.java.core.json.JsonArray
14-
import com.campudus.test.SqlTestVerticle
15-
import org.vertx.testtools.VertxAssert._
16-
import scala.util.Failure
17-
import scala.util.Success
18-
import com.github.mauricio.async.db.column.DateEncoderDecoder
19-
20-
class PostgreSQLTest extends SqlTestVerticle with VertxScalaHelpers {
214

22-
val address = "campudus.asyncdb"
23-
val config = new JsonObject().putString("address", address)
24-
lazy val logger = container.logger()
5+
import org.vertx.java.core.json.JsonArray
6+
import org.vertx.testtools.VertxAssert.assertEquals
257

26-
override def getConfig = config
8+
trait BaseSqlTests { this: SqlTestVerticle =>
9+
lazy val logger = getContainer().logger()
2710

2811
def withTable[X](tableName: String)(fn: => Future[X]) = {
2912
for {
@@ -35,7 +18,15 @@ class PostgreSQLTest extends SqlTestVerticle with VertxScalaHelpers {
3518

3619
def asyncTableTest[X](tableName: String)(fn: => Future[X]) = asyncTest(withTable(tableName)(fn))
3720

38-
@Test
21+
private def typeTestInsert[X](fn: => Future[X]) = asyncTableTest("some_test") {
22+
expectOk(insert("some_test",
23+
new JsonArray("""["name","email","is_male","age","money","wedding_date"]"""),
24+
new JsonArray("""[["Mr. Test","test@example.com",true,15,167.31,"2024-04-01"],
25+
["Ms Test2","test2@example.com",false,43,167.31,"1997-12-24"]]"""))) flatMap { _ =>
26+
fn
27+
}
28+
}
29+
3930
def simpleConnection(): Unit = asyncTest {
4031
expectOk(raw("SELECT 0")) map { reply =>
4132
assertEquals(1, reply.getNumber("rows"))
@@ -45,7 +36,6 @@ class PostgreSQLTest extends SqlTestVerticle with VertxScalaHelpers {
4536
}
4637
}
4738

48-
@Test
4939
def multipleFields(): Unit = asyncTest {
5040
expectOk(raw("SELECT 1, 0")) map { reply =>
5141
assertEquals(1, reply.getNumber("rows"))
@@ -57,33 +47,20 @@ class PostgreSQLTest extends SqlTestVerticle with VertxScalaHelpers {
5747
}
5848
}
5949

60-
@Test
6150
def createAndDropTable(): Unit = asyncTest {
6251
createTable("some_test") flatMap (_ => dropTable("some_test")) map { reply =>
6352
assertEquals(0, reply.getNumber("rows"))
6453
}
6554
}
6655

67-
@Test
6856
def insertCorrect(): Unit = asyncTableTest("some_test") {
6957
expectOk(insert("some_test", new JsonArray("""["name","email"]"""), new JsonArray("""[["Test","test@example.com"],["Test2","test2@example.com"]]""")))
7058
}
7159

72-
private def typeTestInsert[X](fn: => Future[X]) = asyncTableTest("some_test") {
73-
expectOk(insert("some_test",
74-
new JsonArray("""["name","email","is_male","age","money","wedding_date"]"""),
75-
new JsonArray("""[["Mr. Test","test@example.com",true,15,167.31,"2024-04-01"],
76-
["Ms Test2","test2@example.com",false,43,167.31,"1997-12-24"]]"""))) flatMap { _ =>
77-
fn
78-
}
79-
}
80-
81-
@Test
8260
def insertTypeTest(): Unit = typeTestInsert {
8361
Future.successful()
8462
}
8563

86-
@Test
8764
def insertMaliciousDataTest(): Unit = asyncTableTest("some_test") {
8865
// If this SQL injection works, the drop table of asyncTableTest would throw an exception
8966
expectOk(insert("some_test",
@@ -92,14 +69,12 @@ class PostgreSQLTest extends SqlTestVerticle with VertxScalaHelpers {
9269
["Ms Test2','some@example.com',false,15,167.31,'2024-04-01');DROP TABLE some_test;--","test2@example.com",false,43,167.31,"1997-12-24"]]""")))
9370
}
9471

95-
@Test
9672
def insertUniqueProblem(): Unit = asyncTableTest("some_test") {
9773
expectError(insert("some_test", new JsonArray("""["name","email"]"""), new JsonArray("""[["Test","test@example.com"],["Test","test@example.com"]]"""))) map { reply =>
9874
logger.info("expected error: " + reply.encode())
9975
}
10076
}
10177

102-
@Test
10378
def selectEverything(): Unit = typeTestInsert {
10479
val fieldsArray = new JsonArray("""["name","email","is_male","age","money","wedding_date"]""")
10580
expectOk(select("some_test", fieldsArray)) map { reply =>
@@ -125,7 +100,8 @@ class PostgreSQLTest extends SqlTestVerticle with VertxScalaHelpers {
125100
assertEquals(true, mrTest.get[Boolean](2))
126101
assertEquals(15, mrTest.get[Integer](3))
127102
assertEquals(167.31, mrTest.get[Integer](4))
128-
assertEquals(DateEncoderDecoder.decode("2024-04-01").toString(), mrTest.get[String](5))
103+
// FIXME check date conversion
104+
// assertEquals("2024-04-01", mrTest.get[JsonObject](5))
129105
}
130106

131107
private def checkMrsTest(mrsTest: JsonArray) = {
@@ -134,10 +110,10 @@ class PostgreSQLTest extends SqlTestVerticle with VertxScalaHelpers {
134110
assertEquals(false, mrsTest.get[Boolean](2))
135111
assertEquals(43, mrsTest.get[Integer](3))
136112
assertEquals(167.31, mrsTest.get[Integer](4))
137-
assertEquals(DateEncoderDecoder.decode("1997-12-24").toString(), mrsTest.get[String](5))
113+
// FIXME check date conversion
114+
// assertEquals("1997-12-24", mrsTest.get[JsonObject](5))
138115
}
139116

140-
@Test
141117
def selectFiltered(): Unit = typeTestInsert {
142118
val fieldsArray = new JsonArray("""["name","email"]""")
143119
expectOk(select("some_test", fieldsArray)) map { reply =>
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
package com.campudus.test.postgresql
2+
3+
import org.junit.Test
4+
import org.vertx.java.core.json.JsonObject
5+
import com.campudus.test.{ BaseSqlTests, SqlTestVerticle }
6+
import org.vertx.testtools.VertxAssert
7+
8+
class MySqlTest extends SqlTestVerticle with BaseSqlTests {
9+
10+
val address = "campudus.asyncdb"
11+
val config = new JsonObject().putString("address", address).putString("connection", "MySQL")
12+
13+
override def getConfig = config
14+
15+
// FIXME test stuff
16+
@Test
17+
def something(): Unit = VertxAssert.testComplete()
18+
19+
// @Test
20+
// override def selectFiltered(): Unit = super.selectFiltered()
21+
// @Test
22+
// override def selectEverything(): Unit = super.selectEverything()
23+
// @Test
24+
// override def insertUniqueProblem(): Unit = super.insertUniqueProblem()
25+
// @Test
26+
// override def insertMaliciousDataTest(): Unit = super.insertMaliciousDataTest()
27+
// @Test
28+
// override def insertTypeTest(): Unit = super.insertTypeTest()
29+
// @Test
30+
// override def insertCorrect(): Unit = super.insertCorrect()
31+
// @Test
32+
// override def createAndDropTable(): Unit = super.createAndDropTable()
33+
// @Test
34+
// override def multipleFields(): Unit = super.multipleFields()
35+
// @Test
36+
// override def simpleConnection(): Unit = super.simpleConnection()
37+
38+
}
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
package com.campudus.test.postgresql
2+
3+
import org.junit.Test
4+
import org.vertx.java.core.json.JsonObject
5+
6+
import com.campudus.test.{ BaseSqlTests, SqlTestVerticle }
7+
8+
class PostgreSqlTest extends SqlTestVerticle with BaseSqlTests {
9+
10+
val address = "campudus.asyncdb"
11+
val config = new JsonObject().putString("address", address)
12+
13+
override def getConfig = config
14+
15+
@Test
16+
override def selectFiltered(): Unit = super.selectFiltered()
17+
@Test
18+
override def selectEverything(): Unit = super.selectEverything()
19+
@Test
20+
override def insertUniqueProblem(): Unit = super.insertUniqueProblem()
21+
@Test
22+
override def insertMaliciousDataTest(): Unit = super.insertMaliciousDataTest()
23+
@Test
24+
override def insertTypeTest(): Unit = super.insertTypeTest()
25+
@Test
26+
override def insertCorrect(): Unit = super.insertCorrect()
27+
@Test
28+
override def createAndDropTable(): Unit = super.createAndDropTable()
29+
@Test
30+
override def multipleFields(): Unit = super.multipleFields()
31+
@Test
32+
override def simpleConnection(): Unit = super.simpleConnection()
33+
34+
}

0 commit comments

Comments
 (0)