Skip to content

Commit e5be218

Browse files
Implement password encryption using an RSA public key
1 parent 9efb120 commit e5be218

26 files changed

+275
-61
lines changed

db-async-common/src/main/java/com/github/jasync/sql/db/Configuration.kt

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,12 @@ import io.netty.channel.nio.NioEventLoopGroup
1010
import io.netty.util.CharsetUtil
1111
import mu.KotlinLogging
1212
import java.nio.charset.Charset
13+
import java.nio.file.Path
1314
import java.time.Duration
1415
import java.util.concurrent.CompletionStage
1516
import java.util.concurrent.Executor
1617
import java.util.function.Supplier
18+
import kotlin.io.path.notExists
1719

1820
private val logger = KotlinLogging.logger {}
1921

@@ -27,6 +29,7 @@ private val logger = KotlinLogging.logger {}
2729
* @param password password, defaults to no password
2830
* @param database database name, defaults to no database
2931
* @param ssl ssl configuration
32+
* @param rsaPublicKey path to the RSA public key, used for password encryption over unsafe connections
3033
* @param charset charset for the connection, defaults to UTF-8, make sure you know what you are doing if you
3134
* change this
3235
* @param maximumMessageSize the maximum size a message from the server could possibly have, this limits possible
@@ -52,6 +55,7 @@ class Configuration @JvmOverloads constructor(
5255
val password: String? = null,
5356
val database: String? = null,
5457
val ssl: SSLConfiguration = SSLConfiguration(),
58+
val rsaPublicKey: Path? = null,
5559
val charset: Charset = CharsetUtil.UTF_8,
5660
val maximumMessageSize: Int = 16777216,
5761
val allocator: ByteBufAllocator = PooledByteBufAllocator.DEFAULT,
@@ -72,6 +76,10 @@ class Configuration @JvmOverloads constructor(
7276
"Please change eventLoopGroup configuration."
7377
}
7478
}
79+
80+
if (rsaPublicKey != null && rsaPublicKey.notExists()) {
81+
throw IllegalArgumentException("Public key file '$rsaPublicKey' does not exist")
82+
}
7583
}
7684

