Skip to content

Commit 7f63376

Browse files
committed
[BEAM-12388] Add caching to deferred dataframes
This adds caching to any dataframes using the InteractiveRuner.
1 parent 3e933b5 commit 7f63376

File tree

7 files changed

+333
-13
lines changed

7 files changed

+333
-13
lines changed

sdks/python/apache_beam/runners/interactive/caching/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,3 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616
#
17-
from apache_beam.runners.interactive.caching.streaming_cache import StreamingCache
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
from typing import *
18+
19+
import apache_beam as beam
20+
from apache_beam.dataframe import convert
21+
from apache_beam.dataframe import expressions
22+
23+
24+
class ExpressionCache(object):
25+
"""Utility class for caching deferred DataFrames expressions.
26+
27+
This is cache is currently a light-weight wrapper around the
28+
TO_PCOLLECTION_CACHE in the beam.dataframes.convert module and the
29+
computed_pcollections in the interactive module.
30+
31+
Example::
32+
33+
df : beam.dataframe.DeferredDataFrame = ...
34+
...
35+
cache = ExpressionCache()
36+
cache.replace_with_cached(df._expr)
37+
38+
This will automatically link the instance to the existing caches. After it is
39+
created, the cache can then be used to modify an existing deferred dataframe
40+
expression tree to replace nodes with computed PCollections.
41+
42+
This object can be created and destroyed whenever. This class holds no state
43+
and the only side-effect is modifying the given expression.
44+
"""
45+
def __init__(self, pcollection_cache=None, computed_cache=None):
46+
from apache_beam.runners.interactive import interactive_environment as ie
47+
48+
self._pcollection_cache = (
49+
convert.TO_PCOLLECTION_CACHE
50+
if pcollection_cache is None else pcollection_cache)
51+
self._computed_cache = (
52+
ie.current_env().computed_pcollections
53+
if computed_cache is None else computed_cache)
54+
55+
def replace_with_cached(
56+
self, expr: expressions.Expression) -> Dict[str, expressions.Expression]:
57+
"""Replaces any previously computed expressions with PlaceholderExpressions.
58+
59+
This is used to correctly read any expressions that were cached in previous
60+
runs. This enables the InteractiveRunner to prune off old calculations from
61+
the expression tree.
62+
"""
63+
64+
replaced_inputs: Dict[str, expressions.Expression] = {}
65+
self._replace_with_cached_recur(expr, replaced_inputs)
66+
return replaced_inputs
67+
68+
def _replace_with_cached_recur(
69+
self,
70+
expr: expressions.Expression,
71+
replaced_inputs: Dict[str, expressions.Expression]) -> None:
72+
"""Recursive call for `replace_with_cached`.
73+
74+
Recurses through the expression tree and replaces any cached inputs with
75+
`PlaceholderExpression`s.
76+
"""
77+
78+
final_inputs = []
79+
80+
for input in expr.args():
81+
pc = self._get_cached(input)
82+
83+
# Only read from cache when there is the PCollection has been fully
84+
# computed. This is so that no partial results are used.
85+
if self._is_computed(pc):
86+
87+
# Reuse previously seen cached expressions. This is so that the same
88+
# value isn't cached multiple times.
89+
if input._id in replaced_inputs:
90+
cached = replaced_inputs[input._id]
91+
else:
92+
cached = expressions.PlaceholderExpression(
93+
input.proxy(), self._pcollection_cache[input._id])
94+
95+
replaced_inputs[input._id] = cached
96+
final_inputs.append(cached)
97+
else:
98+
final_inputs.append(input)
99+
self._replace_with_cached_recur(input, replaced_inputs)
100+
expr._args = tuple(final_inputs)
101+
102+
def _get_cached(self,
103+
expr: expressions.Expression) -> Optional[beam.PCollection]:
104+
"""Returns the PCollection associated with the expression."""
105+
return self._pcollection_cache.get(expr._id, None)
106+
107+
def _is_computed(self, pc: beam.PCollection) -> bool:
108+
"""Returns True if the PCollection has been run and computed."""
109+
return pc is not None and pc in self._computed_cache
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
import unittest
19+
20+
import apache_beam as beam
21+
from apache_beam.dataframe import expressions
22+
from apache_beam.runners.interactive.caching.expression_cache import ExpressionCache
23+
24+
25+
class ExpressionCacheTest(unittest.TestCase):
26+
def setUp(self):
27+
self._pcollection_cache = {}
28+
self._computed_cache = set()
29+
self._pipeline = beam.Pipeline()
30+
self.cache = ExpressionCache(self._pcollection_cache, self._computed_cache)
31+
32+
def create_trace(self, expr):
33+
trace = [expr]
34+
for input in expr.args():
35+
trace += self.create_trace(input)
36+
return trace
37+
38+
def mock_cache(self, expr):
39+
pcoll = beam.PCollection(self._pipeline)
40+
self._pcollection_cache[expr._id] = pcoll
41+
self._computed_cache.add(pcoll)
42+
43+
def assertTraceTypes(self, expr, expected):
44+
actual_types = [type(e).__name__ for e in self.create_trace(expr)]
45+
expected_types = [e.__name__ for e in expected]
46+
self.assertListEqual(actual_types, expected_types)
47+
48+
def test_only_replaces_cached(self):
49+
in_expr = expressions.ConstantExpression(0)
50+
comp_expr = expressions.ComputedExpression('test', lambda x: x, [in_expr])
51+
52+
# Expect that no replacement of expressions is performed.
53+
expected_trace = [
54+
expressions.ComputedExpression, expressions.ConstantExpression
55+
]
56+
self.assertTraceTypes(comp_expr, expected_trace)
57+
58+
self.cache.replace_with_cached(comp_expr)
59+
60+
self.assertTraceTypes(comp_expr, expected_trace)
61+
62+
# Now "cache" the expression and assert that the cached expression was
63+
# replaced with a placeholder.
64+
self.mock_cache(in_expr)
65+
66+
replaced = self.cache.replace_with_cached(comp_expr)
67+
68+
expected_trace = [
69+
expressions.ComputedExpression, expressions.PlaceholderExpression
70+
]
71+
self.assertTraceTypes(comp_expr, expected_trace)
72+
self.assertIn(in_expr._id, replaced)
73+
74+
def test_only_replaces_inputs(self):
75+
arg_0_expr = expressions.ConstantExpression(0)
76+
ident_val = expressions.ComputedExpression(
77+
'ident', lambda x: x, [arg_0_expr])
78+
79+
arg_1_expr = expressions.ConstantExpression(1)
80+
comp_expr = expressions.ComputedExpression(
81+
'add', lambda x, y: x + y, [ident_val, arg_1_expr])
82+
83+
self.mock_cache(ident_val)
84+
85+
replaced = self.cache.replace_with_cached(comp_expr)
86+
87+
# Assert that ident_val was replaced and that its arguments were removed
88+
# from the expression tree.
89+
expected_trace = [
90+
expressions.ComputedExpression,
91+
expressions.PlaceholderExpression,
92+
expressions.ConstantExpression
93+
]
94+
self.assertTraceTypes(comp_expr, expected_trace)
95+
self.assertIn(ident_val._id, replaced)
96+
self.assertNotIn(arg_0_expr, self.create_trace(comp_expr))
97+
98+
def test_only_caches_same_input(self):
99+
arg_0_expr = expressions.ConstantExpression(0)
100+
ident_val = expressions.ComputedExpression(
101+
'ident', lambda x: x, [arg_0_expr])
102+
comp_expr = expressions.ComputedExpression(
103+
'add', lambda x, y: x + y, [ident_val, arg_0_expr])
104+
105+
self.mock_cache(arg_0_expr)
106+
107+
replaced = self.cache.replace_with_cached(comp_expr)
108+
109+
# Assert that arg_0_expr, being an input to two computations, was replaced
110+
# with the same placeholder expression.
111+
expected_trace = [
112+
expressions.ComputedExpression,
113+
expressions.ComputedExpression,
114+
expressions.PlaceholderExpression,
115+
expressions.PlaceholderExpression
116+
]
117+
actual_trace = self.create_trace(comp_expr)
118+
unique_placeholders = set(
119+
t for t in actual_trace
120+
if isinstance(t, expressions.PlaceholderExpression))
121+
self.assertTraceTypes(comp_expr, expected_trace)
122+
self.assertTrue(
123+
all(e == replaced[arg_0_expr._id] for e in unique_placeholders))
124+
self.assertIn(arg_0_expr._id, replaced)
125+
126+
127+
if __name__ == '__main__':
128+
unittest.main()

