Skip to content

Commit 40769b4

Browse files
gatorsmilemarmbrus
authored andcommitted
[SPARK-11905][SQL] Support Persist/Cache and Unpersist in Dataset APIs
Persist and Unpersist exist in both RDD and Dataframe APIs. I think they are still very critical in Dataset APIs. Not sure if my understanding is correct? If so, could you help me check if the implementation is acceptable? Please provide your opinions. marmbrus rxin cloud-fan Thank you very much! Author: gatorsmile <gatorsmile@gmail.com> Author: xiaoli <lixiao1983@gmail.com> Author: Xiao Li <xiaoli@Xiaos-MacBook-Pro.local> Closes apache#9889 from gatorsmile/persistDS. (cherry picked from commit 0a7bca2) Signed-off-by: Michael Armbrust <michael@databricks.com>
1 parent 88bbce0 commit 40769b4

File tree

6 files changed

+162
-18
lines changed

6 files changed

+162
-18
lines changed

sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1584,6 +1584,7 @@ class DataFrame private[sql](
15841584
def distinct(): DataFrame = dropDuplicates()
15851585

15861586
/**
1587+
* Persist this [[DataFrame]] with the default storage level (`MEMORY_AND_DISK`).
15871588
* @group basic
15881589
* @since 1.3.0
15891590
*/
@@ -1593,12 +1594,17 @@ class DataFrame private[sql](
15931594
}
15941595

15951596
/**
1597+
* Persist this [[DataFrame]] with the default storage level (`MEMORY_AND_DISK`).
15961598
* @group basic
15971599
* @since 1.3.0
15981600
*/
15991601
def cache(): this.type = persist()
16001602

16011603
/**
1604+
* Persist this [[DataFrame]] with the given storage level.
1605+
* @param newLevel One of: `MEMORY_ONLY`, `MEMORY_AND_DISK`, `MEMORY_ONLY_SER`,
1606+
* `MEMORY_AND_DISK_SER`, `DISK_ONLY`, `MEMORY_ONLY_2`,
1607+
* `MEMORY_AND_DISK_2`, etc.
16021608
* @group basic
16031609
* @since 1.3.0
16041610
*/
@@ -1608,6 +1614,8 @@ class DataFrame private[sql](
16081614
}
16091615

16101616
/**
1617+
* Mark the [[DataFrame]] as non-persistent, and remove all blocks for it from memory and disk.
1618+
* @param blocking Whether to block until all blocks are deleted.
16111619
* @group basic
16121620
* @since 1.3.0
16131621
*/
@@ -1617,6 +1625,7 @@ class DataFrame private[sql](
16171625
}
16181626

16191627
/**
1628+
* Mark the [[DataFrame]] as non-persistent, and remove all blocks for it from memory and disk.
16201629
* @group basic
16211630
* @since 1.3.0
16221631
*/

sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.plans.JoinType
2929
import org.apache.spark.sql.catalyst.plans.logical._
3030
import org.apache.spark.sql.execution.{Queryable, QueryExecution}
3131
import org.apache.spark.sql.types.StructType
32+
import org.apache.spark.storage.StorageLevel
3233
import org.apache.spark.util.Utils
3334

3435
/**
@@ -565,7 +566,7 @@ class Dataset[T] private[sql](
565566
* combined.
566567
*
567568
* Note that, this function is not a typical set union operation, in that it does not eliminate
568-
* duplicate items. As such, it is analagous to `UNION ALL` in SQL.
569+
* duplicate items. As such, it is analogous to `UNION ALL` in SQL.
569570
* @since 1.6.0
570571
*/
571572
def union(other: Dataset[T]): Dataset[T] = withPlan[T](other)(Union)
@@ -618,7 +619,6 @@ class Dataset[T] private[sql](
618619
case _ => Alias(CreateStruct(rightOutput), "_2")()
619620
}
620621

621-
622622
implicit val tuple2Encoder: Encoder[(T, U)] =
623623
ExpressionEncoder.tuple(this.unresolvedTEncoder, other.unresolvedTEncoder)
624624
withPlan[(T, U)](other) { (left, right) =>
@@ -697,11 +697,55 @@ class Dataset[T] private[sql](
697697
*/
698698
def takeAsList(num: Int): java.util.List[T] = java.util.Arrays.asList(take(num) : _*)
699699

700+
/**
701+
* Persist this [[Dataset]] with the default storage level (`MEMORY_AND_DISK`).
702+
* @since 1.6.0
703+
*/
704+
def persist(): this.type = {
705+
sqlContext.cacheManager.cacheQuery(this)
706+
this
707+
}
708+
709+
/**
710+
* Persist this [[Dataset]] with the default storage level (`MEMORY_AND_DISK`).
711+
* @since 1.6.0
712+
*/
713+
def cache(): this.type = persist()
714+
715+
/**
716+
* Persist this [[Dataset]] with the given storage level.
717+
* @param newLevel One of: `MEMORY_ONLY`, `MEMORY_AND_DISK`, `MEMORY_ONLY_SER`,
718+
* `MEMORY_AND_DISK_SER`, `DISK_ONLY`, `MEMORY_ONLY_2`,
719+
* `MEMORY_AND_DISK_2`, etc.
720+
* @group basic
721+
* @since 1.6.0
722+
*/
723+
def persist(newLevel: StorageLevel): this.type = {
724+
sqlContext.cacheManager.cacheQuery(this, None, newLevel)
725+
this
726+
}
727+
728+
/**
729+
* Mark the [[Dataset]] as non-persistent, and remove all blocks for it from memory and disk.
730+
* @param blocking Whether to block until all blocks are deleted.
731+
* @since 1.6.0
732+
*/
733+
def unpersist(blocking: Boolean): this.type = {
734+
sqlContext.cacheManager.tryUncacheQuery(this, blocking)
735+
this
736+
}
737+
738+
/**
739+
* Mark the [[Dataset]] as non-persistent, and remove all blocks for it from memory and disk.
740+
* @since 1.6.0
741+
*/
742+
def unpersist(): this.type = unpersist(blocking = false)
743+
700744
/* ******************** *
701745
* Internal Functions *
702746
* ******************** */
703747

704-
private[sql] def logicalPlan = queryExecution.analyzed
748+
private[sql] def logicalPlan: LogicalPlan = queryExecution.analyzed
705749

706750
private[sql] def withPlan(f: LogicalPlan => LogicalPlan): Dataset[T] =
707751
new Dataset[T](sqlContext, sqlContext.executePlan(f(logicalPlan)), tEncoder)

sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,15 @@ class SQLContext private[sql](
338338
cacheManager.lookupCachedData(table(tableName)).nonEmpty
339339
}
340340

341+
/**
342+
* Returns true if the [[Queryable]] is currently cached in-memory.
343+
* @group cachemgmt
344+
* @since 1.3.0
345+
*/
346+
private[sql] def isCached(qName: Queryable): Boolean = {
347+
cacheManager.lookupCachedData(qName).nonEmpty
348+
}
349+
341350
/**
342351
* Caches the specified table in-memory.
343352
* @group cachemgmt

sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ package org.apache.spark.sql.execution
2020
import java.util.concurrent.locks.ReentrantReadWriteLock
2121

2222
import org.apache.spark.Logging
23-
import org.apache.spark.sql.DataFrame
2423
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
2524
import org.apache.spark.sql.execution.columnar.InMemoryRelation
2625
import org.apache.spark.storage.StorageLevel
@@ -75,12 +74,12 @@ private[sql] class CacheManager extends Logging {
7574
}
7675

7776
/**
78-
* Caches the data produced by the logical representation of the given [[DataFrame]]. Unlike
79-
* `RDD.cache()`, the default storage level is set to be `MEMORY_AND_DISK` because recomputing
80-
* the in-memory columnar representation of the underlying table is expensive.
77+
* Caches the data produced by the logical representation of the given [[Queryable]].
78+
* Unlike `RDD.cache()`, the default storage level is set to be `MEMORY_AND_DISK` because
79+
* recomputing the in-memory columnar representation of the underlying table is expensive.
8180
*/
8281
private[sql] def cacheQuery(
83-
query: DataFrame,
82+
query: Queryable,
8483
tableName: Option[String] = None,
8584
storageLevel: StorageLevel = MEMORY_AND_DISK): Unit = writeLock {
8685
val planToCache = query.queryExecution.analyzed
@@ -95,23 +94,25 @@ private[sql] class CacheManager extends Logging {
9594
sqlContext.conf.useCompression,
9695
sqlContext.conf.columnBatchSize,
9796
storageLevel,
98-
sqlContext.executePlan(query.logicalPlan).executedPlan,
97+
sqlContext.executePlan(planToCache).executedPlan,
9998
tableName))
10099
}
101100
}
102101

103-
/** Removes the data for the given [[DataFrame]] from the cache */
104-
private[sql] def uncacheQuery(query: DataFrame, blocking: Boolean = true): Unit = writeLock {
102+
/** Removes the data for the given [[Queryable]] from the cache */
103+
private[sql] def uncacheQuery(query: Queryable, blocking: Boolean = true): Unit = writeLock {
105104
val planToCache = query.queryExecution.analyzed
106105
val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan))
107106
require(dataIndex >= 0, s"Table $query is not cached.")
108107
cachedData(dataIndex).cachedRepresentation.uncache(blocking)
109108
cachedData.remove(dataIndex)
110109
}
111110