7785
fun resolveCredentials(): CompletionStage<Credentials> = (credentialsProvider ?: StaticCredentialsProvider(username, password)).provide()
@@ -84,6 +92,7 @@ class Configuration @JvmOverloads constructor(
8492
password: String? = null,
8593
database: String? = null,
8694
ssl: SSLConfiguration? = null,
95+
rsaPublicKey: Path? = null,
8796
charset: Charset? = null,
8897
maximumMessageSize: Int? = null,
8998
allocator: ByteBufAllocator? = null,
@@ -104,6 +113,7 @@ class Configuration @JvmOverloads constructor(
104113
password = password ?: this.password,
105114
database = database ?: this.database,
106115
ssl = ssl ?: this.ssl,
116+
rsaPublicKey = rsaPublicKey ?: this.rsaPublicKey,
107117
charset = charset ?: this.charset,
108118
maximumMessageSize = maximumMessageSize ?: this.maximumMessageSize,
109119
allocator = allocator ?: this.allocator,
@@ -131,6 +141,7 @@ class Configuration @JvmOverloads constructor(
131141
if (password != other.password) return false
132142
if (database != other.database) return false
133143
if (ssl != other.ssl) return false
144+
if (rsaPublicKey != other.rsaPublicKey) return false
134145
if (charset != other.charset) return false
135146
if (maximumMessageSize != other.maximumMessageSize) return false
136147
if (allocator != other.allocator) return false
@@ -154,6 +165,7 @@ class Configuration @JvmOverloads constructor(
154165
result = 31 * result + (password?.hashCode() ?: 0)
155166
result = 31 * result + (database?.hashCode() ?: 0)
156167
result = 31 * result + ssl.hashCode()
168+
result = 31 * result + rsaPublicKey.hashCode()
157169
result = 31 * result + charset.hashCode()
158170
result = 31 * result + maximumMessageSize
159171
result = 31 * result + allocator.hashCode()
@@ -170,7 +182,7 @@ class Configuration @JvmOverloads constructor(
170182
}
171183

172184
override fun toString(): String {
173-
return "Configuration(username='$username', host='$host', port=$port, password=****, database=$database, ssl=$ssl, charset=$charset, maximumMessageSize=$maximumMessageSize, allocator=$allocator, connectionTimeout=$connectionTimeout, queryTimeout=$queryTimeout, applicationName=$applicationName, interceptors=$interceptors, eventLoopGroup=$eventLoopGroup, executionContext=$executionContext, currentSchema=$currentSchema, socketPath=$socketPath, credentialsProvider=$credentialsProvider)"
185+
return "Configuration(username='$username', host='$host', port=$port, password=****, database=$database, ssl=$ssl, rsaPublicKey=$rsaPublicKey, charset=$charset, maximumMessageSize=$maximumMessageSize, allocator=$allocator, connectionTimeout=$connectionTimeout, queryTimeout=$queryTimeout, applicationName=$applicationName, interceptors=$interceptors, eventLoopGroup=$eventLoopGroup, executionContext=$executionContext, currentSchema=$currentSchema, socketPath=$socketPath, credentialsProvider=$credentialsProvider)"
174186
}
175187
}
176188

mysql-async/src/main/java/com/github/jasync/sql/db/mysql/MySQLConnection.kt

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,12 @@ import com.github.jasync.sql.db.exceptions.InsufficientParametersException
1313
import com.github.jasync.sql.db.interceptor.PreparedStatementParams
1414
import com.github.jasync.sql.db.mysql.codec.MySQLConnectionHandler
1515
import com.github.jasync.sql.db.mysql.codec.MySQLHandlerDelegate
16+
import com.github.jasync.sql.db.mysql.encoder.auth.AuthenticationMethod
1617
import com.github.jasync.sql.db.mysql.exceptions.MySQLException
1718
import com.github.jasync.sql.db.mysql.message.client.AuthenticationSwitchResponse
1819
import com.github.jasync.sql.db.mysql.message.client.CapabilityRequestMessage
1920
import com.github.jasync.sql.db.mysql.message.client.HandshakeResponseMessage
21+
import com.github.jasync.sql.db.mysql.message.server.AuthMoreDataMessage
2022
import com.github.jasync.sql.db.mysql.message.server.AuthenticationSwitchRequest
2123
import com.github.jasync.sql.db.mysql.message.server.EOFMessage
2224
import com.github.jasync.sql.db.mysql.message.server.ErrorMessage
@@ -93,6 +95,8 @@ class MySQLConnection @JvmOverloads constructor(
9395
private var connected = false
9496
private var lastException: Throwable? = null
9597
private var serverVersion: Version? = null
98+
private var authenticationMethod: String? = null
99+
private var authenticationSeed: ByteArray? = null
96100

97101
object StatusFlags {
98102
// https://dev.mysql.com/doc/internals/en/status-flags.html
@@ -259,6 +263,8 @@ class MySQLConnection @JvmOverloads constructor(
259263
override fun onHandshake(message: HandshakeMessage) {
260264
this.serverVersion = parseVersion(message.serverVersion)
261265
this.serverStatus = message.statusFlags
266+
this.authenticationMethod = message.authenticationMethod
267+
this.authenticationSeed = message.seed
262268

263269
val switchToSsl = when (this.configuration.ssl.mode) {
264270
SSLConfiguration.Mode.Disable -> false
@@ -298,7 +304,8 @@ class MySQLConnection @JvmOverloads constructor(
298304
message.authenticationMethod,
299305
database = configuration.database,
300306
password = configuration.password,
301-
appName = configuration.applicationName
307+
appName = configuration.applicationName,
308+
configuration = configuration,
302309
)
303310

304311
if (!switchToSsl) {
@@ -336,7 +343,23 @@ class MySQLConnection @JvmOverloads constructor(
336343
}
337344

338345
override fun switchAuthentication(message: AuthenticationSwitchRequest) {
339-
this.connectionHandler.write(AuthenticationSwitchResponse(configuration.password, message))
346+
this.connectionHandler.write(AuthenticationSwitchResponse(configuration, message))
347+
}
348+
349+
override fun onAuthMoreData(message: AuthMoreDataMessage) {
350+
if (message.isSuccess()) {
351+
// Do nothing. This message will be followed by an `OkMessage`.
352+
return
353+
}
354+
355+
if (authenticationMethod != AuthenticationMethod.CachingSha2) {
356+
throw IllegalStateException(
357+
"AuthMoreDataMessage is only supported for '${AuthenticationMethod.CachingSha2}' method"
358+
)
359+
}
360+
361+
val m = AuthenticationSwitchRequest(AuthenticationMethod.Sha256, authenticationSeed!!)
362+
this.connectionHandler.write(AuthenticationSwitchResponse(configuration, m))
340363
}
341364

342365
override fun sendQueryDirect(query: String): CompletableFuture<QueryResult> {

mysql-async/src/main/java/com/github/jasync/sql/db/mysql/codec/MySQLConnectionHandler.kt

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import com.github.jasync.sql.db.Configuration
44
import com.github.jasync.sql.db.exceptions.DatabaseException
55
import com.github.jasync.sql.db.general.MutableResultSet
66
import com.github.jasync.sql.db.mysql.binary.BinaryRowDecoder
7-
import com.github.jasync.sql.db.mysql.encoder.auth.AuthenticationMethod
87
import com.github.jasync.sql.db.mysql.message.client.AuthenticationSwitchResponse
98
import com.github.jasync.sql.db.mysql.message.client.CapabilityRequestMessage
109
import com.github.jasync.sql.db.mysql.message.client.CloseStatementMessage
@@ -132,18 +131,7 @@ class MySQLConnectionHandler(
132131
this.handleEOF(message)
133132
}
134133
ServerMessage.AuthMoreData -> {
135-
val m = message as AuthMoreDataMessage
136-
137-
if (!m.isSuccess()) {
138-
if (!sslEstablished) {
139-
throw IllegalStateException(
140-
"Full authentication mode for ${AuthenticationMethod.CachingSha2} requires SSL"
141-
)
142-
}
143-
144-
val request = AuthenticationSwitchRequest(AuthenticationMethod.CachingSha2, null)
145-
handlerDelegate.switchAuthentication(request)
146-
}
134+
handlerDelegate.onAuthMoreData(message as AuthMoreDataMessage)
147135
}
148136
ServerMessage.ColumnDefinition -> {
149137
val m = message as ColumnDefinitionMessage

mysql-async/src/main/java/com/github/jasync/sql/db/mysql/codec/MySQLHandlerDelegate.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package com.github.jasync.sql.db.mysql.codec
22

33
import com.github.jasync.sql.db.ResultSet
4+
import com.github.jasync.sql.db.mysql.message.server.AuthMoreDataMessage
45
import com.github.jasync.sql.db.mysql.message.server.AuthenticationSwitchRequest
56
import com.github.jasync.sql.db.mysql.message.server.EOFMessage
67
import com.github.jasync.sql.db.mysql.message.server.ErrorMessage
@@ -18,5 +19,6 @@ interface MySQLHandlerDelegate {
1819
fun connected(ctx: ChannelHandlerContext)
1920
fun onResultSet(resultSet: ResultSet, message: EOFMessage)
2021
fun switchAuthentication(message: AuthenticationSwitchRequest)
22+
fun onAuthMoreData(message: AuthMoreDataMessage)
2123
fun unregistered()
2224
}

mysql-async/src/main/java/com/github/jasync/sql/db/mysql/encoder/AuthenticationSwitchResponseEncoder.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class AuthenticationSwitchResponseEncoder(val charset: Charset) : MessageEncoder
2121
val buffer = ByteBufferUtils.packetBuffer()
2222

2323
val bytes =
24-
authenticator.generateAuthentication(charset, switch.password, switch.request.seed)
24+
authenticator.generateAuthentication(charset, switch.configuration, switch.request.seed)
2525
buffer.writeBytes(bytes)
2626

2727
return buffer

mysql-async/src/main/java/com/github/jasync/sql/db/mysql/encoder/HandshakeResponseEncoder.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class HandshakeResponseEncoder(private val charset: Charset, private val headerE
3131
val authenticator = this.authenticationMethods.getOrElse(
3232
method
3333
) { throw UnsupportedAuthenticationMethodException(method) }
34-
val bytes = authenticator.generateAuthentication(charset, m.password, m.seed)
34+
val bytes = authenticator.generateAuthentication(charset, m.configuration, m.seed)
3535
buffer.writeByte(bytes.length)
3636
buffer.writeBytes(bytes)
3737
} else {

mysql-async/src/main/java/com/github/jasync/sql/db/mysql/encoder/auth/AuthenticationMethod.kt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
package com.github.jasync.sql.db.mysql.encoder.auth
22

3+
import com.github.jasync.sql.db.Configuration
34
import java.nio.charset.Charset
45

56
interface AuthenticationMethod {
67

7-
fun generateAuthentication(charset: Charset, password: String?, seed: ByteArray?): ByteArray
8+
fun generateAuthentication(charset: Charset, configuration: Configuration, seed: ByteArray): ByteArray
89

910
companion object {
1011
const val CachingSha2 = "caching_sha2_password"

mysql-async/src/main/java/com/github/jasync/sql/db/mysql/encoder/auth/AuthenticationScrambler.kt

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
package com.github.jasync.sql.db.mysql.encoder.auth
22

3-
import com.github.jasync.sql.db.util.length
43
import java.nio.charset.Charset
54
import java.security.MessageDigest
65
import kotlin.experimental.xor
@@ -32,11 +31,8 @@ object AuthenticationScrambler {
3231
}
3332

3433
val result = messageDigest.digest()
35-
var counter = 0
36-
37-
while (counter < result.length) {
38-
result[counter] = (result[counter] xor initialDigest[counter])
39-
counter += 1
34+
for ((index, byte) in result.withIndex()) {
35+
result[index] = byte xor initialDigest[index]
4036
}
4137

4238
return result

mysql-async/src/main/java/com/github/jasync/sql/db/mysql/encoder/auth/CachingSha2PasswordAuthentication.kt

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,17 @@
11
package com.github.jasync.sql.db.mysql.encoder.auth
22

3+
import com.github.jasync.sql.db.Configuration
34
import java.nio.charset.Charset
45

56
object CachingSha2PasswordAuthentication : AuthenticationMethod {
67

78
private val EmptyArray = ByteArray(0)
89

9-
override fun generateAuthentication(charset: Charset, password: String?, seed: ByteArray?): ByteArray {
10+
override fun generateAuthentication(charset: Charset, configuration: Configuration, seed: ByteArray): ByteArray {
11+
val password = configuration.password
12+
1013
return if (password != null) {
11-
if (seed != null) {
12-
// Fast authentication mode. Requires seed, but not SSL.
13-
AuthenticationScrambler.scramble411("SHA-256", password, charset, seed, false)
14-
} else {
15-
// Full authentication mode.
16-
// Since this sends the plaintext password, SSL is required.
17-
// Without SSL, the server always rejects the password.
18-
Sha256PasswordAuthentication.generateAuthentication(charset, password, null)
19-
}
14+
AuthenticationScrambler.scramble411("SHA-256", password, charset, seed, false)
2015
} else {
2116
EmptyArray
2217
}

mysql-async/src/main/java/com/github/jasync/sql/db/mysql/encoder/auth/MySQLNativePasswordAuthentication.kt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
package com.github.jasync.sql.db.mysql.encoder.auth
22

3+
import com.github.jasync.sql.db.Configuration
34
import java.nio.charset.Charset
45

56
object MySQLNativePasswordAuthentication : AuthenticationMethod {
67

78
private val EmptyArray = ByteArray(0)
89

9-
override fun generateAuthentication(charset: Charset, password: String?, seed: ByteArray?): ByteArray {
10-
requireNotNull(seed) { "Seed should not be null" }
10+
override fun generateAuthentication(charset: Charset, configuration: Configuration, seed: ByteArray): ByteArray {
11+
val password = configuration.password
1112

1213
return if (password != null) {
1314
AuthenticationScrambler.scramble411("SHA-1", password, charset, seed, true)

0 commit comments

Comments
 (0)