Skip to content

Commit 1570906

Browse files
committed
Merge pull request mauricio#172 from alexdupre/ssl
Add SSL support.
2 parents 81fa14d + 0f9a587 commit 1570906

File tree

21 files changed

+364
-48
lines changed

21 files changed

+364
-48
lines changed

db-async-common/src/main/scala/com/github/mauricio/async/db/Configuration.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ object Configuration {
3737
* @param port database port, defaults to 5432
3838
* @param password password, defaults to no password
3939
* @param database database name, defaults to no database
40+
* @param ssl ssl configuration
4041
* @param charset charset for the connection, defaults to UTF-8, make sure you know what you are doing if you
4142
* change this
4243
* @param maximumMessageSize the maximum size a message from the server could possibly have, this limits possible
@@ -55,6 +56,7 @@ case class Configuration(username: String,
5556
port: Int = 5432,
5657
password: Option[String] = None,
5758
database: Option[String] = None,
59+
ssl: SSLConfiguration = SSLConfiguration(),
5860
charset: Charset = Configuration.DefaultCharset,
5961
maximumMessageSize: Int = 16777216,
6062
allocator: ByteBufAllocator = PooledByteBufAllocator.DEFAULT,
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
package com.github.mauricio.async.db
2+
3+
import java.io.File
4+
5+
import SSLConfiguration.Mode
6+
7+
/**
8+
*
9+
* Contains the SSL configuration necessary to connect to a database.
10+
*
11+
* @param mode whether and with what priority a SSL connection will be negotiated, default disabled
12+
* @param rootCert path to PEM encoded trusted root certificates, None to use internal JDK cacerts, defaults to None
13+
*
14+
*/
15+
case class SSLConfiguration(mode: Mode.Value = Mode.Disable, rootCert: Option[java.io.File] = None)
16+
17+
object SSLConfiguration {
18+
19+
object Mode extends Enumeration {
20+
val Disable = Value("disable") // only try a non-SSL connection
21+
val Prefer = Value("prefer") // first try an SSL connection; if that fails, try a non-SSL connection
22+
val Require = Value("require") // only try an SSL connection, but don't verify Certificate Authority
23+
val VerifyCA = Value("verify-ca") // only try an SSL connection, and verify that the server certificate is issued by a trusted certificate authority (CA)
24+
val VerifyFull = Value("verify-full") // only try an SSL connection, verify that the server certificate is issued by a trusted CA and that the server host name matches that in the certificate
25+
}
26+
27+
def apply(properties: Map[String, String]): SSLConfiguration = SSLConfiguration(
28+
mode = Mode.withName(properties.get("sslmode").getOrElse("disable")),
29+
rootCert = properties.get("sslrootcert").map(new File(_))
30+
)
31+
}

postgresql-async/src/main/scala/com/github/mauricio/async/db/postgresql/codec/MessageDecoder.scala

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
package com.github.mauricio.async.db.postgresql.codec
1818

1919
import com.github.mauricio.async.db.postgresql.exceptions.{MessageTooLongException}
20-
import com.github.mauricio.async.db.postgresql.messages.backend.ServerMessage
20+
import com.github.mauricio.async.db.postgresql.messages.backend.{ServerMessage, SSLResponseMessage}
2121
import com.github.mauricio.async.db.postgresql.parsers.{AuthenticationStartupParser, MessageParsersRegistry}
2222
import com.github.mauricio.async.db.util.{BufferDumper, Log}
2323
import java.nio.charset.Charset
@@ -31,15 +31,21 @@ object MessageDecoder {
3131
val DefaultMaximumSize = 16777216
3232
}
3333

34-
class MessageDecoder(charset: Charset, maximumMessageSize : Int = MessageDecoder.DefaultMaximumSize) extends ByteToMessageDecoder {
34+
class MessageDecoder(sslEnabled: Boolean, charset: Charset, maximumMessageSize : Int = MessageDecoder.DefaultMaximumSize) extends ByteToMessageDecoder {
3535

3636
import MessageDecoder.log
3737

3838
private val parser = new MessageParsersRegistry(charset)
3939

40+
private var sslChecked = false
41+
4042
override def decode(ctx: ChannelHandlerContext, b: ByteBuf, out: java.util.List[Object]): Unit = {
4143

42-
if (b.readableBytes() >= 5) {
44+
if (sslEnabled & !sslChecked) {
45+
val code = b.readByte()
46+
sslChecked = true
47+
out.add(new SSLResponseMessage(code == 'S'))
48+
} else if (b.readableBytes() >= 5) {
4349

4450
b.markReaderIndex()
4551

postgresql-async/src/main/scala/com/github/mauricio/async/db/postgresql/codec/MessageEncoder.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,13 @@ class MessageEncoder(charset: Charset, encoderRegistry: ColumnEncoderRegistry) e
4444
override def encode(ctx: ChannelHandlerContext, msg: AnyRef, out: java.util.List[Object]) = {
4545

4646
val buffer = msg match {
47+
case SSLRequestMessage => SSLMessageEncoder.encode()
48+
case message: StartupMessage => startupEncoder.encode(message)
4749
case message: ClientMessage => {
4850
val encoder = (message.kind: @switch) match {
4951
case ServerMessage.Close => CloseMessageEncoder
5052
case ServerMessage.Execute => this.executeEncoder
5153
case ServerMessage.Parse => this.openEncoder
52-
case ServerMessage.Startup => this.startupEncoder
5354
case ServerMessage.Query => this.queryEncoder
5455
case ServerMessage.PasswordMessage => this.credentialEncoder
5556
case _ => throw new EncoderNotAvailableException(message)

postgresql-async/src/main/scala/com/github/mauricio/async/db/postgresql/codec/PostgreSQLConnectionHandler.scala

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package com.github.mauricio.async.db.postgresql.codec
1818

1919
import com.github.mauricio.async.db.Configuration
20+
import com.github.mauricio.async.db.SSLConfiguration.Mode
2021
import com.github.mauricio.async.db.column.{ColumnDecoderRegistry, ColumnEncoderRegistry}
2122
import com.github.mauricio.async.db.postgresql.exceptions._
2223
import com.github.mauricio.async.db.postgresql.messages.backend._
@@ -38,6 +39,12 @@ import com.github.mauricio.async.db.postgresql.messages.backend.RowDescriptionMe
3839
import com.github.mauricio.async.db.postgresql.messages.backend.ParameterStatusMessage
3940
import io.netty.channel.socket.nio.NioSocketChannel
4041
import io.netty.handler.codec.CodecException
42+
import io.netty.handler.ssl.{SslContextBuilder, SslHandler}
43+
import io.netty.handler.ssl.util.InsecureTrustManagerFactory
44+
import io.netty.util.concurrent.FutureListener
45+
import javax.net.ssl.{SSLParameters, TrustManagerFactory}
46+
import java.security.KeyStore
47+
import java.io.FileInputStream
4148

4249
object PostgreSQLConnectionHandler {
4350
final val log = Log.get[PostgreSQLConnectionHandler]
@@ -79,7 +86,7 @@ class PostgreSQLConnectionHandler
7986

8087
override def initChannel(ch: channel.Channel): Unit = {
8188
ch.pipeline.addLast(
82-
new MessageDecoder(configuration.charset, configuration.maximumMessageSize),
89+
new MessageDecoder(configuration.ssl.mode != Mode.Disable, configuration.charset, configuration.maximumMessageSize),
8390
new MessageEncoder(configuration.charset, encoderRegistry),
8491
PostgreSQLConnectionHandler.this)
8592
}
@@ -120,13 +127,61 @@ class PostgreSQLConnectionHandler
120127
}
121128

122129
override def channelActive(ctx: ChannelHandlerContext): Unit = {
123-
ctx.writeAndFlush(new StartupMessage(this.properties))
130+
if (configuration.ssl.mode == Mode.Disable)
131+
ctx.writeAndFlush(new StartupMessage(this.properties))
132+
else
133+
ctx.writeAndFlush(SSLRequestMessage)
124134
}
125135

126136
override def channelRead0(ctx: ChannelHandlerContext, msg: Object): Unit = {
127137

128138
msg match {
129139

140+
case SSLResponseMessage(supported) =>
141+
if (supported) {
142+
val ctxBuilder = SslContextBuilder.forClient()
143+
if (configuration.ssl.mode >= Mode.VerifyCA) {
144+
configuration.ssl.rootCert.fold {
145+
val tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
146+
val ks = KeyStore.getInstance(KeyStore.getDefaultType())
147+
val cacerts = new FileInputStream(System.getProperty("java.home") + "/lib/security/cacerts")
148+
try {
149+
ks.load(cacerts, "changeit".toCharArray)
150+
} finally {
151+
cacerts.close()
152+
}
153+
tmf.init(ks)
154+
ctxBuilder.trustManager(tmf)
155+
} { path =>
156+
ctxBuilder.trustManager(path)
157+
}
158+
} else {
159+
ctxBuilder.trustManager(InsecureTrustManagerFactory.INSTANCE)
160+
}
161+
val sslContext = ctxBuilder.build()
162+
val sslEngine = sslContext.newEngine(ctx.alloc(), configuration.host, configuration.port)
163+
if (configuration.ssl.mode >= Mode.VerifyFull) {
164+
val sslParams = sslEngine.getSSLParameters()
165+
sslParams.setEndpointIdentificationAlgorithm("HTTPS")
166+
sslEngine.setSSLParameters(sslParams)
167+
}
168+
val handler = new SslHandler(sslEngine)
169+
ctx.pipeline().addFirst(handler)
170+
handler.handshakeFuture.addListener(new FutureListener[channel.Channel]() {
171+
def operationComplete(future: io.netty.util.concurrent.Future[channel.Channel]) {
172+
if (future.isSuccess()) {
173+
ctx.writeAndFlush(new StartupMessage(properties))
174+
} else {
175+
connectionDelegate.onError(future.cause())
176+
}
177+
}
178+
})
179+
} else if (configuration.ssl.mode < Mode.Require) {
180+
ctx.writeAndFlush(new StartupMessage(properties))
181+
} else {
182+
connectionDelegate.onError(new IllegalArgumentException("SSL is not supported on server"))
183+
}
184+
130185
case m: ServerMessage => {
131186

132187
(m.kind : @switch) match {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package com.github.mauricio.async.db.postgresql.encoders
2+
3+
import io.netty.buffer.ByteBuf
4+
import io.netty.buffer.Unpooled
5+
6+
object SSLMessageEncoder {
7+
8+
def encode(): ByteBuf = {
9+
val buffer = Unpooled.buffer()
10+
buffer.writeInt(8)
11+
buffer.writeShort(1234)
12+
buffer.writeShort(5679)
13+
buffer
14+
}
15+
16+
}

postgresql-async/src/main/scala/com/github/mauricio/async/db/postgresql/encoders/StartupMessageEncoder.scala

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,11 @@ import com.github.mauricio.async.db.util.ByteBufferUtils
2121
import java.nio.charset.Charset
2222
import io.netty.buffer.{Unpooled, ByteBuf}
2323

24-
class StartupMessageEncoder(charset: Charset) extends Encoder {
24+
class StartupMessageEncoder(charset: Charset) {
2525

2626
//private val log = Log.getByName("StartupMessageEncoder")
2727

28-
override def encode(message: ClientMessage): ByteBuf = {
29-
30-
val startup = message.asInstanceOf[StartupMessage]
28+
def encode(startup: StartupMessage): ByteBuf = {
3129

3230
val buffer = Unpooled.buffer()
3331
buffer.writeInt(0)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
package com.github.mauricio.async.db.postgresql.messages.backend
2+
3+
case class SSLResponseMessage(supported: Boolean)

postgresql-async/src/main/scala/com/github/mauricio/async/db/postgresql/messages/backend/ServerMessage.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ object ServerMessage {
4343
final val Query = 'Q'
4444
final val RowDescription = 'T'
4545
final val ReadyForQuery = 'Z'
46-
final val Startup = '0'
4746
final val Sync = 'S'
4847
}
4948

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
package com.github.mauricio.async.db.postgresql.messages.frontend
2+
3+
trait InitialClientMessage
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
package com.github.mauricio.async.db.postgresql.messages.frontend
2+
3+
import com.github.mauricio.async.db.postgresql.messages.backend.ServerMessage
4+
5+
object SSLRequestMessage extends InitialClientMessage

postgresql-async/src/main/scala/com/github/mauricio/async/db/postgresql/messages/frontend/StartupMessage.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,4 @@
1616

1717
package com.github.mauricio.async.db.postgresql.messages.frontend
1818

19-
import com.github.mauricio.async.db.postgresql.messages.backend.ServerMessage
20-
21-
class StartupMessage(val parameters: List[(String, Any)]) extends ClientMessage(ServerMessage.Startup)
19+
class StartupMessage(val parameters: List[(String, Any)]) extends InitialClientMessage

postgresql-async/src/main/scala/com/github/mauricio/async/db/postgresql/util/ParserURL.scala

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,28 +16,37 @@ object ParserURL {
1616
val PGPORT = "port"
1717
val PGDBNAME = "database"
1818
val PGHOST = "host"
19-
val PGUSERNAME = "username"
19+
val PGUSERNAME = "user"
2020
val PGPASSWORD = "password"
2121

2222
val DEFAULT_PORT = "5432"
2323

24-
private val pgurl1 = """(jdbc:postgresql):(?://([^/:]*|\[.+\])(?::(\d+))?)?(?:/([^/?]*))?(?:\?user=(.*)&password=(.*))?""".r
25-
private val pgurl2 = """(postgres|postgresql)://(.*):(.*)@(.*):(\d+)/(.*)""".r
24+
private val pgurl1 = """(jdbc:postgresql):(?://([^/:]*|\[.+\])(?::(\d+))?)?(?:/([^/?]*))?(?:\?(.*))?""".r
25+
private val pgurl2 = """(postgres|postgresql)://(.*):(.*)@(.*):(\d+)/([^/?]*)(?:\?(.*))?""".r
2626

2727
def parse(connectionURL: String): Map[String, String] = {
2828
val properties: Map[String, String] = Map()
2929

30+
def parseOptions(optionsStr: String): Map[String, String] =
31+
optionsStr.split("&").map { o =>
32+
o.span(_ != '=') match {
33+
case (name, value) => name -> value.drop(1)
34+
}
35+
}.toMap
36+
3037
connectionURL match {
31-
case pgurl1(protocol, server, port, dbname, username, password) => {
38+
case pgurl1(protocol, server, port, dbname, params) => {
3239
var result = properties
3340
if (server != null) result += (PGHOST -> unwrapIpv6address(server))
3441
if (dbname != null && dbname.nonEmpty) result += (PGDBNAME -> dbname)
35-
if(port != null) result += (PGPORT -> port)
36-
if(username != null) result = (result + (PGUSERNAME -> username) + (PGPASSWORD -> password))
42+
if (port != null) result += (PGPORT -> port)
43+
if (params != null) result ++= parseOptions(params)
3744
result
3845
}
39-
case pgurl2(protocol, username, password, server, port, dbname) => {
40-
properties + (PGHOST -> unwrapIpv6address(server)) + (PGPORT -> port) + (PGDBNAME -> dbname) + (PGUSERNAME -> username) + (PGPASSWORD -> password)
46+
case pgurl2(protocol, username, password, server, port, dbname, params) => {
47+
var result = properties + (PGHOST -> unwrapIpv6address(server)) + (PGPORT -> port) + (PGDBNAME -> dbname) + (PGUSERNAME -> username) + (PGPASSWORD -> password)
48+
if (params != null) result ++= parseOptions(params)
49+
result
4150
}
4251
case _ => {
4352
logger.warn(s"Connection url '$connectionURL' could not be parsed.")

postgresql-async/src/main/scala/com/github/mauricio/async/db/postgresql/util/URLParser.scala

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,11 @@
1616

1717
package com.github.mauricio.async.db.postgresql.util
1818

19-
import com.github.mauricio.async.db.Configuration
19+
import com.github.mauricio.async.db.{Configuration, SSLConfiguration}
2020
import java.nio.charset.Charset
2121

2222
object URLParser {
2323

24-
private val Username = "username"
25-
private val Password = "password"
26-
2724
import Configuration.Default
2825

2926
def parse(url: String,
@@ -35,11 +32,12 @@ object URLParser {
3532
val port = properties.get(ParserURL.PGPORT).getOrElse(ParserURL.DEFAULT_PORT).toInt
3633

3734
new Configuration(
38-
username = properties.get(Username).getOrElse(Default.username),
39-
password = properties.get(Password),
35+
username = properties.get(ParserURL.PGUSERNAME).getOrElse(Default.username),
36+
password = properties.get(ParserURL.PGPASSWORD),
4037
database = properties.get(ParserURL.PGDBNAME),
4138
host = properties.getOrElse(ParserURL.PGHOST, Default.host),
4239
port = port,
40+
ssl = SSLConfiguration(properties),
4341
charset = charset
4442
)
4543

postgresql-async/src/test/scala/com/github/mauricio/async/db/postgresql/DatabaseTestHelper.scala

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,12 @@ package com.github.mauricio.async.db.postgresql
1818

1919
import com.github.mauricio.async.db.util.Log
2020
import com.github.mauricio.async.db.{Connection, Configuration}
21+
import java.io.File
2122
import java.util.concurrent.{TimeoutException, TimeUnit}
22-
import scala.Some
2323
import scala.concurrent.duration._
2424
import scala.concurrent.{Future, Await}
25+
import com.github.mauricio.async.db.SSLConfiguration
26+
import com.github.mauricio.async.db.SSLConfiguration.Mode
2527

2628
object DatabaseTestHelper {
2729
val log = Log.get[DatabaseTestHelper]
@@ -54,6 +56,16 @@ trait DatabaseTestHelper {
5456
withHandler(this.timeTestConfiguration, fn)
5557
}
5658

59+
def withSSLHandler[T](mode: SSLConfiguration.Mode.Value, host: String = "localhost", rootCert: Option[File] = Some(new File("script/server.crt")))(fn: (PostgreSQLConnection) => T): T = {
60+
val config = new Configuration(
61+
host = host,
62+
port = databasePort,
63+
username = "postgres",
64+
database = databaseName,
65+
ssl = SSLConfiguration(mode = mode, rootCert = rootCert))
66+
withHandler(config, fn)
67+
}
68+
5769
def withHandler[T](configuration: Configuration, fn: (PostgreSQLConnection) => T): T = {
5870

5971
val handler = new PostgreSQLConnection(configuration)

postgresql-async/src/test/scala/com/github/mauricio/async/db/postgresql/MessageDecoderSpec.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import java.util
2727

2828
class MessageDecoderSpec extends Specification {
2929

30-
val decoder = new MessageDecoder(CharsetUtil.UTF_8)
30+
val decoder = new MessageDecoder(false, CharsetUtil.UTF_8)
3131

3232
"message decoder" should {
3333

0 commit comments

Comments
 (0)