Skip to content

Commit 939fa99

Browse files
authored
[BEAM-1833] Preserve inputs names at graph construction and through proto transaltion. (apache#15202)
1 parent 739dbb8 commit 939fa99

File tree

7 files changed

+92
-66
lines changed

7 files changed

+92
-66
lines changed

sdks/python/apache_beam/pipeline.py

+33-25
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
from typing import FrozenSet
6262
from typing import Iterable
6363
from typing import List
64+
from typing import Mapping
6465
from typing import Optional
6566
from typing import Sequence
6667
from typing import Set
@@ -271,7 +272,7 @@ def _replace(self, override):
271272
output_replacements = {
272273
} # type: Dict[AppliedPTransform, List[Tuple[pvalue.PValue, Optional[str]]]]
273274
input_replacements = {
274-
} # type: Dict[AppliedPTransform, Sequence[Union[pvalue.PBegin, pvalue.PCollection]]]
275+
} # type: Dict[AppliedPTransform, Mapping[str, Union[pvalue.PBegin, pvalue.PCollection]]]
275276
side_input_replacements = {
276277
} # type: Dict[AppliedPTransform, List[pvalue.AsSideInput]]
277278

@@ -297,7 +298,7 @@ def _replace_if_needed(self, original_transform_node):
297298
original_transform_node.parent,
298299
replacement_transform,
299300
original_transform_node.full_label,
300-
original_transform_node.inputs)
301+
original_transform_node.main_inputs)
301302

302303
replacement_transform_node.resource_hints = (
303304
original_transform_node.resource_hints)
@@ -437,11 +438,11 @@ def visit_transform(self, transform_node):
437438
output_replacements[transform_node].append((tag, replacement))
438439

439440
if replace_input:
440-
new_input = [
441-
input if not input in output_map else output_map[input]
442-
for input in transform_node.inputs
443-
]
444-
input_replacements[transform_node] = new_input
441+
new_inputs = {
442+
tag: input if not input in output_map else output_map[input]
443+
for (tag, input) in transform_node.main_inputs.items()
444+
}
445+
input_replacements[transform_node] = new_inputs
445446