112-
/** Tries to remove the data for the given [[DataFrame]] from the cache if it's cached */
111+
/** Tries to remove the data for the given [[Queryable]] from the cache
112+
* if it's cached
113+
*/
113114
private[sql] def tryUncacheQuery(
114-
query: DataFrame,
115+
query: Queryable,
115116
blocking: Boolean = true): Boolean = writeLock {
116117
val planToCache = query.queryExecution.analyzed
117118
val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan))
@@ -123,12 +124,12 @@ private[sql] class CacheManager extends Logging {
123124
found
124125
}
125126

126-
/** Optionally returns cached data for the given [[DataFrame]] */
127-
private[sql] def lookupCachedData(query: DataFrame): Option[CachedData] = readLock {
127+
/** Optionally returns cached data for the given [[Queryable]] */
128+
private[sql] def lookupCachedData(query: Queryable): Option[CachedData] = readLock {
128129
lookupCachedData(query.queryExecution.analyzed)
129130
}
130131

131-
/** Optionally returns cached data for the given LogicalPlan. */
132+
/** Optionally returns cached data for the given [[LogicalPlan]]. */
132133
private[sql] def lookupCachedData(plan: LogicalPlan): Option[CachedData] = readLock {
133134
cachedData.find(cd => plan.sameResult(cd.plan))
134135
}
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql
19+
20+
import scala.language.postfixOps
21+
22+
import org.apache.spark.sql.functions._
23+
import org.apache.spark.sql.test.SharedSQLContext
24+
25+
26+
class DatasetCacheSuite extends QueryTest with SharedSQLContext {
27+
import testImplicits._
28+
29+
test("persist and unpersist") {
30+
val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS().select(expr("_2 + 1").as[Int])
31+
val cached = ds.cache()
32+
// count triggers the caching action. It should not throw.
33+
cached.count()
34+
// Make sure, the Dataset is indeed cached.
35+
assertCached(cached)
36+
// Check result.
37+
checkAnswer(
38+
cached,
39+
2, 3, 4)
40+
// Drop the cache.
41+
cached.unpersist()
42+
assert(!sqlContext.isCached(cached), "The Dataset should not be cached.")
43+
}
44+
45+
test("persist and then rebind right encoder when join 2 datasets") {
46+
val ds1 = Seq("1", "2").toDS().as("a")
47+
val ds2 = Seq(2, 3).toDS().as("b")
48+
49+
ds1.persist()
50+
assertCached(ds1)
51+
ds2.persist()
52+
assertCached(ds2)
53+
54+
val joined = ds1.joinWith(ds2, $"a.value" === $"b.value")
55+
checkAnswer(joined, ("2", 2))
56+
assertCached(joined, 2)
57+
58+
ds1.unpersist()
59+
assert(!sqlContext.isCached(ds1), "The Dataset ds1 should not be cached.")
60+
ds2.unpersist()
61+
assert(!sqlContext.isCached(ds2), "The Dataset ds2 should not be cached.")
62+
}
63+
64+
test("persist and then groupBy columns asKey, map") {
65+
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
66+
val grouped = ds.groupBy($"_1").keyAs[String]
67+
val agged = grouped.mapGroups { case (g, iter) => (g, iter.map(_._2).sum) }
68+
agged.persist()
69+
70+
checkAnswer(
71+
agged.filter(_._1 == "b"),
72+
("b", 3))
73+
assertCached(agged.filter(_._1 == "b"))
74+
75+
ds.unpersist()
76+
assert(!sqlContext.isCached(ds), "The Dataset ds should not be cached.")
77+
agged.unpersist()
78+
assert(!sqlContext.isCached(agged), "The Dataset agged should not be cached.")
79+
}
80+
}

sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import scala.collection.JavaConverters._
2424
import org.apache.spark.sql.catalyst.plans._
2525
import org.apache.spark.sql.catalyst.util._
2626
import org.apache.spark.sql.execution.columnar.InMemoryRelation
27+
import org.apache.spark.sql.execution.Queryable
2728

2829
abstract class QueryTest extends PlanTest {
2930

@@ -163,9 +164,9 @@ abstract class QueryTest extends PlanTest {
163164
}
164165

165166
/**
166-
* Asserts that a given [[DataFrame]] will be executed using the given number of cached results.
167+
* Asserts that a given [[Queryable]] will be executed using the given number of cached results.
167168
*/
168-
def assertCached(query: DataFrame, numCachedTables: Int = 1): Unit = {
169+
def assertCached(query: Queryable, numCachedTables: Int = 1): Unit = {
169170
val planWithCaching = query.queryExecution.withCachedData
170171
val cachedData = planWithCaching collect {
171172
case cached: InMemoryRelation => cached

0 commit comments

Comments
 (0)