sdks/python/apache_beam/runners/interactive/interactive_beam.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,12 @@
3737
import pandas as pd
3838

3939
import apache_beam as beam
40-
from apache_beam.dataframe.convert import to_pcollection
4140
from apache_beam.dataframe.frame_base import DeferredBase
4241
from apache_beam.runners.interactive import interactive_environment as ie
4342
from apache_beam.runners.interactive.display import pipeline_graph
4443
from apache_beam.runners.interactive.display.pcoll_visualization import visualize
4544
from apache_beam.runners.interactive.options import interactive_options
45+
from apache_beam.runners.interactive.utils import deferred_df_to_pcollection
4646
from apache_beam.runners.interactive.utils import elements_to_df
4747
from apache_beam.runners.interactive.utils import progress_indicated
4848
from apache_beam.runners.runner import PipelineState
@@ -455,10 +455,7 @@ def show(
455455
element_types = {}
456456
for pcoll in flatten_pcolls:
457457
if isinstance(pcoll, DeferredBase):
458-
proxy = pcoll._expr.proxy()
459-
pcoll = to_pcollection(
460-
pcoll, yield_elements='pandas', label=str(pcoll._expr))
461-
element_type = proxy
458+
pcoll, element_type = deferred_df_to_pcollection(pcoll)
462459
watch({'anonymous_pcollection_{}'.format(id(pcoll)): pcoll})
463460
else:
464461
element_type = pcoll.element_type
@@ -569,11 +566,7 @@ def collect(pcoll, n='inf', duration='inf', include_window_info=False):
569566
# collect the result in elements_to_df.
570567
if isinstance(pcoll, DeferredBase):
571568
# Get the proxy so we can get the output shape of the DataFrame.
572-
# TODO(BEAM-11064): Once type hints are implemented for pandas, use those
573-
# instead of the proxy.
574-
element_type = pcoll._expr.proxy()
575-
pcoll = to_pcollection(
576-
pcoll, yield_elements='pandas', label=str(pcoll._expr))
569+
pcoll, element_type = deferred_df_to_pcollection(pcoll)
577570
watch({'anonymous_pcollection_{}'.format(id(pcoll)): pcoll})
578571
else:
579572
element_type = pcoll.element_type

sdks/python/apache_beam/runners/interactive/interactive_runner_test.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,81 @@ def test_dataframes_same_cell_twice(self):
408408
df_expected['cube'],
409409
ib.collect(df['cube'], n=10).reset_index(drop=True))
410410

