diff --git a/mysql-async/src/main/java/com/github/jasync/sql/db/mysql/binary/BinaryRowEncoder.kt b/mysql-async/src/main/java/com/github/jasync/sql/db/mysql/binary/BinaryRowEncoder.kt index aeb47889..cf2572b1 100644 --- a/mysql-async/src/main/java/com/github/jasync/sql/db/mysql/binary/BinaryRowEncoder.kt +++ b/mysql-async/src/main/java/com/github/jasync/sql/db/mysql/binary/BinaryRowEncoder.kt @@ -13,6 +13,7 @@ import com.github.jasync.sql.db.mysql.binary.encoder.DurationEncoder import com.github.jasync.sql.db.mysql.binary.encoder.FloatEncoder import com.github.jasync.sql.db.mysql.binary.encoder.IntegerEncoder import com.github.jasync.sql.db.mysql.binary.encoder.JavaDateEncoder +import com.github.jasync.sql.db.mysql.binary.encoder.ListEncoder import com.github.jasync.sql.db.mysql.binary.encoder.LocalDateEncoder import com.github.jasync.sql.db.mysql.binary.encoder.LocalDateTimeEncoder import com.github.jasync.sql.db.mysql.binary.encoder.LocalTimeEncoder @@ -39,6 +40,7 @@ import java.time.OffsetDateTime class BinaryRowEncoder(charset: Charset) { private val stringEncoder = StringEncoder(charset) + private val listEncoder = ListEncoder(charset) private val encoders: Map, BinaryEncoder> = mapOf( String::class.java to this.stringEncoder, BigInteger::class.java to this.stringEncoder, @@ -67,12 +69,15 @@ class BinaryRowEncoder(charset: Charset) { Duration::class.java to DurationEncoder, ByteArray::class.java to ByteArrayEncoder, Boolean::class.java to BooleanEncoder, - java.lang.Boolean::class.java to BooleanEncoder + java.lang.Boolean::class.java to BooleanEncoder, + java.util.ArrayList::class.java to this.listEncoder, + java.util.LinkedList::class.java to this.listEncoder ) fun encoderFor(v: Any): BinaryEncoder { return this.encoders.getOrElse(v::class.java) { return when (v) { + is List<*> -> this.listEncoder is CharSequence -> this.stringEncoder is java.math.BigInteger -> this.stringEncoder is BigDecimal -> this.stringEncoder diff --git a/mysql-async/src/main/java/com/github/jasync/sql/db/mysql/binary/encoder/ListEncoder.kt b/mysql-async/src/main/java/com/github/jasync/sql/db/mysql/binary/encoder/ListEncoder.kt new file mode 100644 index 00000000..96e51361 --- /dev/null +++ b/mysql-async/src/main/java/com/github/jasync/sql/db/mysql/binary/encoder/ListEncoder.kt @@ -0,0 +1,52 @@ +package com.github.jasync.sql.db.mysql.binary.encoder + +import com.github.jasync.sql.db.mysql.column.ColumnTypes +import io.netty.buffer.ByteBuf +import java.nio.charset.Charset + +/** + * Encoder for List types (including ArrayList) in MySQL prepared statements. + * This encoder converts the list to a comma-separated string for use in SQL IN clauses. + */ +class ListEncoder(private val charset: Charset) : BinaryEncoder { + + override fun encodesTo(): Int = ColumnTypes.FIELD_TYPE_VAR_STRING + + override fun encode(value: Any, buffer: ByteBuf) { + if (value !is List<*>) { + throw IllegalArgumentException("Cannot encode non-List value with ListEncoder") + } + + // Convert list to comma-separated string + val stringValue = value.filterNotNull().joinToString(",") + + // Write the string as length-encoded string + val bytes = stringValue.toByteArray(charset) + + // MySQL uses length coded binary for strings + // https://dev.mysql.com/doc/internals/en/string.html + if (bytes.size < 251) { + buffer.writeByte(bytes.size) + } else if (bytes.size < 65536) { + buffer.writeByte(252) + buffer.writeShortLE(bytes.size) + } else if (bytes.size < 16777216) { + buffer.writeByte(253) + buffer.writeMediumLE(bytes.size) + } else { + buffer.writeByte(254) + buffer.writeLongLE(bytes.size.toLong()) + } + + buffer.writeBytes(bytes) + } +} + +/** + * Helper method to write a 3-byte integer in little-endian format + */ +private fun ByteBuf.writeMediumLE(value: Int) { + this.writeByte(value and 0xFF) + this.writeByte(value shr 8 and 0xFF) + this.writeByte(value shr 16 and 0xFF) +} diff --git a/mysql-async/src/test/java/com/github/jasync/sql/db/mysql/ArrayListInClauseTest.kt b/mysql-async/src/test/java/com/github/jasync/sql/db/mysql/ArrayListInClauseTest.kt new file mode 100644 index 00000000..f00d5f70 --- /dev/null +++ b/mysql-async/src/test/java/com/github/jasync/sql/db/mysql/ArrayListInClauseTest.kt @@ -0,0 +1,133 @@ +package com.github.jasync.sql.db.mysql + +import com.github.jasync.sql.db.Connection +import org.assertj.core.api.Assertions.assertThat +import org.junit.Assume +import org.junit.Test +import java.util.ArrayList +import java.util.concurrent.TimeUnit + +/** + * Test case demonstrating how to fix the "couldn't find mapping for class java.util.ArrayList" error + * when using IN clauses with jasync-sql + */ +class ArrayListInClauseTest : ConnectionHelper() { + + data class Org(val id: Int, val name: String) + + @Test + fun `test ArrayList in IN clause error and solution`() { + // Skip test if Docker environment is not available + try { + val connection = MySQLConnection( + defaultConfiguration + ) + executeTest(connection) + } catch (e: Exception) { + // Log the exception and skip the test + println("Skipping test due to connection issue: ${e.message}") + Assume.assumeTrue("Database connection required for this test", false) + } + } + + private fun executeTest(connection: Connection) { + try { + connection.connect().get(5, TimeUnit.SECONDS) + + // Create a test table for organizations + val createTable = """ + CREATE TEMPORARY TABLE organization ( + id INT NOT NULL, + name VARCHAR(255) NOT NULL, + status VARCHAR(50) NOT NULL + ) + """.trimIndent() + + // Insert test data + val insertData = """ + INSERT INTO organization (id, name, status) VALUES + (1, 'Company A', 'ok'), + (2, 'Company B', 'ok'), + (3, 'Company C', 'pending'), + (4, 'Company D', 'ok'), + (5, 'Company E', 'inactive') + """.trimIndent() + + connection.sendQuery(createTable).get(5, TimeUnit.SECONDS) + connection.sendQuery(insertData).get(5, TimeUnit.SECONDS) + + // Create a list of organization IDs to query + val orgsId = ArrayList() + orgsId.add(1) + orgsId.add(2) + orgsId.add(4) + + // PROBLEM DEMONSTRATION: This will cause the error "couldn't find mapping for class java.util.ArrayList" + try { + println("\n\n=================================================") + println("REPRODUCING THE ISSUE: Using ArrayList in IN clause") + println("=================================================") + val sqlProblem = "SELECT id, name FROM organization WHERE status='ok' AND id IN (?)" + val listOfList = ArrayList>() + listOfList.add(orgsId) + + println("SQL: $sqlProblem") + println("Parameters: $listOfList (${listOfList.javaClass.name})") + + connection.sendPreparedStatement(sqlProblem, listOfList) + .get(5, TimeUnit.SECONDS) + .rows + .forEach { row -> + println("This should not execute as an error should be thrown") + } + } catch (e: Exception) { + println("============================================") + println("EXPECTED ERROR OCCURRED!") + println("Error message: ${e.message}") + + var rootCause = e + while (rootCause.cause != null && rootCause.cause != rootCause) { + rootCause = rootCause.cause!! as Exception + } + + println("Root cause: ${rootCause.javaClass.name}: ${rootCause.message}") + println("============================================\n") + } + + // SOLUTION 1: Generate the correct number of placeholders + val sql1 = "SELECT id, name FROM organization WHERE status='ok' AND id IN (?, ?, ?)" + val result1 = ArrayList() + + connection.sendPreparedStatement(sql1, listOf(1, 2, 4)) + .get(5, TimeUnit.SECONDS) + .rows + .forEach { row -> + result1.add(Org(row.getInt(0)!!, row.getString(1)!!)) + } + + assertThat(result1.size).isEqualTo(3) + assertThat(result1.map { it.id }).containsExactlyInAnyOrder(1, 2, 4) + + // SOLUTION 2: Dynamically generate placeholders based on the list size + val placeholders = orgsId.joinToString(", ") { "?" } + val sql2 = "SELECT id, name FROM organization WHERE status='ok' AND id IN ($placeholders)" + val result2 = ArrayList() + + connection.sendPreparedStatement(sql2, orgsId) + .get(5, TimeUnit.SECONDS) + .rows + .forEach { row -> + result2.add(Org(row.getInt(0)!!, row.getString(1)!!)) + } + println("Result2:") + assertThat(result2.size).isEqualTo(3) + assertThat(result2.map { it.id }).containsExactlyInAnyOrder(1, 2, 4) + assertThat(result2.map { it.name }).containsExactlyInAnyOrder("Company A", "Company B", "Company D") + + connection.disconnect().get(5, TimeUnit.SECONDS) + } catch (e: Exception) { + connection.disconnect().get(5, TimeUnit.SECONDS) + throw e + } + } +} diff --git a/mysql-async/src/test/java/com/github/jasync/sql/db/mysql/binary/encoder/ListEncoderTest.kt b/mysql-async/src/test/java/com/github/jasync/sql/db/mysql/binary/encoder/ListEncoderTest.kt new file mode 100644 index 00000000..a9d0c8d2 --- /dev/null +++ b/mysql-async/src/test/java/com/github/jasync/sql/db/mysql/binary/encoder/ListEncoderTest.kt @@ -0,0 +1,177 @@ +package com.github.jasync.sql.db.mysql.binary.encoder + +import io.netty.buffer.Unpooled +import org.assertj.core.api.Assertions.assertThat +import org.junit.Test +import java.nio.charset.StandardCharsets +import kotlin.test.assertEquals +import kotlin.test.fail + +class ListEncoderTest { + + private val charset = StandardCharsets.UTF_8 + private val encoder = ListEncoder(charset) + + @Test + fun `encodesTo should return VAR_STRING type`() { + assertEquals(com.github.jasync.sql.db.mysql.column.ColumnTypes.FIELD_TYPE_VAR_STRING, encoder.encodesTo()) + } + + @Test + fun `encode should write empty list as length-encoded empty string`() { + val buffer = Unpooled.buffer() + val list = emptyList() + encoder.encode(list, buffer) + + assertEquals(0.toByte(), buffer.readByte()) // Length of empty string is 0 + assertEquals(0, buffer.readableBytes()) + } + + @Test + fun `encode should write single item list`() { + val buffer = Unpooled.buffer() + val list = listOf("hello") + encoder.encode(list, buffer) + + val expectedString = "hello" + val expectedBytes = expectedString.toByteArray(charset) + + assertEquals(expectedBytes.size.toByte(), buffer.readByte()) + val writtenBytes = ByteArray(buffer.readableBytes()) + buffer.readBytes(writtenBytes) + assertThat(writtenBytes).isEqualTo(expectedBytes) + } + + @Test + fun `encode should write multiple item list as comma-separated string`() { + val buffer = Unpooled.buffer() + val list = listOf("hello", "world", "test") + encoder.encode(list, buffer) + + val expectedString = "hello,world,test" + val expectedBytes = expectedString.toByteArray(charset) + + assertEquals(expectedBytes.size.toByte(), buffer.readByte()) + val writtenBytes = ByteArray(buffer.readableBytes()) + buffer.readBytes(writtenBytes) + assertThat(writtenBytes).isEqualTo(expectedBytes) + } + + @Test + fun `encode should filter out null values from list`() { + val buffer = Unpooled.buffer() + val list = listOf("one", null, "two", null, "three") + encoder.encode(list, buffer) + + val expectedString = "one,two,three" + val expectedBytes = expectedString.toByteArray(charset) + + assertEquals(expectedBytes.size.toByte(), buffer.readByte()) + val writtenBytes = ByteArray(buffer.readableBytes()) + buffer.readBytes(writtenBytes) + assertThat(writtenBytes).isEqualTo(expectedBytes) + } + + @Test + fun `encode should handle list with only null values as empty list`() { + val buffer = Unpooled.buffer() + val list = listOf(null, null, null) + encoder.encode(list, buffer) + + assertEquals(0.toByte(), buffer.readByte()) // Length of empty string is 0 + assertEquals(0, buffer.readableBytes()) + } + + @Test + fun `encode should handle list of numbers`() { + val buffer = Unpooled.buffer() + val list = listOf(1, 20, 300) + encoder.encode(list, buffer) + + val expectedString = "1,20,300" + val expectedBytes = expectedString.toByteArray(charset) + + assertEquals(expectedBytes.size.toByte(), buffer.readByte()) + val writtenBytes = ByteArray(buffer.readableBytes()) + buffer.readBytes(writtenBytes) + assertThat(writtenBytes).isEqualTo(expectedBytes) + } + + @Test + fun `encode should throw IllegalArgumentException for non-list value`() { + val buffer = Unpooled.buffer() + try { + encoder.encode("not a list", buffer) + fail("Expected IllegalArgumentException but no exception was thrown") + } catch (e: IllegalArgumentException) { + // Expected exception + } + } + + @Test + fun `encode should handle string length less than 251`() { + val buffer = Unpooled.buffer() + val str = "a".repeat(250) + val list = listOf(str) + encoder.encode(list, buffer) + + val expectedBytes = str.toByteArray(charset) + assertEquals(expectedBytes.size.toByte(), buffer.readByte()) // Length byte + val writtenBytes = ByteArray(buffer.readableBytes()) + buffer.readBytes(writtenBytes) + assertThat(writtenBytes).isEqualTo(expectedBytes) + } + + @Test + fun `encode should handle string length equal to 251`() { + val buffer = Unpooled.buffer() + val str = "a".repeat(251) + val list = listOf(str) + encoder.encode(list, buffer) + + val expectedBytes = str.toByteArray(charset) + assertEquals(252.toByte(), buffer.readByte()) // Prefix for 2-byte length + assertEquals(expectedBytes.size.toShort(), buffer.readShortLE()) + val writtenBytes = ByteArray(buffer.readableBytes()) + buffer.readBytes(writtenBytes) + assertThat(writtenBytes).isEqualTo(expectedBytes) + } + + @Test + fun `encode should handle string length between 251 and 65535`() { + val buffer = Unpooled.buffer() + val str = "a".repeat(300) + val list = listOf(str) + encoder.encode(list, buffer) + + val expectedBytes = str.toByteArray(charset) + assertEquals(252.toByte(), buffer.readByte()) // Prefix for 2-byte length + assertEquals(expectedBytes.size.toShort(), buffer.readShortLE()) + val writtenBytes = ByteArray(buffer.readableBytes()) + buffer.readBytes(writtenBytes) + assertThat(writtenBytes).isEqualTo(expectedBytes) + } + + @Test + fun `encode should handle string length equal to 65536`() { + val buffer = Unpooled.buffer() + val str = "a".repeat(65536) + val list = listOf(str) + encoder.encode(list, buffer) + + val expectedBytes = str.toByteArray(charset) + assertEquals(253.toByte(), buffer.readByte()) // Prefix for 3-byte length + assertEquals(expectedBytes.size, readUnsignedMediumLE(buffer)) + val writtenBytes = ByteArray(buffer.readableBytes()) + buffer.readBytes(writtenBytes) + assertThat(writtenBytes).isEqualTo(expectedBytes) + } + + // Helper to read unsigned medium for testing + private fun readUnsignedMediumLE(buffer: io.netty.buffer.ByteBuf): Int { + val b1 = buffer.readUnsignedByte().toInt() + val b2 = buffer.readUnsignedByte().toInt() + val b3 = buffer.readUnsignedByte().toInt() + return b1 or (b2 shl 8) or (b3 shl 16) + } +}