Skip to content

List Support for SQL IN Clause #434

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import com.github.jasync.sql.db.mysql.binary.encoder.DurationEncoder
import com.github.jasync.sql.db.mysql.binary.encoder.FloatEncoder
import com.github.jasync.sql.db.mysql.binary.encoder.IntegerEncoder
import com.github.jasync.sql.db.mysql.binary.encoder.JavaDateEncoder
import com.github.jasync.sql.db.mysql.binary.encoder.ListEncoder
import com.github.jasync.sql.db.mysql.binary.encoder.LocalDateEncoder
import com.github.jasync.sql.db.mysql.binary.encoder.LocalDateTimeEncoder
import com.github.jasync.sql.db.mysql.binary.encoder.LocalTimeEncoder
Expand All @@ -39,6 +40,7 @@ import java.time.OffsetDateTime
class BinaryRowEncoder(charset: Charset) {

private val stringEncoder = StringEncoder(charset)
private val listEncoder = ListEncoder(charset)
private val encoders: Map<Class<*>, BinaryEncoder> = mapOf(
String::class.java to this.stringEncoder,
BigInteger::class.java to this.stringEncoder,
Expand Down Expand Up @@ -67,12 +69,15 @@ class BinaryRowEncoder(charset: Charset) {
Duration::class.java to DurationEncoder,
ByteArray::class.java to ByteArrayEncoder,
Boolean::class.java to BooleanEncoder,
java.lang.Boolean::class.java to BooleanEncoder
java.lang.Boolean::class.java to BooleanEncoder,
java.util.ArrayList::class.java to this.listEncoder,
java.util.LinkedList::class.java to this.listEncoder
)

fun encoderFor(v: Any): BinaryEncoder {
return this.encoders.getOrElse(v::class.java) {
return when (v) {
is List<*> -> this.listEncoder
is CharSequence -> this.stringEncoder
is java.math.BigInteger -> this.stringEncoder
is BigDecimal -> this.stringEncoder
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package com.github.jasync.sql.db.mysql.binary.encoder

import com.github.jasync.sql.db.mysql.column.ColumnTypes
import io.netty.buffer.ByteBuf
import java.nio.charset.Charset

/**
* Encoder for List types (including ArrayList) in MySQL prepared statements.
* This encoder converts the list to a comma-separated string for use in SQL IN clauses.
*/
class ListEncoder(private val charset: Charset) : BinaryEncoder {

override fun encodesTo(): Int = ColumnTypes.FIELD_TYPE_VAR_STRING

override fun encode(value: Any, buffer: ByteBuf) {
if (value !is List<*>) {
throw IllegalArgumentException("Cannot encode non-List value with ListEncoder")
}

// Convert list to comma-separated string
val stringValue = value.filterNotNull().joinToString(",")

// Write the string as length-encoded string
val bytes = stringValue.toByteArray(charset)

// MySQL uses length coded binary for strings
// https://dev.mysql.com/doc/internals/en/string.html
if (bytes.size < 251) {
buffer.writeByte(bytes.size)
} else if (bytes.size < 65536) {
buffer.writeByte(252)
buffer.writeShortLE(bytes.size)
} else if (bytes.size < 16777216) {
buffer.writeByte(253)
buffer.writeMediumLE(bytes.size)
} else {
buffer.writeByte(254)
buffer.writeLongLE(bytes.size.toLong())
}

buffer.writeBytes(bytes)
}
}

/**
* Helper method to write a 3-byte integer in little-endian format
*/
private fun ByteBuf.writeMediumLE(value: Int) {
this.writeByte(value and 0xFF)
this.writeByte(value shr 8 and 0xFF)
this.writeByte(value shr 16 and 0xFF)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
package com.github.jasync.sql.db.mysql

import com.github.jasync.sql.db.Connection
import org.assertj.core.api.Assertions.assertThat
import org.junit.Assume
import org.junit.Test
import java.util.ArrayList
import java.util.concurrent.TimeUnit

/**
* Test case demonstrating how to fix the "couldn't find mapping for class java.util.ArrayList" error
* when using IN clauses with jasync-sql
*/
class ArrayListInClauseTest : ConnectionHelper() {

data class Org(val id: Int, val name: String)

@Test
fun `test ArrayList in IN clause error and solution`() {
// Skip test if Docker environment is not available
try {
val connection = MySQLConnection(
defaultConfiguration
)
executeTest(connection)
} catch (e: Exception) {
// Log the exception and skip the test
println("Skipping test due to connection issue: ${e.message}")
Assume.assumeTrue("Database connection required for this test", false)
}
}

private fun executeTest(connection: Connection) {
try {
connection.connect().get(5, TimeUnit.SECONDS)

// Create a test table for organizations
val createTable = """
CREATE TEMPORARY TABLE organization (
id INT NOT NULL,
name VARCHAR(255) NOT NULL,
status VARCHAR(50) NOT NULL
)
""".trimIndent()

// Insert test data
val insertData = """
INSERT INTO organization (id, name, status) VALUES
(1, 'Company A', 'ok'),
(2, 'Company B', 'ok'),
(3, 'Company C', 'pending'),
(4, 'Company D', 'ok'),
(5, 'Company E', 'inactive')
""".trimIndent()

connection.sendQuery(createTable).get(5, TimeUnit.SECONDS)
connection.sendQuery(insertData).get(5, TimeUnit.SECONDS)

// Create a list of organization IDs to query
val orgsId = ArrayList<Int>()
orgsId.add(1)
orgsId.add(2)
orgsId.add(4)

// PROBLEM DEMONSTRATION: This will cause the error "couldn't find mapping for class java.util.ArrayList"
try {
println("\n\n=================================================")
println("REPRODUCING THE ISSUE: Using ArrayList in IN clause")
println("=================================================")
val sqlProblem = "SELECT id, name FROM organization WHERE status='ok' AND id IN (?)"
val listOfList = ArrayList<ArrayList<Int>>()
listOfList.add(orgsId)

println("SQL: $sqlProblem")
println("Parameters: $listOfList (${listOfList.javaClass.name})")

connection.sendPreparedStatement(sqlProblem, listOfList)
.get(5, TimeUnit.SECONDS)
.rows
.forEach { row ->
println("This should not execute as an error should be thrown")
}
} catch (e: Exception) {
println("============================================")
println("EXPECTED ERROR OCCURRED!")
println("Error message: ${e.message}")

var rootCause = e
while (rootCause.cause != null && rootCause.cause != rootCause) {
rootCause = rootCause.cause!! as Exception
}

println("Root cause: ${rootCause.javaClass.name}: ${rootCause.message}")
println("============================================\n")
}

// SOLUTION 1: Generate the correct number of placeholders
val sql1 = "SELECT id, name FROM organization WHERE status='ok' AND id IN (?, ?, ?)"
val result1 = ArrayList<Org>()

connection.sendPreparedStatement(sql1, listOf(1, 2, 4))
.get(5, TimeUnit.SECONDS)
.rows
.forEach { row ->
result1.add(Org(row.getInt(0)!!, row.getString(1)!!))
}

assertThat(result1.size).isEqualTo(3)
assertThat(result1.map { it.id }).containsExactlyInAnyOrder(1, 2, 4)

// SOLUTION 2: Dynamically generate placeholders based on the list size
val placeholders = orgsId.joinToString(", ") { "?" }
val sql2 = "SELECT id, name FROM organization WHERE status='ok' AND id IN ($placeholders)"
val result2 = ArrayList<Org>()

connection.sendPreparedStatement(sql2, orgsId)
.get(5, TimeUnit.SECONDS)
.rows
.forEach { row ->
result2.add(Org(row.getInt(0)!!, row.getString(1)!!))
}
println("Result2:")
assertThat(result2.size).isEqualTo(3)
assertThat(result2.map { it.id }).containsExactlyInAnyOrder(1, 2, 4)
assertThat(result2.map { it.name }).containsExactlyInAnyOrder("Company A", "Company B", "Company D")

connection.disconnect().get(5, TimeUnit.SECONDS)
} catch (e: Exception) {
connection.disconnect().get(5, TimeUnit.SECONDS)
throw e
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
package com.github.jasync.sql.db.mysql.binary.encoder

import io.netty.buffer.Unpooled
import org.assertj.core.api.Assertions.assertThat
import org.junit.Test
import java.nio.charset.StandardCharsets
import kotlin.test.assertEquals
import kotlin.test.fail

class ListEncoderTest {

private val charset = StandardCharsets.UTF_8
private val encoder = ListEncoder(charset)

@Test
fun `encodesTo should return VAR_STRING type`() {
assertEquals(com.github.jasync.sql.db.mysql.column.ColumnTypes.FIELD_TYPE_VAR_STRING, encoder.encodesTo())
}

@Test
fun `encode should write empty list as length-encoded empty string`() {
val buffer = Unpooled.buffer()
val list = emptyList<String>()
encoder.encode(list, buffer)

assertEquals(0.toByte(), buffer.readByte()) // Length of empty string is 0
assertEquals(0, buffer.readableBytes())
}

@Test
fun `encode should write single item list`() {
val buffer = Unpooled.buffer()
val list = listOf("hello")
encoder.encode(list, buffer)

val expectedString = "hello"
val expectedBytes = expectedString.toByteArray(charset)

assertEquals(expectedBytes.size.toByte(), buffer.readByte())
val writtenBytes = ByteArray(buffer.readableBytes())
buffer.readBytes(writtenBytes)
assertThat(writtenBytes).isEqualTo(expectedBytes)
}

@Test
fun `encode should write multiple item list as comma-separated string`() {
val buffer = Unpooled.buffer()
val list = listOf("hello", "world", "test")
encoder.encode(list, buffer)

val expectedString = "hello,world,test"
val expectedBytes = expectedString.toByteArray(charset)

assertEquals(expectedBytes.size.toByte(), buffer.readByte())
val writtenBytes = ByteArray(buffer.readableBytes())
buffer.readBytes(writtenBytes)
assertThat(writtenBytes).isEqualTo(expectedBytes)
}

@Test
fun `encode should filter out null values from list`() {
val buffer = Unpooled.buffer()
val list = listOf("one", null, "two", null, "three")
encoder.encode(list, buffer)

val expectedString = "one,two,three"
val expectedBytes = expectedString.toByteArray(charset)

assertEquals(expectedBytes.size.toByte(), buffer.readByte())
val writtenBytes = ByteArray(buffer.readableBytes())
buffer.readBytes(writtenBytes)
assertThat(writtenBytes).isEqualTo(expectedBytes)
}

@Test
fun `encode should handle list with only null values as empty list`() {
val buffer = Unpooled.buffer()
val list = listOf(null, null, null)
encoder.encode(list, buffer)

assertEquals(0.toByte(), buffer.readByte()) // Length of empty string is 0
assertEquals(0, buffer.readableBytes())
}

@Test
fun `encode should handle list of numbers`() {
val buffer = Unpooled.buffer()
val list = listOf(1, 20, 300)
encoder.encode(list, buffer)

val expectedString = "1,20,300"
val expectedBytes = expectedString.toByteArray(charset)

assertEquals(expectedBytes.size.toByte(), buffer.readByte())
val writtenBytes = ByteArray(buffer.readableBytes())
buffer.readBytes(writtenBytes)
assertThat(writtenBytes).isEqualTo(expectedBytes)
}

@Test
fun `encode should throw IllegalArgumentException for non-list value`() {
val buffer = Unpooled.buffer()
try {
encoder.encode("not a list", buffer)
fail("Expected IllegalArgumentException but no exception was thrown")
} catch (e: IllegalArgumentException) {
// Expected exception
}
}

@Test
fun `encode should handle string length less than 251`() {
val buffer = Unpooled.buffer()
val str = "a".repeat(250)
val list = listOf(str)
encoder.encode(list, buffer)

val expectedBytes = str.toByteArray(charset)
assertEquals(expectedBytes.size.toByte(), buffer.readByte()) // Length byte
val writtenBytes = ByteArray(buffer.readableBytes())
buffer.readBytes(writtenBytes)
assertThat(writtenBytes).isEqualTo(expectedBytes)
}

@Test
fun `encode should handle string length equal to 251`() {
val buffer = Unpooled.buffer()
val str = "a".repeat(251)
val list = listOf(str)
encoder.encode(list, buffer)

val expectedBytes = str.toByteArray(charset)
assertEquals(252.toByte(), buffer.readByte()) // Prefix for 2-byte length
assertEquals(expectedBytes.size.toShort(), buffer.readShortLE())
val writtenBytes = ByteArray(buffer.readableBytes())
buffer.readBytes(writtenBytes)
assertThat(writtenBytes).isEqualTo(expectedBytes)
}

@Test
fun `encode should handle string length between 251 and 65535`() {
val buffer = Unpooled.buffer()
val str = "a".repeat(300)
val list = listOf(str)
encoder.encode(list, buffer)

val expectedBytes = str.toByteArray(charset)
assertEquals(252.toByte(), buffer.readByte()) // Prefix for 2-byte length
assertEquals(expectedBytes.size.toShort(), buffer.readShortLE())
val writtenBytes = ByteArray(buffer.readableBytes())
buffer.readBytes(writtenBytes)
assertThat(writtenBytes).isEqualTo(expectedBytes)
}

@Test
fun `encode should handle string length equal to 65536`() {
val buffer = Unpooled.buffer()
val str = "a".repeat(65536)
val list = listOf(str)
encoder.encode(list, buffer)

val expectedBytes = str.toByteArray(charset)
assertEquals(253.toByte(), buffer.readByte()) // Prefix for 3-byte length
assertEquals(expectedBytes.size, readUnsignedMediumLE(buffer))
val writtenBytes = ByteArray(buffer.readableBytes())
buffer.readBytes(writtenBytes)
assertThat(writtenBytes).isEqualTo(expectedBytes)
}

// Helper to read unsigned medium for testing
private fun readUnsignedMediumLE(buffer: io.netty.buffer.ByteBuf): Int {
val b1 = buffer.readUnsignedByte().toInt()
val b2 = buffer.readUnsignedByte().toInt()
val b3 = buffer.readUnsignedByte().toInt()
return b1 or (b2 shl 8) or (b3 shl 16)
}
}