61
61
from typing import FrozenSet
62
62
from typing import Iterable
63
63
from typing import List
64
+ from typing import Mapping
64
65
from typing import Optional
65
66
from typing import Sequence
66
67
from typing import Set
@@ -271,7 +272,7 @@ def _replace(self, override):
271
272
output_replacements = {
272
273
} # type: Dict[AppliedPTransform, List[Tuple[pvalue.PValue, Optional[str]]]]
273
274
input_replacements = {
274
- } # type: Dict[AppliedPTransform, Sequence[ Union[pvalue.PBegin, pvalue.PCollection]]]
275
+ } # type: Dict[AppliedPTransform, Mapping[str, Union[pvalue.PBegin, pvalue.PCollection]]]
275
276
side_input_replacements = {
276
277
} # type: Dict[AppliedPTransform, List[pvalue.AsSideInput]]
277
278
@@ -297,7 +298,7 @@ def _replace_if_needed(self, original_transform_node):
297
298
original_transform_node .parent ,
298
299
replacement_transform ,
299
300
original_transform_node .full_label ,
300
- original_transform_node .inputs )
301
+ original_transform_node .main_inputs )
301
302
302
303
replacement_transform_node .resource_hints = (
303
304
original_transform_node .resource_hints )
@@ -437,11 +438,11 @@ def visit_transform(self, transform_node):
437
438
output_replacements [transform_node ].append ((tag , replacement ))
438
439
439
440
if replace_input :
440
- new_input = [
441
- input if not input in output_map else output_map [input ]
442
- for input in transform_node .inputs
443
- ]
444
- input_replacements [transform_node ] = new_input
441
+ new_inputs = {
442
+ tag : input if not input in output_map else output_map [input ]
443
+ for ( tag , input ) in transform_node .main_inputs . items ()
444
+ }
445
+ input_replacements [transform_node ] = new_inputs
445
446
446
447
if replace_side_inputs :
447
448
new_side_inputs = []
@@ -670,15 +671,18 @@ def apply(
670
671
671
672
pvalueish , inputs = transform ._extract_input_pvalues (pvalueish )
672
673
try :
673
- inputs = tuple (inputs )
674
- for leaf_input in inputs :
675
- if not isinstance (leaf_input , pvalue .PValue ):
676
- raise TypeError
674
+ if not isinstance (inputs , dict ):
675
+ inputs = {str (ix ): input for (ix , input ) in enumerate (inputs )}
677
676
except TypeError :
678
677
raise NotImplementedError (
679
678
'Unable to extract PValue inputs from %s; either %s does not accept '
680
679
'inputs of this format, or it does not properly override '
681
680
'_extract_input_pvalues' % (pvalueish , transform ))
681
+ for t , leaf_input in inputs .items ():
682
+ if not isinstance (leaf_input , pvalue .PValue ) or not isinstance (t , str ):
683
+ raise NotImplementedError (
684
+ '%s does not properly override _extract_input_pvalues, '
685
+ 'returned %s from %s' % (transform , inputs , pvalueish ))
682
686
683
687
current = AppliedPTransform (
684
688
self ._current_transform (), transform , full_label , inputs )
@@ -705,7 +709,8 @@ def apply(
705
709
if result .producer is None :
706
710
result .producer = current
707
711
708
- self ._infer_result_type (transform , inputs , result )
712
+ # TODO(BEAM-1833): Pass full tuples dict.
713
+ self ._infer_result_type (transform , tuple (inputs .values ()), result )
709
714
710
715
assert isinstance (result .producer .inputs , tuple )
711
716
# The DoOutputsTuple adds the PCollection to the outputs when accessed
@@ -940,7 +945,7 @@ def from_runner_api(
940
945
for id in proto .components .transforms :
941
946
transform = context .transforms .get_by_id (id )
942
947
if not transform .inputs and transform .transform .__class__ in has_pbegin :
943
- transform .inputs = ( pvalue .PBegin (p ), )
948
+ transform .main_inputs = { 'None' : pvalue .PBegin (p )}
944
949
945
950
if return_context :
946
951
return p , context # type: ignore # too complicated for now
@@ -1030,7 +1035,7 @@ def __init__(
1030
1035
parent , # type: Optional[AppliedPTransform]
1031
1036
transform , # type: Optional[ptransform.PTransform]
1032
1037
full_label , # type: str
1033
- inputs , # type: Optional[Sequence[ Union[pvalue.PBegin, pvalue.PCollection]]]
1038
+ main_inputs , # type: Optional[Mapping[str, Union[pvalue.PBegin, pvalue.PCollection]]]
1034
1039
environment_id = None , # type: Optional[str]
1035
1040
annotations = None , # type: Optional[Dict[str, bytes]]
1036
1041
):
@@ -1043,7 +1048,7 @@ def __init__(
1043
1048
# reusing PTransform instances in different contexts (apply() calls) without
1044
1049
# any interference. This is particularly useful for composite transforms.
1045
1050
self .full_label = full_label
1046
- self .inputs = inputs or ( )
1051
+ self .main_inputs = dict ( main_inputs or {} )
1047
1052
1048
1053
self .side_inputs = tuple () if transform is None else transform .side_inputs
1049
1054
self .outputs = {} # type: Dict[Union[str, int, None], pvalue.PValue]
@@ -1076,6 +1081,10 @@ def annotation_to_bytes(key, a: Any) -> bytes:
1076
1081
}
1077
1082
self .annotations = annotations
1078
1083
1084
+ @property
1085
+ def inputs (self ):
1086
+ return tuple (self .main_inputs .values ())
1087
+
1079
1088
def __repr__ (self ):
1080
1089
# type: () -> str
1081
1090
return "%s(%s, %s)" % (
@@ -1109,8 +1118,8 @@ def replace_output(
1109
1118
if isinstance (self .transform , external .ExternalTransform ):
1110
1119
self .transform .replace_named_outputs (self .named_outputs ())
1111
1120
1112
- def replace_inputs (self , inputs ):
1113
- self .inputs = inputs
1121
+ def replace_inputs (self , main_inputs ):
1122
+ self .main_inputs = main_inputs
1114
1123
1115
1124
# Importing locally to prevent circular dependency issues.
1116
1125
from apache_beam .transforms import external
@@ -1215,12 +1224,11 @@ def visit(
1215
1224
1216
1225
def named_inputs (self ):
1217
1226
# type: () -> Dict[str, pvalue.PValue]
1218
- # TODO(BEAM-1833): Push names up into the sdk construction.
1219
1227
if self .transform is None :
1220
- assert not self .inputs and not self .side_inputs
1228
+ assert not self .main_inputs and not self .side_inputs
1221
1229
return {}
1222
1230
else :
1223
- return self .transform ._named_inputs (self .inputs , self .side_inputs )
1231
+ return self .transform ._named_inputs (self .main_inputs , self .side_inputs )
1224
1232
1225
1233
def named_outputs (self ):
1226
1234
# type: () -> Dict[str, pvalue.PCollection]
@@ -1309,10 +1317,10 @@ def from_runner_api(
1309
1317
pardo_payload = None
1310
1318
side_input_tags = []
1311
1319
1312
- main_inputs = [
1313
- context .pcollections .get_by_id (id ) for tag ,
1314
- id in proto .inputs .items () if tag not in side_input_tags
1315
- ]
1320
+ main_inputs = {
1321
+ tag : context .pcollections .get_by_id (id )
1322
+ for ( tag , id ) in proto .inputs .items () if tag not in side_input_tags
1323
+ }
1316
1324
1317
1325
transform = ptransform .PTransform .from_runner_api (proto , context )
1318
1326
if transform and proto .environment_id :
@@ -1334,7 +1342,7 @@ def from_runner_api(
1334
1342
parent = None ,
1335
1343
transform = transform ,
1336
1344
full_label = proto .unique_name ,
1337
- inputs = main_inputs ,
1345
+ main_inputs = main_inputs ,
1338
1346
environment_id = None ,
1339
1347
annotations = proto .annotations )
1340
1348
0 commit comments