Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import io.netty.channel.nio.NioEventLoopGroup
import io.netty.util.CharsetUtil
import mu.KotlinLogging
import java.nio.charset.Charset
import java.nio.file.Path
import java.time.Duration
import java.util.concurrent.CompletionStage
import java.util.concurrent.Executor
Expand Down Expand Up @@ -43,6 +44,7 @@ private val logger = KotlinLogging.logger {}
* @param currentSchema optional database schema - postgresql only.
* @param socketPath path to unix domain socket file (on the local machine)
* @param credentialsProvider a credential provider used to inject credentials on demand
* @param rsaPublicKey path to the RSA public key, used for password encryption over unsafe connections
*
*/
class Configuration @JvmOverloads constructor(
Expand All @@ -63,7 +65,8 @@ class Configuration @JvmOverloads constructor(
val executionContext: Executor = ExecutorServiceUtils.CommonPool,
val currentSchema: String? = null,
val socketPath: String? = null,
val credentialsProvider: CredentialsProvider? = null
val credentialsProvider: CredentialsProvider? = null,
val rsaPublicKey: Path? = null,
) {
init {
if (socketPath != null && eventLoopGroup is NioEventLoopGroup) {
Expand Down Expand Up @@ -96,6 +99,7 @@ class Configuration @JvmOverloads constructor(
currentSchema: String? = null,
socketPath: String? = null,
credentialsProvider: CredentialsProvider? = null,
rsaPublicKey: Path? = null,
): Configuration {
return Configuration(
username = username ?: this.username,
Expand All @@ -116,6 +120,7 @@ class Configuration @JvmOverloads constructor(
currentSchema = currentSchema ?: this.currentSchema,
socketPath = socketPath ?: this.socketPath,
credentialsProvider = credentialsProvider ?: this.credentialsProvider,
rsaPublicKey = rsaPublicKey ?: this.rsaPublicKey,
)
}

Expand Down Expand Up @@ -143,6 +148,7 @@ class Configuration @JvmOverloads constructor(
if (currentSchema != other.currentSchema) return false
if (socketPath != other.socketPath) return false
if (credentialsProvider != other.credentialsProvider) return false
if (rsaPublicKey != other.rsaPublicKey) return false

return true
}
Expand All @@ -166,11 +172,12 @@ class Configuration @JvmOverloads constructor(
result = 31 * result + (currentSchema?.hashCode() ?: 0)
result = 31 * result + (socketPath?.hashCode() ?: 0)
result = 31 * result + (credentialsProvider?.hashCode() ?: 0)
result = 31 * result + (rsaPublicKey?.hashCode() ?: 0)
return result
}

override fun toString(): String {
return "Configuration(username='$username', host='$host', port=$port, password=****, database=$database, ssl=$ssl, charset=$charset, maximumMessageSize=$maximumMessageSize, allocator=$allocator, connectionTimeout=$connectionTimeout, queryTimeout=$queryTimeout, applicationName=$applicationName, interceptors=$interceptors, eventLoopGroup=$eventLoopGroup, executionContext=$executionContext, currentSchema=$currentSchema, socketPath=$socketPath, credentialsProvider=$credentialsProvider)"
return "Configuration(username='$username', host='$host', port=$port, password=****, database=$database, ssl=$ssl, charset=$charset, maximumMessageSize=$maximumMessageSize, allocator=$allocator, connectionTimeout=$connectionTimeout, queryTimeout=$queryTimeout, applicationName=$applicationName, interceptors=$interceptors, eventLoopGroup=$eventLoopGroup, executionContext=$executionContext, currentSchema=$currentSchema, socketPath=$socketPath, credentialsProvider=$credentialsProvider, rsaPublicKey=$rsaPublicKey)"
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@ import com.github.jasync.sql.db.exceptions.InsufficientParametersException
import com.github.jasync.sql.db.interceptor.PreparedStatementParams
import com.github.jasync.sql.db.mysql.codec.MySQLConnectionHandler
import com.github.jasync.sql.db.mysql.codec.MySQLHandlerDelegate
import com.github.jasync.sql.db.mysql.encoder.auth.AuthenticationMethod
import com.github.jasync.sql.db.mysql.exceptions.MySQLException
import com.github.jasync.sql.db.mysql.message.client.AuthenticationSwitchResponse
import com.github.jasync.sql.db.mysql.message.client.CapabilityRequestMessage
import com.github.jasync.sql.db.mysql.message.client.HandshakeResponseMessage
import com.github.jasync.sql.db.mysql.message.server.AuthMoreDataMessage
import com.github.jasync.sql.db.mysql.message.server.AuthenticationSwitchRequest
import com.github.jasync.sql.db.mysql.message.server.EOFMessage
import com.github.jasync.sql.db.mysql.message.server.ErrorMessage
Expand Down Expand Up @@ -93,6 +95,8 @@ class MySQLConnection @JvmOverloads constructor(
private var connected = false
private var lastException: Throwable? = null
private var serverVersion: Version? = null
private var authenticationMethod: String? = null
private var authenticationSeed: ByteArray? = null

object StatusFlags {
// https://dev.mysql.com/doc/internals/en/status-flags.html
Expand Down Expand Up @@ -259,6 +263,8 @@ class MySQLConnection @JvmOverloads constructor(
override fun onHandshake(message: HandshakeMessage) {
this.serverVersion = parseVersion(message.serverVersion)
this.serverStatus = message.statusFlags
this.authenticationMethod = message.authenticationMethod
this.authenticationSeed = message.seed

val switchToSsl = when (this.configuration.ssl.mode) {
SSLConfiguration.Mode.Disable -> false
Expand Down Expand Up @@ -298,7 +304,9 @@ class MySQLConnection @JvmOverloads constructor(
message.authenticationMethod,
database = configuration.database,
password = configuration.password,
appName = configuration.applicationName
appName = configuration.applicationName,
sslConfiguration = configuration.ssl,
rsaPublicKey = configuration.rsaPublicKey,
)

if (!switchToSsl) {
Expand Down Expand Up @@ -336,7 +344,37 @@ class MySQLConnection @JvmOverloads constructor(
}

override fun switchAuthentication(message: AuthenticationSwitchRequest) {
this.connectionHandler.write(AuthenticationSwitchResponse(configuration.password, message))
val response = AuthenticationSwitchResponse(
configuration.password,
configuration.ssl,
configuration.rsaPublicKey,
message
)

this.connectionHandler.write(response)
}

override fun onAuthMoreData(message: AuthMoreDataMessage) {
if (message.isSuccess()) {
// Do nothing. This message will be followed by an `OkMessage`.
return
}

if (authenticationMethod != AuthenticationMethod.CachingSha2) {
throw IllegalStateException(
"AuthMoreDataMessage is only supported for '${AuthenticationMethod.CachingSha2}' method"
)
}

val request = AuthenticationSwitchRequest(AuthenticationMethod.Sha256, authenticationSeed!!)
val response = AuthenticationSwitchResponse(
configuration.password,
configuration.ssl,
configuration.rsaPublicKey,
request
)

this.connectionHandler.write(response)
}

override fun sendQueryDirect(query: String): CompletableFuture<QueryResult> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import com.github.jasync.sql.db.Configuration
import com.github.jasync.sql.db.exceptions.DatabaseException
import com.github.jasync.sql.db.general.MutableResultSet
import com.github.jasync.sql.db.mysql.binary.BinaryRowDecoder
import com.github.jasync.sql.db.mysql.encoder.auth.AuthenticationMethod
import com.github.jasync.sql.db.mysql.message.client.AuthenticationSwitchResponse
import com.github.jasync.sql.db.mysql.message.client.CapabilityRequestMessage
import com.github.jasync.sql.db.mysql.message.client.CloseStatementMessage
Expand Down Expand Up @@ -132,18 +131,7 @@ class MySQLConnectionHandler(
this.handleEOF(message)
}
ServerMessage.AuthMoreData -> {
val m = message as AuthMoreDataMessage

if (!m.isSuccess()) {
if (!sslEstablished) {
throw IllegalStateException(
"Full authentication mode for ${AuthenticationMethod.CachingSha2} requires SSL"
)
}

val request = AuthenticationSwitchRequest(AuthenticationMethod.CachingSha2, null)
handlerDelegate.switchAuthentication(request)
}
handlerDelegate.onAuthMoreData(message as AuthMoreDataMessage)
}
ServerMessage.ColumnDefinition -> {
val m = message as ColumnDefinitionMessage
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.github.jasync.sql.db.mysql.codec

import com.github.jasync.sql.db.ResultSet
import com.github.jasync.sql.db.mysql.message.server.AuthMoreDataMessage
import com.github.jasync.sql.db.mysql.message.server.AuthenticationSwitchRequest
import com.github.jasync.sql.db.mysql.message.server.EOFMessage
import com.github.jasync.sql.db.mysql.message.server.ErrorMessage
Expand All @@ -18,5 +19,6 @@ interface MySQLHandlerDelegate {
fun connected(ctx: ChannelHandlerContext)
fun onResultSet(resultSet: ResultSet, message: EOFMessage)
fun switchAuthentication(message: AuthenticationSwitchRequest)
fun onAuthMoreData(message: AuthMoreDataMessage)
fun unregistered()
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,13 @@ class AuthenticationSwitchResponseEncoder(val charset: Charset) : MessageEncoder

val buffer = ByteBufferUtils.packetBuffer()

val bytes =
authenticator.generateAuthentication(charset, switch.password, switch.request.seed)
val bytes = authenticator.generateAuthentication(
charset,
switch.password,
switch.request.seed,
switch.sslConfiguration,
switch.rsaPublicKey,
)
buffer.writeBytes(bytes)

return buffer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class HandshakeResponseEncoder(private val charset: Charset, private val headerE
val authenticator = this.authenticationMethods.getOrElse(
method
) { throw UnsupportedAuthenticationMethodException(method) }
val bytes = authenticator.generateAuthentication(charset, m.password, m.seed)
val bytes = authenticator.generateAuthentication(charset, m.password, m.seed, m.sslConfiguration, m.rsaPublicKey)
buffer.writeByte(bytes.length)
buffer.writeBytes(bytes)
} else {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
package com.github.jasync.sql.db.mysql.encoder.auth

import com.github.jasync.sql.db.SSLConfiguration
import java.nio.charset.Charset
import java.nio.file.Path

interface AuthenticationMethod {

fun generateAuthentication(charset: Charset, password: String?, seed: ByteArray?): ByteArray
fun generateAuthentication(
charset: Charset,
password: String?,
seed: ByteArray,
sslConfiguration: SSLConfiguration,
rsaPublicKey: Path?,
): ByteArray

companion object {
const val CachingSha2 = "caching_sha2_password"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package com.github.jasync.sql.db.mysql.encoder.auth

import com.github.jasync.sql.db.util.length
import java.nio.charset.Charset
import java.security.MessageDigest
import kotlin.experimental.xor
Expand Down Expand Up @@ -32,11 +31,8 @@ object AuthenticationScrambler {
}

val result = messageDigest.digest()
var counter = 0

while (counter < result.length) {
result[counter] = (result[counter] xor initialDigest[counter])
counter += 1
for ((index, byte) in result.withIndex()) {
result[index] = byte xor initialDigest[index]
}

return result
Expand Down
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
package com.github.jasync.sql.db.mysql.encoder.auth

import com.github.jasync.sql.db.SSLConfiguration
import java.nio.charset.Charset
import java.nio.file.Path

object CachingSha2PasswordAuthentication : AuthenticationMethod {

private val EmptyArray = ByteArray(0)

override fun generateAuthentication(charset: Charset, password: String?, seed: ByteArray?): ByteArray {
override fun generateAuthentication(
charset: Charset,
password: String?,
seed: ByteArray,
sslConfiguration: SSLConfiguration,
rsaPublicKey: Path?,
): ByteArray {
return if (password != null) {
if (seed != null) {
// Fast authentication mode. Requires seed, but not SSL.
AuthenticationScrambler.scramble411("SHA-256", password, charset, seed, false)
} else {
// Full authentication mode.
// Since this sends the plaintext password, SSL is required.
// Without SSL, the server always rejects the password.
Sha256PasswordAuthentication.generateAuthentication(charset, password, null)
}
AuthenticationScrambler.scramble411("SHA-256", password, charset, seed, false)
} else {
EmptyArray
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
package com.github.jasync.sql.db.mysql.encoder.auth

import com.github.jasync.sql.db.SSLConfiguration
import java.nio.charset.Charset
import java.nio.file.Path

object MySQLNativePasswordAuthentication : AuthenticationMethod {

private val EmptyArray = ByteArray(0)

override fun generateAuthentication(charset: Charset, password: String?, seed: ByteArray?): ByteArray {
requireNotNull(seed) { "Seed should not be null" }

override fun generateAuthentication(
charset: Charset,
password: String?,
seed: ByteArray,
sslConfiguration: SSLConfiguration,
rsaPublicKey: Path?,
): ByteArray {
return if (password != null) {
AuthenticationScrambler.scramble411("SHA-1", password, charset, seed, true)
} else {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
package com.github.jasync.sql.db.mysql.encoder.auth

import com.github.jasync.sql.db.SSLConfiguration
import com.github.jasync.sql.db.util.length
import java.nio.charset.Charset
import java.nio.file.Path
import kotlin.math.floor

@Suppress("RedundantExplicitType", "UNUSED_VALUE", "VARIABLE_WITH_REDUNDANT_INITIALIZER")
object OldPasswordAuthentication : AuthenticationMethod {

private val EmptyArray = ByteArray(0)

override fun generateAuthentication(charset: Charset, password: String?, seed: ByteArray?): ByteArray {
requireNotNull(seed) { "Seed should not be null" }

override fun generateAuthentication(
charset: Charset,
password: String?,
seed: ByteArray,
sslConfiguration: SSLConfiguration,
rsaPublicKey: Path?,
): ByteArray {
return when {
!password.isNullOrEmpty() -> {
// The native authentication handshake will provide a 20-byte challenge.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Authentication methods
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome!


This driver implements multiple authentication methods available in MySQL/MariaDB and PostgresSQL.
The step-by-step authentication flow and implementation details are described below.

## `caching_sha2_password`

This is the default authentication method since MySQL 8.0.
Official documentation can be found [here][caching-sha2-password].

The fast authentication flow (using password scrambling) is as follows:
1. During the handshake, MySQL server sends the authentication seed (nonce).
2. The driver scrambles the password using SHA-256 with `AuthenticationScrambler`, and sends `HandshakeResponse`.
3. If the password entry is cached on the server, it performs fast authentication, and returns `AuthMoreData`
message indicating success (`data=3`). This is followed by `OkMessage` and the authentication flow completes.
4. In case the password is not cached, the server requires us to switch to full authentication, and returns
`AuthMoreData` with `data=4`.

If we need to perform the full authentication flow (using SHA-256 hashing), the process is as follows:
1. If we're connected over SSL, we can send `AuthenticationSwitchResponse` with a plaintext password.
Note that if we try to do the same over an unsafe connection, the server always rejects the password.
2. If we are not connected over SSL, we can use the provided `rsaPublicKey` (used by the server) to encrypt the
password, and send it as `AuthenticationSwitchResponse`. See `Sha256PasswordAuthentication` for
implementation details.
3. If `rsaPublicKey` is not specified, the public key used to encrypt the password can be fetched from the
server. **This is currently not supported by the driver.**
4. If the authentication was successful, the server caches the password entry, and returns `OkMessage`.
The next authentication request for the specified user can therefore be done with fast authentication.

## `sha256_password`

This authentication method has been deprecated in favor of `caching_sha2_password` in MySQL 8.0, and works the
same as its full authentication flow.

## `mysql_native_password`

This was the default authentication method until MySQL 8.0.
Official documentation can be found [here][mysql-native-password].

The authentication flow is as follows:
1. During the handshake, MySQL server sends the authentication seed (nonce).
2. The driver scrambles the password using SHA-1 with `AuthenticationScrambler`, and sends `HandshakeResponse`.
3. If the password is correct the server sends `OkMessage`, and `ErrorMessage` otherwise.

## `mysql_old_password`

This method was mainly used before MySQL 4.1. It was deprecated in MySQL 5.6 and removed in MySQL 5.7.
Official documentation can be found [here][mysql-old-password].

The authentication flow is as follows:
1. During the handshake, MySQL server sends the authentication seed (nonce). This can either be 8 bytes on older
versions of MySQL, or 20 bytes if the server uses the `mysql_native_password` method as the default. In the
latter case, the driver uses the first 8 bytes of the seed.
2. The driver hashes the password using a proprietary algorithm, and sends `HandshakeResponse`. See
`OldPasswordAuthentication` for implementation details.
3. If the password is correct the server sends `OkMessage`, and `ErrorMessage` otherwise.

[caching-sha2-password]: https://dev.mysql.com/doc/dev/mysql-server/8.0.32/page_caching_sha2_authentication_exchanges.html
[mysql-native-password]: https://dev.mysql.com/doc/dev/mysql-server/8.0.32/page_protocol_connection_phase_authentication_methods_native_password_authentication.html
[mysql-old-password]: https://dev.mysql.com/doc/dev/mysql-server/8.0.32/page_protocol_connection_phase_authentication_methods.html#page_protocol_connection_phase_authentication_methods_old_password_authentication
Loading