Skip to content

Commit 70f271b

Browse files
maropugatorsmile
authored andcommitted
[SPARK-12446][SQL][BACKPORT-1.6] Add unit tests for JDBCRDD internal functions
No tests done for JDBCRDD#compileFilter. Author: Takeshi YAMAMURO <linguin.m.sgmail.com> Closes apache#10409 from maropu/AddTestsInJdbcRdd. (cherry picked from commit 8c1b867) Author: Takeshi YAMAMURO <linguin.m.s@gmail.com> Closes apache#16124 from dongjoon-hyun/SPARK-12446-BRANCH-1.6.
1 parent 8f25cb2 commit 70f271b

File tree

2 files changed

+54
-30
lines changed

2 files changed

+54
-30
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -165,8 +165,37 @@ private[sql] object JDBCRDD extends Logging {
165165
* @return A Catalyst schema corresponding to columns in the given order.
166166
*/
167167
private def pruneSchema(schema: StructType, columns: Array[String]): StructType = {
168-
val fieldMap = Map(schema.fields map { x => x.metadata.getString("name") -> x }: _*)
169-
new StructType(columns map { name => fieldMap(name) })
168+
val fieldMap = Map(schema.fields.map(x => x.metadata.getString("name") -> x): _*)
169+
new StructType(columns.map(name => fieldMap(name)))
170+
}
171+
172+
/**
173+
* Converts value to SQL expression.
174+
*/
175+
private def compileValue(value: Any): Any = value match {
176+
case stringValue: String => s"'${escapeSql(stringValue)}'"
177+
case timestampValue: Timestamp => "'" + timestampValue + "'"
178+
case dateValue: Date => "'" + dateValue + "'"
179+
case _ => value
180+
}
181+
182+
private def escapeSql(value: String): String =
183+
if (value == null) null else StringUtils.replace(value, "'", "''")
184+
185+
/**
186+
* Turns a single Filter into a String representing a SQL expression.
187+
* Returns null for an unhandled filter.
188+
*/
189+
private def compileFilter(f: Filter): String = f match {
190+
case EqualTo(attr, value) => s"$attr = ${compileValue(value)}"
191+
case Not(EqualTo(attr, value)) => s"$attr != ${compileValue(value)}"
192+
case LessThan(attr, value) => s"$attr < ${compileValue(value)}"
193+
case GreaterThan(attr, value) => s"$attr > ${compileValue(value)}"
194+
case LessThanOrEqual(attr, value) => s"$attr <= ${compileValue(value)}"
195+
case GreaterThanOrEqual(attr, value) => s"$attr >= ${compileValue(value)}"
196+
case IsNull(attr) => s"$attr IS NULL"
197+
case IsNotNull(attr) => s"$attr IS NOT NULL"
198+
case _ => null
170199
}
171200

172201

@@ -240,37 +269,12 @@ private[sql] class JDBCRDD(
240269
if (sb.length == 0) "1" else sb.substring(1)
241270
}
242271

243-
/**
244-
* Converts value to SQL expression.
245-
*/
246-
private def compileValue(value: Any): Any = value match {
247-
case stringValue: String => s"'${escapeSql(stringValue)}'"
248-
case timestampValue: Timestamp => "'" + timestampValue + "'"
249-
case dateValue: Date => "'" + dateValue + "'"
250-
case _ => value
251-
}
252-
253-
private def escapeSql(value: String): String =
254-
if (value == null) null else StringUtils.replace(value, "'", "''")
255-
256-
/**
257-
* Turns a single Filter into a String representing a SQL expression.
258-
* Returns null for an unhandled filter.
259-
*/
260-
private def compileFilter(f: Filter): String = f match {
261-
case EqualTo(attr, value) => s"$attr = ${compileValue(value)}"
262-
case LessThan(attr, value) => s"$attr < ${compileValue(value)}"
263-
case GreaterThan(attr, value) => s"$attr > ${compileValue(value)}"
264-
case LessThanOrEqual(attr, value) => s"$attr <= ${compileValue(value)}"
265-
case GreaterThanOrEqual(attr, value) => s"$attr >= ${compileValue(value)}"
266-
case _ => null
267-
}
268272

269273
/**
270274
* `filters`, but as a WHERE clause suitable for injection into a SQL query.
271275
*/
272276
private val filterWhereClause: String = {
273-
val filterStrings = filters map compileFilter filter (_ != null)
277+
val filterStrings = filters.map(JDBCRDD.compileFilter).filter(_ != null)
274278
if (filterStrings.size > 0) {
275279
val sb = new StringBuilder("WHERE ")
276280
filterStrings.foreach(x => sb.append(x).append(" AND "))

sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,22 @@
1818
package org.apache.spark.sql.jdbc
1919

2020
import java.math.BigDecimal
21-
import java.sql.DriverManager
21+
import java.sql.{Date, DriverManager, Timestamp}
2222
import java.util.{Calendar, GregorianCalendar, Properties}
2323

2424
import org.h2.jdbc.JdbcSQLException
2525
import org.scalatest.BeforeAndAfter
26+
import org.scalatest.PrivateMethodTester
2627

2728
import org.apache.spark.SparkFunSuite
29+
import org.apache.spark.sql.execution.datasources.jdbc.JDBCRDD
2830
import org.apache.spark.sql.test.SharedSQLContext
2931
import org.apache.spark.sql.types._
32+
import org.apache.spark.sql.sources._
3033
import org.apache.spark.util.Utils
3134

32-
class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext {
35+
class JDBCSuite extends SparkFunSuite
36+
with BeforeAndAfter with PrivateMethodTester with SharedSQLContext {
3337
import testImplicits._
3438

3539
val url = "jdbc:h2:mem:testdb0"
@@ -427,6 +431,22 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext
427431
assert(DerbyColumns === Seq(""""abc"""", """"key""""))
428432
}
429433

434+
test("compile filters") {
435+
val compileFilter = PrivateMethod[String]('compileFilter)
436+
def doCompileFilter(f: Filter): String = JDBCRDD invokePrivate compileFilter(f)
437+
assert(doCompileFilter(EqualTo("col0", 3)) === "col0 = 3")
438+
assert(doCompileFilter(Not(EqualTo("col1", "abc"))) === "col1 != 'abc'")
439+
assert(doCompileFilter(LessThan("col0", 5)) === "col0 < 5")
440+
assert(doCompileFilter(LessThan("col3",
441+
Timestamp.valueOf("1995-11-21 00:00:00.0"))) === "col3 < '1995-11-21 00:00:00.0'")
442+
assert(doCompileFilter(LessThan("col4", Date.valueOf("1983-08-04"))) === "col4 < '1983-08-04'")
443+
assert(doCompileFilter(LessThanOrEqual("col0", 5)) === "col0 <= 5")
444+
assert(doCompileFilter(GreaterThan("col0", 3)) === "col0 > 3")
445+
assert(doCompileFilter(GreaterThanOrEqual("col0", 3)) === "col0 >= 3")
446+
assert(doCompileFilter(IsNull("col1")) === "col1 IS NULL")
447+
assert(doCompileFilter(IsNotNull("col1")) === "col1 IS NOT NULL")
448+
}
449+
430450
test("Dialect unregister") {
431451
JdbcDialects.registerDialect(testH2Dialect)
432452
JdbcDialects.unregisterDialect(testH2Dialect)

0 commit comments

Comments
 (0)