Skip to content

Commit a5212a5

Browse files
author
xinyunh
committed
Add support to UDF and modify the test case
1 parent 1cb181e commit a5212a5

File tree

4 files changed

+36
-20
lines changed

4 files changed

+36
-20
lines changed

src/main/scala/org/apache/spark/sql/hbase/HBaseCustomFilter.scala

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ private[hbase] class HBaseCustomFilter extends FilterBase with Writable {
166166
* @param node the node to reset children on
167167
* @return
168168
*/
169-
def resetNode(node: Node) = {
169+
private def resetNode(node: Node) = {
170170
if (node != null && node.cpr != null) {
171171
node.currentValue = node.cpr.start.orNull
172172
if (node.currentValue != null && !node.cpr.startInclusive) {
@@ -401,7 +401,7 @@ private[hbase] class HBaseCustomFilter extends FilterBase with Writable {
401401
* @param node the node to start with
402402
* @return (return code, the row key after successful increment)
403403
*/
404-
def increment(node: Node): (ReturnCode, HBaseRawType) = {
404+
private def increment(node: Node): (ReturnCode, HBaseRawType) = {
405405
var currentNode: Node = node
406406
while (currentNode.parent != null) {
407407
if (addOne(currentNode)) {
@@ -439,7 +439,7 @@ private[hbase] class HBaseCustomFilter extends FilterBase with Writable {
439439
* @param node the node to add 1 to
440440
* @return whether the addition can be made within the value domain
441441
*/
442-
def addOne(node: Node): Boolean = {
442+
private def addOne(node: Node): Boolean = {
443443
val dt = node.dt
444444
val value = node.currentValue
445445
var canAddOne: Boolean = true
@@ -569,7 +569,7 @@ private[hbase] class HBaseCustomFilter extends FilterBase with Writable {
569569
* do a full evaluation for the remaining predicate based on all the cell values
570570
* @param kvs the list of cell
571571
*/
572-
def fullEvalution(kvs: java.util.List[Cell]) = {
572+
private def fullEvalution(kvs: java.util.List[Cell]) = {
573573
resetRow(workingRow)
574574
cellMap.clear()
575575
for (i <- 0 to kvs.size() - 1) {
@@ -609,9 +609,14 @@ private[hbase] class HBaseCustomFilter extends FilterBase with Writable {
609609
}
610610

611611
override def filterRowCells(kvs: java.util.List[Cell]) = {
612-
if (remainingPredicate != null) {
613-
fullEvalution(kvs)
614-
}
612+
// In coprocessor, if the call to filterKeyValue returns INCLUDE on the very last record,
613+
// the scanner runs past the end and never call filterKeyValue() before reaching here, leading
614+
// to empty kvs and a subsequent NPE. This is observed with HBase 0.98.5.
615+
//
616+
// If a later HBase release has this addressed, this check will be made unnecessary
617+
// to save some CPU cycles
618+
if (kvs.isEmpty) filterRowFlag = true
619+
else if (remainingPredicate != null) fullEvalution(kvs)
615620
}
616621

617622
override def hasFilterRow: Boolean = {

src/main/scala/org/apache/spark/sql/hbase/HBaseSQLReaderRDD.scala

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,14 +221,21 @@ class HBaseSQLReaderRDD(val relation: HBaseRelation,
221221
s
222222
}
223223

224+
if (!useCustomFilter) {
225+
def addOtherFilter(rdd: RDD[Row]): Unit = rdd match {
226+
case hcsRDD: HBaseCoprocessorSQLReaderRDD => hcsRDD.otherFilters = otherFilters
227+
case _ => if (rdd.dependencies.nonEmpty) addOtherFilter(rdd.firstParent[Row])
228+
}
229+
addOtherFilter(newSubplanRDD)
230+
}
231+
224232
val outputDataType: Seq[DataType] = subplan.get.output.map(attr => attr.dataType)
225233
val taskContextPara: (Int, Int, Long, Int) = TaskContext.get() match {
226234
case t: TaskContextImpl => (t.stageId, t.partitionId, t.taskAttemptId, t.attemptNumber)
227235
case _ => (0, 0, 0L, 0)
228236
}
229237

230-
scan.setAttribute(CoprocessorConstants.COINDEX,
231-
Bytes.toBytes(partitionIndex))
238+
scan.setAttribute(CoprocessorConstants.COINDEX, Bytes.toBytes(partitionIndex))
232239
scan.setAttribute(CoprocessorConstants.COTYPE, HBaseSerializer.serialize(outputDataType))
233240
scan.setAttribute(CoprocessorConstants.COKEY, HBaseSerializer.serialize(newSubplanRDD))
234241
scan.setAttribute(CoprocessorConstants.COTASK, HBaseSerializer.serialize(taskContextPara))

src/main/scala/org/apache/spark/sql/hbase/SparkSqlRegionObserver.scala

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,14 +47,10 @@ class HBaseCoprocessorSQLReaderRDD(var relation: HBaseRelation,
4747
private def createIterator(context: TaskContext): Iterator[Row] = {
4848
val otherFilter: (Row) => Boolean = {
4949
if (otherFilters.isDefined) {
50-
if (relation.deploySuccessfully.isDefined && relation.deploySuccessfully.get) {
51-
null
50+
if (codegenEnabled) {
51+
GeneratePredicate.generate(otherFilters.get, finalOutput)
5252
} else {
53-
if (codegenEnabled) {
54-
GeneratePredicate.generate(otherFilters.get, finalOutput)
55-
} else {
56-
InterpretedPredicate.create(otherFilters.get, finalOutput)
57-
}
53+
InterpretedPredicate.create(otherFilters.get, finalOutput)
5854
}
5955
} else null
6056
}

src/test/scala/org/apache/spark/sql/hbase/HBaseAdditionalQuerySuite.scala

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -179,14 +179,22 @@ class HBaseAdditionalQuerySuite extends TestBase {
179179

180180
test("DataFrame Test") {
181181
val teachers: DataFrame = TestHbase.sql("Select * from spark_teacher_3key")
182-
teachers.orderBy(Column("grade").asc, Column("class").asc).show(3)
182+
val result = teachers.orderBy(Column("grade").asc, Column("class").asc)
183+
.select("teacher_name").limit(3).collect()
184+
result.foreach(println)
185+
val exparr = Array(Array("teacher_1_1_1"), Array("teacher_1_2_1"), Array("teacher_1_3_1"))
186+
val res = {
187+
for (rx <- exparr.indices)
188+
yield compareWithTol(result(rx).toSeq, exparr(rx), s"Row$rx failed")
189+
}.foldLeft(true) { case (res1, newres) => res1 && newres}
190+
assert(res, "One or more rows did not match expected")
183191
}
184192

185193
test("UDF Test") {
186-
def myFilter(date: String) = date contains "_1_2"
194+
def myFilter(s: String) = s contains "_1_2"
187195
TestHbase.udf.register("myFilter", myFilter _)
188-
val result = TestHbase.sql("Select * from spark_teacher_3key WHERE myFilter(teacher_name)")
189-
result.foreach(println)
196+
val result = TestHbase.sql("Select count(*) from spark_teacher_3key WHERE myFilter(teacher_name)")
197+
result.foreach(r => require(r.getLong(0) == 3L))
190198
}
191199

192200
test("group test for presplit table with coprocessor but without codegen") {

0 commit comments

Comments
 (0)