Skip to content

Commit 528d86e

Browse files
committed
Fixing wrong conditional that was causing the code to evaluate a query when all it should have done was accepting the connection
1 parent feabd1d commit 528d86e

File tree

7 files changed

+150
-46
lines changed

7 files changed

+150
-46
lines changed

db-async-common/src/main/scala/com/github/mauricio/async/db/util/NettyUtils.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,6 @@ import io.netty.channel.nio.NioEventLoopGroup
1919

2020
object NettyUtils {
2121

22-
lazy val DetaultEventLoopGroup = new NioEventLoopGroup()
22+
lazy val DetaultEventLoopGroup = new NioEventLoopGroup(0, DaemonThreadsFactory)
2323

2424
}

mysql-async/src/main/scala/com/github/mauricio/async/db/mysql/MySQLConnection.scala

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
package com.github.mauricio.async.db.mysql
1818

1919
import com.github.mauricio.async.db._
20-
import com.github.mauricio.async.db.exceptions.ConnectionStillRunningQueryException
20+
import com.github.mauricio.async.db.exceptions.{ConnectionNotConnectedException, ConnectionStillRunningQueryException}
2121
import com.github.mauricio.async.db.mysql.codec.{MySQLHandlerDelegate, MySQLConnectionHandler}
2222
import com.github.mauricio.async.db.mysql.exceptions.MySQLException
2323
import com.github.mauricio.async.db.mysql.message.client._
@@ -54,9 +54,16 @@ class MySQLConnection(
5454
charsetMapper.toInt(configuration.charset)
5555

5656
private final val connectionCount = MySQLConnection.Counter.incrementAndGet()
57+
private final val connectionId = s"[mysql-connection-$connectionCount]"
5758
private implicit val internalPool = executionContext
5859

59-
private final val connectionHandler = new MySQLConnectionHandler(configuration, charsetMapper, this, group, executionContext)
60+
private final val connectionHandler = new MySQLConnectionHandler(
61+
configuration,
62+
charsetMapper,
63+
this,
64+
group,
65+
executionContext,
66+
connectionId)
6067

6168
private final val connectionPromise = Promise[Connection]()
6269
private final val disconnectionPromise = Promise[Connection]()
@@ -98,17 +105,17 @@ class MySQLConnection(
98105
}
99106

100107
override def connected(ctx: ChannelHandlerContext) {
101-
log.debug("Connected to {}", ctx.channel.remoteAddress)
108+
log.debug(s"$connectionId Connected to {}", ctx.channel.remoteAddress)
102109
this.connected = true
103110
}
104111

105112
override def exceptionCaught(throwable: Throwable) {
106-
log.error("Transport failure", throwable)
113+
log.error(s"$connectionId Transport failure ", throwable)
107114
setException(throwable)
108115
}
109116

110117
override def onError(message: ErrorMessage) {
111-
log.error("Received an error message -> {}", message)
118+
log.error(s"$connectionId Received an error message -> {}", message)
112119
val exception = new MySQLException(message)
113120
exception.fillInStackTrace()
114121
this.setException(exception)
@@ -121,20 +128,21 @@ class MySQLConnection(
121128
}
122129

123130
override def onOk(message: OkMessage) {
124-
this.connectionPromise.trySuccess(this)
125-
126-
if (this.isQuerying) {
127-
this.succeedQueryPromise(
128-
new MySQLQueryResult(
129-
message.affectedRows,
130-
message.message,
131-
message.lastInsertId,
132-
message.statusFlags,
133-
message.warnings
131+
if ( !this.connectionPromise.isCompleted ) {
132+
this.connectionPromise.success(this)
133+
} else {
134+
if (this.isQuerying) {
135+
this.succeedQueryPromise(
136+
new MySQLQueryResult(
137+
message.affectedRows,
138+
message.message,
139+
message.lastInsertId,
140+
message.statusFlags,
141+
message.warnings
142+
)
134143
)
135-
)
144+
}
136145
}
137-
138146
}
139147

140148
def onEOF(message: EOFMessage) {
@@ -152,7 +160,6 @@ class MySQLConnection(
152160
}
153161

154162
override def onHandshake(message: HandshakeMessage) {
155-
156163
this.serverVersion = Version(message.serverVersion)
157164

158165
this.connectionHandler.write(new HandshakeResponseMessage(

mysql-async/src/main/scala/com/github/mauricio/async/db/mysql/codec/MySQLConnectionHandler.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,16 @@ class MySQLConnectionHandler(
5858
charsetMapper: CharsetMapper,
5959
handlerDelegate: MySQLHandlerDelegate,
6060
group : EventLoopGroup,
61-
executionContext : ExecutionContext
61+
executionContext : ExecutionContext,
62+
connectionId : String
6263
)
6364
extends SimpleChannelInboundHandler[Object] {
6465

6566
private implicit val internalPool = executionContext
6667

6768
private final val bootstrap = new Bootstrap().group(this.group)
6869
private final val connectionPromise = Promise[MySQLConnectionHandler]
69-
private final val decoder = new MySQLFrameDecoder(configuration.charset)
70+
private final val decoder = new MySQLFrameDecoder(configuration.charset, connectionId)
7071
private final val encoder = new MySQLOneToOneEncoder(configuration.charset, charsetMapper)
7172
private final val currentParameters = new ArrayBuffer[ColumnDefinitionMessage]()
7273
private final val currentColumns = new ArrayBuffer[ColumnDefinitionMessage]()

mysql-async/src/main/scala/com/github/mauricio/async/db/mysql/codec/MySQLFrameDecoder.scala

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -22,40 +22,42 @@ import com.github.mauricio.async.db.mysql.message.server._
2222
import com.github.mauricio.async.db.util.ChannelUtils.read3BytesInt
2323
import com.github.mauricio.async.db.util.ChannelWrapper.bufferToWrapper
2424
import com.github.mauricio.async.db.util.Log
25-
import java.nio.charset.Charset
26-
27-
import com.github.mauricio.async.db.mysql.MySQLHelper
28-
import io.netty.handler.codec.ByteToMessageDecoder
29-
import io.netty.channel.ChannelHandlerContext
3025
import io.netty.buffer.ByteBuf
26+
import io.netty.channel.ChannelHandlerContext
27+
import io.netty.handler.codec.ByteToMessageDecoder
3128
import java.nio.ByteOrder
29+
import java.nio.charset.Charset
30+
import java.util.concurrent.atomic.AtomicInteger
31+
3232

3333
object MySQLFrameDecoder {
3434
val log = Log.get[MySQLFrameDecoder]
3535
}
3636

37-
class MySQLFrameDecoder(charset: Charset) extends ByteToMessageDecoder {
37+
class MySQLFrameDecoder(charset: Charset, connectionId : String) extends ByteToMessageDecoder {
38+
3839

40+
private final val messagesCount = new AtomicInteger()
3941
private final val handshakeDecoder = new HandshakeV10Decoder(charset)
4042
private final val errorDecoder = new ErrorDecoder(charset)
4143
private final val okDecoder = new OkDecoder(charset)
4244
private final val columnDecoder = new ColumnDefinitionDecoder(charset, new DecoderRegistry(charset))
4345
private final val rowDecoder = new ResultSetRowDecoder(charset)
4446
private final val preparedStatementPrepareDecoder = new PreparedStatementPrepareResponseDecoder()
4547

46-
private[codec] var processingColumns = false
47-
private[codec] var processingParams = false
48-
private[codec] var isInQuery = false
49-
private[codec] var isPreparedStatementPrepare = false
50-
private[codec] var isPreparedStatementExecute = false
51-
private[codec] var isPreparedStatementExecuteRows = false
48+
@volatile private[codec] var processingColumns = false
49+
@volatile private[codec] var processingParams = false
50+
@volatile private[codec] var isInQuery = false
51+
@volatile private[codec] var isPreparedStatementPrepare = false
52+
@volatile private[codec] var isPreparedStatementExecute = false
53+
@volatile private[codec] var isPreparedStatementExecuteRows = false
5254

53-
private[codec] var totalParams = 0L
54-
private[codec] var processedParams = 0L
55-
private[codec] var totalColumns = 0L
56-
private[codec] var processedColumns = 0L
55+
@volatile private[codec] var totalParams = 0L
56+
@volatile private[codec] var processedParams = 0L
57+
@volatile private[codec] var totalColumns = 0L
58+
@volatile private[codec] var processedColumns = 0L
5759

58-
private var hasReadColumnsCount = false
60+
@volatile private var hasReadColumnsCount = false
5961

6062
def decode(ctx: ChannelHandlerContext, buffer: ByteBuf, out: java.util.List[Object]): Unit = {
6163
if (buffer.readableBytes() > 4) {
@@ -68,6 +70,8 @@ class MySQLFrameDecoder(charset: Charset) extends ByteToMessageDecoder {
6870

6971
if (buffer.readableBytes() >= size) {
7072

73+
messagesCount.incrementAndGet()
74+
7175
val messageType = buffer.getByte(buffer.readerIndex())
7276

7377
if (size < 0) {
@@ -77,7 +81,7 @@ class MySQLFrameDecoder(charset: Charset) extends ByteToMessageDecoder {
7781
// TODO: Remove once https://github.com/netty/netty/issues/1704 is fixed
7882
val slice = buffer.readSlice(size).order(ByteOrder.LITTLE_ENDIAN)
7983
//val dump = MySQLHelper.dumpAsHex(slice)
80-
//log.debug(s"Dump of message is - $messageType - $size isInQuery $isInQuery processingColumns $processingColumns processedColumns $processedColumns processingParams $processingParams processedParams $processedParams \n{}", dump)
84+
//log.debug(s"$connectionId [${messagesCount.get()}] Dump of message is - $messageType - $size isInQuery $isInQuery processingColumns $processingColumns processedColumns $processedColumns processingParams $processingParams processedParams $processedParams \n{}", dump)
8185

8286
slice.readByte()
8387

@@ -232,6 +236,7 @@ class MySQLFrameDecoder(charset: Charset) extends ByteToMessageDecoder {
232236
this.isPreparedStatementExecuteRows = false
233237
this.isInQuery = false
234238
this.processingColumns = false
239+
this.processingParams = false
235240
this.totalColumns = 0
236241
this.processedColumns = 0
237242
this.totalParams = 0

mysql-async/src/main/scala/com/github/mauricio/async/db/mysql/decoder/PreparedStatementPrepareResponseDecoder.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,17 @@
1616

1717
package com.github.mauricio.async.db.mysql.decoder
1818

19-
import io.netty.buffer.ByteBuf
2019
import com.github.mauricio.async.db.mysql.message.server.{PreparedStatementPrepareResponse, ServerMessage}
21-
import com.github.mauricio.async.db.mysql.MySQLHelper
2220
import com.github.mauricio.async.db.util.Log
21+
import io.netty.buffer.ByteBuf
2322

2423
class PreparedStatementPrepareResponseDecoder extends MessageDecoder {
2524

2625
final val log = Log.get[PreparedStatementPrepareResponseDecoder]
2726

2827
def decode(buffer: ByteBuf): ServerMessage = {
2928

30-
//val dump = MySQLHelper.dumpAsHex(buffer, buffer.readableBytes())
29+
//val dump = MySQLHelper.dumpAsHex(buffer)
3130
//log.debug("prepared statement response dump is \n{}", dump)
3231

3332
val statementId = Array[Byte]( buffer.readByte(), buffer.readByte(), buffer.readByte(), buffer.readByte() )
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
/*
2+
* Copyright 2013 Maurício Linhares
3+
*
4+
* Maurício Linhares licenses this file to you under the Apache License,
5+
* version 2.0 (the "License"); you may not use this file except in compliance
6+
* with the License. You may obtain a copy of the License at:
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12+
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13+
* License for the specific language governing permissions and limitations
14+
* under the License.
15+
*/
16+
17+
package com.github.mauricio.async.db.mysql
18+
19+
import com.github.mauricio.async.db.util.Log
20+
import java.util.concurrent.atomic.AtomicInteger
21+
22+
/**
23+
* Mainly a way to try to figure out why sometimes MySQL will fail with a bad prepared statement response message.
24+
*/
25+
26+
object ConcurrentlyRunTest extends ConnectionHelper with Runnable {
27+
28+
private val log = Log.getByName(this.getClass.getName)
29+
private val counter = new AtomicInteger()
30+
private val failures = new AtomicInteger()
31+
32+
def run() {
33+
1.until(50).foreach(x => execute(counter.incrementAndGet()))
34+
}
35+
36+
def main(args : Array[String]) {
37+
38+
log.info("Starting executing code")
39+
40+
val threads = 1.until(10).map(x => new Thread(this))
41+
42+
threads.foreach {t => t.start()}
43+
44+
while ( !threads.forall(t => t.isAlive) ) {
45+
Thread.sleep(5000)
46+
}
47+
48+
log.info(s"Finished executing code, failed execution ${failures.get()} times")
49+
50+
}
51+
52+
53+
def execute(count : Int) {
54+
try {
55+
log.info(s"====> run $count")
56+
val create = """CREATE TEMPORARY TABLE posts (
57+
| id INT NOT NULL AUTO_INCREMENT,
58+
| some_text TEXT not null,
59+
| some_date DATE,
60+
| primary key (id) )""".stripMargin
61+
62+
val insert = "insert into posts (some_text) values (?)"
63+
val select = "select * from posts limit 100"
64+
65+
withConnection {
66+
connection =>
67+
executeQuery(connection, create)
68+
69+
executePreparedStatement(connection, insert, "this is some text here")
70+
71+
val row = executeQuery(connection, select).rows.get(0)
72+
assert(row("id") == 1)
73+
assert(row("some_text") == "this is some text here")
74+
assert(row("some_date") == null)
75+
76+
val queryRow = executePreparedStatement(connection, select).rows.get(0)
77+
78+
assert(queryRow("id") == 1)
79+
assert(queryRow("some_text") == "this is some text here")
80+
assert(queryRow("some_date") == null)
81+
82+
}
83+
} catch {
84+
case e : Exception => {
85+
failures.incrementAndGet()
86+
log.error( s"Failed to execute on run $count - ${e.getMessage}", e)
87+
}
88+
}
89+
90+
}
91+
92+
}

mysql-async/src/test/scala/com/github/mauricio/async/db/mysql/codec/MySQLFrameDecoderSpec.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ class MySQLFrameDecoderSpec extends Specification {
6868

6969
"on a query process it should correctly send an OK" in {
7070

71-
val decoder = new MySQLFrameDecoder(charset)
71+
val decoder = new MySQLFrameDecoder(charset, "[mysql-connection]")
7272
val embedder = new EmbeddedChannel(decoder)
7373
embedder.config.setAllocator(LittleEndianByteBufAllocator.INSTANCE)
7474

@@ -88,7 +88,7 @@ class MySQLFrameDecoderSpec extends Specification {
8888

8989
"on query process it should correctly send an error" in {
9090

91-
val decoder = new MySQLFrameDecoder(charset)
91+
val decoder = new MySQLFrameDecoder(charset, "[mysql-connection]")
9292
val embedder = new EmbeddedChannel(decoder)
9393
embedder.config.setAllocator(LittleEndianByteBufAllocator.INSTANCE)
9494

@@ -111,7 +111,7 @@ class MySQLFrameDecoderSpec extends Specification {
111111

112112
"on query process it should correctly handle a result set" in {
113113

114-
val decoder = new MySQLFrameDecoder(charset)
114+
val decoder = new MySQLFrameDecoder(charset, "[mysql-connection]")
115115
val embedder = new EmbeddedChannel(decoder)
116116
embedder.config.setAllocator(LittleEndianByteBufAllocator.INSTANCE)
117117

@@ -165,7 +165,7 @@ class MySQLFrameDecoderSpec extends Specification {
165165
}
166166

167167
def createPipeline(): EmbeddedChannel = {
168-
val channel = new EmbeddedChannel(new MySQLFrameDecoder(charset))
168+
val channel = new EmbeddedChannel(new MySQLFrameDecoder(charset, "[mysql-connection]"))
169169
channel.config.setAllocator(LittleEndianByteBufAllocator.INSTANCE)
170170
channel
171171
}

0 commit comments

Comments
 (0)