446447
if replace_side_inputs:
447448
new_side_inputs = []
@@ -670,15 +671,18 @@ def apply(
670671

671672
pvalueish, inputs = transform._extract_input_pvalues(pvalueish)
672673
try:
673-
inputs = tuple(inputs)
674-
for leaf_input in inputs:
675-
if not isinstance(leaf_input, pvalue.PValue):
676-
raise TypeError
674+
if not isinstance(inputs, dict):
675+
inputs = {str(ix): input for (ix, input) in enumerate(inputs)}
677676
except TypeError:
678677
raise NotImplementedError(
679678
'Unable to extract PValue inputs from %s; either %s does not accept '
680679
'inputs of this format, or it does not properly override '
681680
'_extract_input_pvalues' % (pvalueish, transform))
681+
for t, leaf_input in inputs.items():
682+
if not isinstance(leaf_input, pvalue.PValue) or not isinstance(t, str):
683+
raise NotImplementedError(
684+
'%s does not properly override _extract_input_pvalues, '
685+
'returned %s from %s' % (transform, inputs, pvalueish))
682686

683687
current = AppliedPTransform(
684688
self._current_transform(), transform, full_label, inputs)
@@ -705,7 +709,8 @@ def apply(
705709
if result.producer is None:
706710
result.producer = current
707711

708-
self._infer_result_type(transform, inputs, result)
712+
# TODO(BEAM-1833): Pass full tuples dict.
713+
self._infer_result_type(transform, tuple(inputs.values()), result)
709714

710715
assert isinstance(result.producer.inputs, tuple)
711716
# The DoOutputsTuple adds the PCollection to the outputs when accessed
@@ -940,7 +945,7 @@ def from_runner_api(
940945
for id in proto.components.transforms:
941946
transform = context.transforms.get_by_id(id)
942947
if not transform.inputs and transform.transform.__class__ in has_pbegin:
943-
transform.inputs = (pvalue.PBegin(p), )
948+
transform.main_inputs = {'None': pvalue.PBegin(p)}
944949

945950
if return_context:
946951
return p, context # type: ignore # too complicated for now
@@ -1030,7 +1035,7 @@ def __init__(
10301035
parent, # type: Optional[AppliedPTransform]
10311036
transform, # type: Optional[ptransform.PTransform]
10321037
full_label, # type: str
1033-
inputs, # type: Optional[Sequence[Union[pvalue.PBegin, pvalue.PCollection]]]
1038+
main_inputs, # type: Optional[Mapping[str, Union[pvalue.PBegin, pvalue.PCollection]]]
10341039
environment_id=None, # type: Optional[str]
10351040
annotations=None, # type: Optional[Dict[str, bytes]]
10361041
):
@@ -1043,7 +1048,7 @@ def __init__(
10431048
# reusing PTransform instances in different contexts (apply() calls) without
10441049
# any interference. This is particularly useful for composite transforms.
10451050
self.full_label = full_label
1046-
self.inputs = inputs or ()
1051+
self.main_inputs = dict(main_inputs or {})
10471052

10481053
self.side_inputs = tuple() if transform is None else transform.side_inputs
10491054
self.outputs = {} # type: Dict[Union[str, int, None], pvalue.PValue]
@@ -1076,6 +1081,10 @@ def annotation_to_bytes(key, a: Any) -> bytes:
10761081
}
10771082
self.annotations = annotations
10781083

1084+
@property
1085+
def inputs(self):
1086+
return tuple(self.main_inputs.values())
1087+
10791088
def __repr__(self):
10801089
# type: () -> str
10811090
return "%s(%s, %s)" % (
@@ -1109,8 +1118,8 @@ def replace_output(
11091118
if isinstance(self.transform, external.ExternalTransform):
11101119
self.transform.replace_named_outputs(self.named_outputs())
11111120

1112-
def replace_inputs(self, inputs):
1113-
self.inputs = inputs
1121+
def replace_inputs(self, main_inputs):
1122+
self.main_inputs = main_inputs
11141123

11151124
# Importing locally to prevent circular dependency issues.
11161125
from apache_beam.transforms import external
@@ -1215,12 +1224,11 @@ def visit(
12151224

12161225
def named_inputs(self):
12171226
# type: () -> Dict[str, pvalue.PValue]
1218-
# TODO(BEAM-1833): Push names up into the sdk construction.
12191227
if self.transform is None:
1220-
assert not self.inputs and not self.side_inputs
1228+
assert not self.main_inputs and not self.side_inputs
12211229
return {}
12221230
else:
1223-
return self.transform._named_inputs(self.inputs, self.side_inputs)
1231+
return self.transform._named_inputs(self.main_inputs, self.side_inputs)
12241232

12251233
def named_outputs(self):
12261234
# type: () -> Dict[str, pvalue.PCollection]
@@ -1309,10 +1317,10 @@ def from_runner_api(
13091317
pardo_payload = None
13101318
side_input_tags = []
13111319

1312-
main_inputs = [
1313-
context.pcollections.get_by_id(id) for tag,
1314-
id in proto.inputs.items() if tag not in side_input_tags
1315-
]
1320+
main_inputs = {
1321+
tag: context.pcollections.get_by_id(id)
1322+
for (tag, id) in proto.inputs.items() if tag not in side_input_tags
1323+
}
13161324

13171325
transform = ptransform.PTransform.from_runner_api(proto, context)
13181326
if transform and proto.environment_id:
@@ -1334,7 +1342,7 @@ def from_runner_api(
13341342
parent=None,
13351343
transform=transform,
13361344
full_label=proto.unique_name,
1337-
inputs=main_inputs,
1345+
main_inputs=main_inputs,
13381346
environment_id=None,
13391347
annotations=proto.annotations)
13401348

sdks/python/apache_beam/pipeline_test.py

+18
Original file line numberDiff line numberDiff line change
@@ -972,6 +972,24 @@ def expand(self, p):
972972
for transform_id in runner_api_proto.components.transforms:
973973
self.assertRegex(transform_id, r'[a-zA-Z0-9-_]+')
974974

975+
def test_input_names(self):
976+
class MyPTransform(beam.PTransform):
977+
def expand(self, pcolls):
978+
return pcolls.values() | beam.Flatten()
979+
980+
p = beam.Pipeline()
981+
input_names = set('ABC')
982+
inputs = {x: p | x >> beam.Create([x]) for x in input_names}
983+
inputs | MyPTransform() # pylint: disable=expression-not-assigned
984+
runner_api_proto = Pipeline.to_runner_api(p)
985+
986+
for transform_proto in runner_api_proto.components.transforms.values():
987+
if transform_proto.unique_name == 'MyPTransform':
988+
self.assertEqual(set(transform_proto.inputs.keys()), input_names)
989+
break
990+
else:
991+
self.fail('Unable to find transform.')
992+
975993
def test_display_data(self):
976994
class MyParentTransform(beam.PTransform):
977995
def expand(self, p):

sdks/python/apache_beam/runners/dataflow/dataflow_runner.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ def visit_transform(self, transform_node):
305305
parent,
306306
beam.Map(lambda x: (b'', x)),
307307
transform_node.full_label + '/MapToVoidKey%s' % ix,
308-
(side_input.pvalue, ))
308+
{'input': side_input.pvalue})
309309
new_side_input.pvalue.producer = map_to_void_key
310310
map_to_void_key.add_output(new_side_input.pvalue, None)
311311
parent.add_part(map_to_void_key)

sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,8 @@ def test_group_by_key_input_visitor_with_valid_inputs(self):
348348
pcoll2.element_type = typehints.Any
349349
pcoll3.element_type = typehints.KV[typehints.Any, typehints.Any]
350350
for pcoll in [pcoll1, pcoll2, pcoll3]:
351-
applied = AppliedPTransform(None, beam.GroupByKey(), "label", [pcoll])
351+
applied = AppliedPTransform(
352+
None, beam.GroupByKey(), "label", {'pcoll': pcoll})
352353
applied.outputs[None] = PCollection(None)
353354
common.group_by_key_input_visitor().visit_transform(applied)
354355
self.assertEqual(
@@ -367,15 +368,15 @@ def test_group_by_key_input_visitor_with_invalid_inputs(self):
367368
for pcoll in [pcoll1, pcoll2]:
368369
with self.assertRaisesRegex(ValueError, err_msg):
369370
common.group_by_key_input_visitor().visit_transform(
370-
AppliedPTransform(None, beam.GroupByKey(), "label", [pcoll]))
371+
AppliedPTransform(None, beam.GroupByKey(), "label", {'in': pcoll}))
371372

372373
def test_group_by_key_input_visitor_for_non_gbk_transforms(self):
373374
p = TestPipeline()
374375
pcoll = PCollection(p)
375376
for transform in [beam.Flatten(), beam.Map(lambda x: x)]:
376377
pcoll.element_type = typehints.Any
377378
common.group_by_key_input_visitor().visit_transform(
378-
AppliedPTransform(None, transform, "label", [pcoll]))
379+
AppliedPTransform(None, transform, "label", {'in': pcoll}))
379380
self.assertEqual(pcoll.element_type, typehints.Any)
380381

381382
def test_flatten_input_with_visitor_with_single_input(self):
@@ -387,19 +388,19 @@ def test_flatten_input_with_visitor_with_multiple_inputs(self):
387388

388389
def _test_flatten_input_visitor(self, input_type, output_type, num_inputs):
389390
p = TestPipeline()
390-
inputs = []
391-
for _ in range(num_inputs):
391+
inputs = {}
392+
for ix in range(num_inputs):
392393
input_pcoll = PCollection(p)
393394
input_pcoll.element_type = input_type
394-
inputs.append(input_pcoll)
395+
inputs[str(ix)] = input_pcoll
395396
output_pcoll = PCollection(p)
396397
output_pcoll.element_type = output_type
397398

398399
flatten = AppliedPTransform(None, beam.Flatten(), "label", inputs)
399400
flatten.add_output(output_pcoll, None)
400401
DataflowRunner.flatten_input_visitor().visit_transform(flatten)
401402
for _ in range(num_inputs):
402-
self.assertEqual(inputs[0].element_type, output_type)
403+
self.assertEqual(inputs['0'].element_type, output_type)
403404

404405
def test_gbk_then_flatten_input_visitor(self):
405406
p = TestPipeline(
@@ -442,7 +443,7 @@ def test_side_input_visitor(self):
442443
z: (x, y, z),
443444
beam.pvalue.AsSingleton(pc),
444445
beam.pvalue.AsMultiMap(pc))
445-
applied_transform = AppliedPTransform(None, transform, "label", [pc])
446+
applied_transform = AppliedPTransform(None, transform, "label", {'pc': pc})
446447
DataflowRunner.side_input_visitor(
447448
use_fn_api=True).visit_transform(applied_transform)
448449
self.assertEqual(2, len(applied_transform.side_inputs))

sdks/python/apache_beam/runners/interactive/pipeline_instrument.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -685,18 +685,18 @@ def enter_composite_transform(self, transform_node):
685685

686686
def visit_transform(self, transform_node):
687687
if transform_node.inputs:
688-
input_list = list(transform_node.inputs)
689-
for i, input_pcoll in enumerate(input_list):
688+
main_inputs = dict(transform_node.main_inputs)
689+
for tag, input_pcoll in main_inputs.items():
690690
key = self._pin.cache_key(input_pcoll)
691691

692692
# Replace the input pcollection with the cached pcollection (if it
693693
# has been cached).
694694
if key in self._pin._cached_pcoll_read:
695695
# Ignore this pcoll in the final pruned instrumented pipeline.
696696
self._pin._ignored_targets.add(input_pcoll)
697-
input_list[i] = self._pin._cached_pcoll_read[key]
697+
main_inputs[tag] = self._pin._cached_pcoll_read[key]
698698
# Update the transform with its new inputs.
699-
transform_node.inputs = tuple(input_list)
699+
transform_node.main_inputs = main_inputs
700700

701701
v = ReadCacheWireVisitor(self)
702702
pipeline.visit(v)

sdks/python/apache_beam/runners/interactive/pipeline_instrument_test.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -297,11 +297,11 @@ def enter_composite_transform(self, transform_node):
297297

298298
def visit_transform(self, transform_node):
299299
if transform_node.inputs:
300-
input_list = list(transform_node.inputs)
301-
for i in range(len(input_list)):
302-
if input_list[i] == init_pcoll:
303-
input_list[i] = cached_init_pcoll
304-
transform_node.inputs = tuple(input_list)
300+
main_inputs = dict(transform_node.main_inputs)
301+
for tag in main_inputs.keys():
302+
if main_inputs[tag] == init_pcoll:
303+
main_inputs[tag] = cached_init_pcoll
304+
transform_node.main_inputs = main_inputs
305305

306306
v = TestReadCacheWireVisitor()
307307
p_origin.visit(v)

0 commit comments

Comments
 (0)