411+
@unittest.skipIf(
412+
not ie.current_env().is_interactive_ready,
413+
'[interactive] dependency is not installed.')
414+
@unittest.skipIf(sys.platform == "win32", "[BEAM-10627]")
415+
@patch('IPython.get_ipython', new_callable=mock_get_ipython)
416+
def test_dataframe_caching(self, cell):
417+
418+
# Create a pipeline that exercises the DataFrame API. This will also use
419+
# caching in the background.
420+
with cell: # Cell 1
421+
p = beam.Pipeline(interactive_runner.InteractiveRunner())
422+
ib.watch({'p': p})
423+
424+
with cell: # Cell 2
425+
data = p | beam.Create([
426+
1, 2, 3
427+
]) | beam.Map(lambda x: beam.Row(square=x * x, cube=x * x * x))
428+
429+
with beam.dataframe.allow_non_parallel_operations():
430+
df = to_dataframe(data).reset_index(drop=True)
431+
432+
ib.collect(df)
433+
434+
with cell: # Cell 3
435+
df['output'] = df['square'] * df['cube']
436+
ib.collect(df)
437+
438+
with cell: # Cell 4
439+
df['output'] = 0
440+
ib.collect(df)
441+
442+
# We use a trace through the graph to perform an isomorphism test. The end
443+
# output should look like a linear graph. This indicates that the dataframe
444+
# transform was correctly broken into separate pieces to cache. If caching
445+
# isn't enabled, all the dataframe computation nodes are connected to a
446+
# single shared node.
447+
trace = []
448+
449+
# Only look at the top-level transforms for the isomorphism. The test
450+
# doesn't care about the transform implementations, just the overall shape.
451+
class TopLevelTracer(beam.pipeline.PipelineVisitor):
452+
def _find_root_producer(self, node: beam.pipeline.AppliedPTransform):
453+
if node is None or not node.full_label:
454+
return None
455+
456+
parent = self._find_root_producer(node.parent)
457+
if parent is None:
458+
return node
459+
460+
return parent
461+
462+
def _add_to_trace(self, node, trace):
463+
if '/' not in str(node):
464+
if node.inputs:
465+
producer = self._find_root_producer(node.inputs[0].producer)
466+
producer_name = producer.full_label if producer else ''
467+
trace.append((producer_name, node.full_label))
468+
469+
def visit_transform(self, node: beam.pipeline.AppliedPTransform):
470+
self._add_to_trace(node, trace)
471+
472+
def enter_composite_transform(
473+
self, node: beam.pipeline.AppliedPTransform):
474+
self._add_to_trace(node, trace)
475+
476+
p.visit(TopLevelTracer())
477+
478+
# Do the isomorphism test which states that the topological sort of the
479+
# graph yields a linear graph.
480+
trace_string = '\n'.join(str(t) for t in trace)
481+
prev_producer = ''
482+
for producer, consumer in trace:
483+
self.assertEqual(producer, prev_producer, trace_string)
484+
prev_producer = consumer
485+
411486

412487
if __name__ == '__main__':
413488
unittest.main()

sdks/python/apache_beam/runners/interactive/recording_manager.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
import pandas as pd
2424

2525
import apache_beam as beam
26-
from apache_beam.dataframe.convert import to_pcollection
2726
from apache_beam.dataframe.frame_base import DeferredBase
2827
from apache_beam.portability.api.beam_runner_api_pb2 import TestStreamPayload
2928
from apache_beam.runners.interactive import background_caching_job as bcj
@@ -310,7 +309,7 @@ def _watch(self, pcolls):
310309
# TODO(BEAM-12388): investigate the mixing pcollections in multiple
311310
# pipelines error when using the default label.
312311
for df in watched_dataframes:
313-
pcoll = to_pcollection(df, yield_elements='pandas', label=str(df._expr))
312+
pcoll, _ = utils.deferred_df_to_pcollection(df)
314313
watched_pcollections.add(pcoll)
315314
for pcoll in pcolls:
316315
if pcoll not in watched_pcollections:

0 commit comments

Comments
 (0)