Skip to content

Commit 85d46e1

Browse files
authored
Merge pull request apache#15357 from [BEAM-12781] Add memoization of BoundedSource encoding for SDFBoundedSourceReader
* Add memoization of BoundedSource encoding for SDFBoundedSourceReader * Fixup for unhashable types * Fixup * addressing comments * fix comment
1 parent c439abf commit 85d46e1

File tree

3 files changed

+68
-3
lines changed

3 files changed

+68
-3
lines changed

sdks/python/apache_beam/coders/coders.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,27 @@
1818
"""Collection of useful coders.
1919
2020
Only those coders listed in __all__ are part of the public API of this module.
21+
22+
## On usage of `pickle`, `dill` and `pickler` in Beam
23+
24+
In Beam, we generally we use `pickle` for pipeline elements and `dill` for
25+
more complex types, like user functions.
26+
27+
`pickler` is Beam's own wrapping of dill + compression + error handling.
28+
It serves also as an API to mask the actual encoding layer (so we can
29+
change it from `dill` if necessary).
30+
31+
We created `_MemoizingPickleCoder` to improve performance when serializing
32+
complex user types for the execution of SDF. Specifically to address
33+
BEAM-12781, where many identical `BoundedSource` instances are being
34+
encoded.
35+
2136
"""
2237
# pytype: skip-file
2338

2439
import base64
2540
import pickle
41+
from functools import lru_cache
2642
from typing import TYPE_CHECKING
2743
from typing import Any
2844
from typing import Callable
@@ -741,6 +757,33 @@ def __hash__(self):
741757
return hash(type(self))
742758

743759

760+
class _MemoizingPickleCoder(_PickleCoderBase):
761+
"""Coder using Python's pickle functionality with memoization."""
762+
def __init__(self, cache_size=16):
763+
super(_MemoizingPickleCoder, self).__init__()
764+
self.cache_size = cache_size
765+
766+
def _create_impl(self):
767+
from apache_beam.internal import pickler
768+
dumps = pickler.dumps
769+
770+
mdumps = lru_cache(maxsize=self.cache_size, typed=True)(dumps)
771+
772+
def _nonhashable_dumps(x):
773+
try:
774+
return mdumps(x)
775+
except TypeError:
776+
return dumps(x)
777+
778+
return coder_impl.CallbackCoderImpl(_nonhashable_dumps, pickler.loads)
779+
780+
def as_deterministic_coder(self, step_label, error_message=None):
781+
return FastPrimitivesCoder(self, requires_deterministic=step_label)
782+
783+
def to_type_hint(self):
784+
return Any
785+
786+
744787
class PickleCoder(_PickleCoderBase):
745788
"""Coder using Python's pickle functionality."""
746789
def _create_impl(self):

sdks/python/apache_beam/coders/coders_test_common.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,10 @@ def test_pickle_coder(self):
207207
coder = coders.PickleCoder()
208208
self.check_coder(coder, *self.test_values)
209209

210+
def test_memoizing_pickle_coder(self):
211+
coder = coders._MemoizingPickleCoder()
212+
self.check_coder(coder, *self.test_values)
213+
210214
def test_deterministic_coder(self):
211215
coder = coders.FastPrimitivesCoder()
212216
deterministic_coder = coders.DeterministicFastPrimitivesCoder(coder, 'step')

sdks/python/apache_beam/io/iobase.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@
4444

4545
from apache_beam import coders
4646
from apache_beam import pvalue
47+
from apache_beam.coders.coders import _MemoizingPickleCoder
48+
from apache_beam.internal import pickler
4749
from apache_beam.portability import common_urns
4850
from apache_beam.portability import python_urns
4951
from apache_beam.portability.api import beam_runner_api_pb2
@@ -886,12 +888,14 @@ def get_desired_chunk_size(total_size):
886888

887889
def expand(self, pbegin):
888890
if isinstance(self.source, BoundedSource):
891+
coders.registry.register_coder(BoundedSource, _MemoizingPickleCoder)
889892
display_data = self.source.display_data() or {}
890893
display_data['source'] = self.source.__class__
894+
891895
return (
892896
pbegin
893897
| Impulse()
894-
| core.Map(lambda _: self.source)
898+
| core.Map(lambda _: self.source).with_output_types(BoundedSource)
895899
| SDFBoundedSourceReader(display_data))
896900
elif isinstance(self.source, ptransform.PTransform):
897901
# The Read transform can also admit a full PTransform as an input
@@ -1573,15 +1577,29 @@ def is_bounded(self):
15731577
return True
15741578

15751579

1580+
class _SDFBoundedSourceWrapperRestrictionCoder(coders.Coder):
1581+
def decode(self, value):
1582+
return _SDFBoundedSourceRestriction(SourceBundle(*pickler.loads(value)))
1583+
1584+
def encode(self, restriction):
1585+
return pickler.dumps((
1586+
restriction._source_bundle.weight,
1587+
restriction._source_bundle.source,
1588+
restriction._source_bundle.start_position,
1589+
restriction._source_bundle.stop_position))
1590+
1591+
15761592
class _SDFBoundedSourceRestrictionProvider(core.RestrictionProvider):
15771593
"""
15781594
A `RestrictionProvider` that is used by SDF for `BoundedSource`.
15791595
15801596
This restriction provider initializes restriction based on input
15811597
element that is expected to be of BoundedSource type.
15821598
"""
1583-
def __init__(self, desired_chunk_size=None):
1599+
def __init__(self, desired_chunk_size=None, restriction_coder=None):
15841600
self._desired_chunk_size = desired_chunk_size
1601+
self._restriction_coder = (
1602+
restriction_coder or _SDFBoundedSourceWrapperRestrictionCoder())
15851603

15861604
def _check_source(self, src):
15871605
if not isinstance(src, BoundedSource):
@@ -1618,7 +1636,7 @@ def restriction_size(self, element, restriction):
16181636
return restriction.weight()
16191637

16201638
def restriction_coder(self):
1621-
return coders.DillCoder()
1639+
return self._restriction_coder
16221640

16231641

16241642
class SDFBoundedSourceReader(PTransform):

0 commit comments

Comments
 (0)