diff --git a/r2dbc-mysql/src/main/java/JasyncResult.kt b/r2dbc-mysql/src/main/java/JasyncResult.kt index 0130b939..14c48d59 100644 --- a/r2dbc-mysql/src/main/java/JasyncResult.kt +++ b/r2dbc-mysql/src/main/java/JasyncResult.kt @@ -2,6 +2,8 @@ package com.github.jasync.r2dbc.mysql import com.github.jasync.sql.db.ResultSet import io.r2dbc.spi.Result +import io.r2dbc.spi.Result.RowSegment +import io.r2dbc.spi.Row import io.r2dbc.spi.RowMetadata import org.reactivestreams.Publisher import reactor.core.publisher.Flux @@ -34,18 +36,60 @@ class JasyncResult( override fun map(mappingFunction: BiFunction): Publisher { return if (selectLastInsertId) { - Mono.fromSupplier { mappingFunction.apply(JasyncInsertSyntheticRow(generatedKeyName, lastInsertId), JasyncInsertSyntheticMetadata(generatedKeyName)) } + Mono.fromSupplier { + mappingFunction.apply( + JasyncInsertSyntheticRow(generatedKeyName, lastInsertId), + JasyncInsertSyntheticMetadata(generatedKeyName) + ) + } } else { Flux.fromIterable(resultSet) - .map { mappingFunction.apply(JasyncRow(it), metadata) } + .map { mappingFunction.apply(JasyncRow(it, metadata), metadata) } } } override fun filter(filter: Predicate): Result { - TODO("Not yet implemented") + return JasyncSegmentResult(this).filter(filter) } override fun flatMap(mappingFunction: Function>): Publisher { - TODO("Not yet implemented") + return JasyncSegmentResult(this).flatMap(mappingFunction) + } + + class JasyncSegmentResult private constructor( + private val segments: Flux, + private val result: JasyncResult + ) : Result { + constructor(result: JasyncResult) : this( + Flux.concat( + Flux.fromIterable(result.resultSet) + .map { JasyncRow(it, result.metadata) }, + Flux.just(Result.UpdateCount { result.rowsAffected }) + ), + result + ) + + override fun getRowsUpdated(): Publisher { + return result.rowsUpdated + } + + override fun map(mappingFunction: BiFunction): Publisher { + return segments + .handle { segment, sink -> + if (segment is RowSegment) { + sink.next(mappingFunction.apply(segment.row(), segment.row().metadata)) + } + } + } + + override fun filter(filter: Predicate): Result { + return JasyncSegmentResult(segments.filter(filter), result) + } + + override fun flatMap(mappingFunction: Function>): Publisher { + return segments.concatMap { segment: Result.Segment -> + mappingFunction.apply(segment) + } + } } } diff --git a/r2dbc-mysql/src/main/java/JasyncRow.kt b/r2dbc-mysql/src/main/java/JasyncRow.kt index 6abcfff9..d272c9c0 100644 --- a/r2dbc-mysql/src/main/java/JasyncRow.kt +++ b/r2dbc-mysql/src/main/java/JasyncRow.kt @@ -1,6 +1,7 @@ package com.github.jasync.r2dbc.mysql import com.github.jasync.sql.db.RowData +import io.r2dbc.spi.Result import io.r2dbc.spi.Row import io.r2dbc.spi.RowMetadata import java.math.BigDecimal @@ -9,7 +10,7 @@ import java.time.LocalDate import java.time.LocalDateTime import java.time.LocalTime -class JasyncRow(private val rowData: RowData) : Row { +class JasyncRow(private val rowData: RowData, private val metadata: JasyncMetadata) : Row, Result.RowSegment { override fun get(index: Int, type: Class): T? { return get(index as Any, type) @@ -20,10 +21,10 @@ class JasyncRow(private val rowData: RowData) : Row { } override fun getMetadata(): RowMetadata { - TODO("Not yet implemented") + return metadata } - @Suppress("UNCHECKED_CAST", "IMPLICIT_CAST_TO_ANY") + @Suppress("UNCHECKED_CAST") private fun get(identifier: Any, requestedType: Class): T? { val value = get(identifier) return when { @@ -92,4 +93,8 @@ class JasyncRow(private val rowData: RowData) : Row { else -> value } } + + override fun row(): Row { + return this + } } diff --git a/r2dbc-mysql/src/test/java/com/github/jasync/r2dbc/mysql/integ/JasyncR2dbcIntegTest.kt b/r2dbc-mysql/src/test/java/com/github/jasync/r2dbc/mysql/integ/JasyncR2dbcIntegTest.kt index 87a68053..fa4b28d1 100644 --- a/r2dbc-mysql/src/test/java/com/github/jasync/r2dbc/mysql/integ/JasyncR2dbcIntegTest.kt +++ b/r2dbc-mysql/src/test/java/com/github/jasync/r2dbc/mysql/integ/JasyncR2dbcIntegTest.kt @@ -5,6 +5,7 @@ import com.github.jasync.sql.db.mysql.MySQLConnection import com.github.jasync.sql.db.mysql.pool.MySQLConnectionFactory import com.github.jasync.sql.db.util.FP import io.mockk.mockk +import io.r2dbc.spi.Result import org.assertj.core.api.Assertions import org.awaitility.kotlin.await import org.junit.Test @@ -63,4 +64,35 @@ class JasyncR2dbcIntegTest : R2dbcConnectionHelper() { await.until { rows == 1 } } } + + @Test + fun `filter test`() { + withConnection { c -> + var rows = 0 + executeQuery(c, createTable) + executeQuery(c, """INSERT INTO users (name) VALUES ('Boogie Man'),('Dambeldor')""") + val mycf = object : MySQLConnectionFactory(mockk()) { + override fun create(): CompletableFuture { + return FP.successful(c) + } + } + val cf = JasyncConnectionFactory(mycf) + Mono.from(cf.create()) + .flatMapMany { connection -> + connection + .createStatement("SELECT name FROM users") + .execute() + } + .map { result -> + result + // we test this function + .filter { segment -> + segment is Result.RowSegment && segment.row().get("name") == "Dambeldor" + } + } + .doOnNext { rows++ } + .subscribe() + await.until { rows == 1 } + } + } }