18
18
19
19
package org .apache .flink .ml .feature .onehotencoder ;
20
20
21
- import org .apache .flink .api .common .functions .FlatMapFunction ;
22
- import org .apache .flink .api .common .functions .MapPartitionFunction ;
21
+ import org .apache .flink .api .common .state .ListState ;
22
+ import org .apache .flink .api .common .state .ListStateDescriptor ;
23
+ import org .apache .flink .api .common .typeinfo .BasicTypeInfo ;
24
+ import org .apache .flink .api .common .typeinfo .TypeInformation ;
23
25
import org .apache .flink .api .java .tuple .Tuple2 ;
26
+ import org .apache .flink .api .java .typeutils .ObjectArrayTypeInfo ;
27
+ import org .apache .flink .api .java .typeutils .TupleTypeInfo ;
28
+ import org .apache .flink .iteration .operator .OperatorStateUtils ;
24
29
import org .apache .flink .ml .api .Estimator ;
25
- import org .apache .flink .ml .common .datastream .DataStreamUtils ;
26
30
import org .apache .flink .ml .common .param .HasHandleInvalid ;
27
31
import org .apache .flink .ml .param .Param ;
28
32
import org .apache .flink .ml .util .ParamUtils ;
29
33
import org .apache .flink .ml .util .ReadWriteUtils ;
34
+ import org .apache .flink .runtime .state .StateInitializationContext ;
35
+ import org .apache .flink .runtime .state .StateSnapshotContext ;
30
36
import org .apache .flink .streaming .api .datastream .DataStream ;
37
+ import org .apache .flink .streaming .api .operators .AbstractStreamOperator ;
38
+ import org .apache .flink .streaming .api .operators .BoundedOneInput ;
39
+ import org .apache .flink .streaming .api .operators .OneInputStreamOperator ;
40
+ import org .apache .flink .streaming .runtime .streamrecord .StreamRecord ;
31
41
import org .apache .flink .table .api .Table ;
32
42
import org .apache .flink .table .api .bridge .java .StreamTableEnvironment ;
33
43
import org .apache .flink .table .api .internal .TableImpl ;
34
44
import org .apache .flink .types .Row ;
35
- import org .apache .flink .util .Collector ;
36
45
import org .apache .flink .util .Preconditions ;
37
46
38
47
import java .io .IOException ;
48
+ import java .util .Arrays ;
49
+ import java .util .Collections ;
39
50
import java .util .HashMap ;
40
51
import java .util .Map ;
41
52
@@ -68,13 +79,20 @@ public OneHotEncoderModel fit(Table... inputs) {
68
79
69
80
StreamTableEnvironment tEnv =
70
81
(StreamTableEnvironment ) ((TableImpl ) inputs [0 ]).getTableEnvironment ();
71
- DataStream <Tuple2 <Integer , Integer >> columnsAndValues =
72
- tEnv .toDataStream (inputs [0 ]).flatMap (new ExtractInputColsValueFunction (inputCols ));
82
+ DataStream <Integer []> localMaxIndices =
83
+ tEnv .toDataStream (inputs [0 ])
84
+ .transform (
85
+ "ExtractInputValueAndFindMaxIndexOperator" ,
86
+ ObjectArrayTypeInfo .getInfoFor (BasicTypeInfo .INT_TYPE_INFO ),
87
+ new ExtractInputValueAndFindMaxIndexOperator (inputCols ));
73
88
74
89
DataStream <Tuple2 <Integer , Integer >> modelData =
75
- DataStreamUtils .mapPartition (
76
- columnsAndValues .keyBy (columnIdAndValue -> columnIdAndValue .f0 ),
77
- new FindMaxIndexFunction ());
90
+ localMaxIndices
91
+ .transform (
92
+ "GenerateModelDataOperator" ,
93
+ TupleTypeInfo .getBasicTupleTypeInfo (Integer .class , Integer .class ),
94
+ new GenerateModelDataOperator ())
95
+ .setParallelism (1 );
78
96
79
97
OneHotEncoderModel model =
80
98
new OneHotEncoderModel ().setModelData (tEnv .fromDataStream (modelData ));
@@ -97,50 +115,129 @@ public Map<Param<?>, Object> getParamMap() {
97
115
}
98
116
99
117
/**
100
- * Extract values of input columns of input data.
101
- *
102
- * <p>Input: rows of input data containing designated input columns
103
- *
104
- * <p>Output: Pairs of column index and value stored in those columns
118
+ * Operator to extract the integer values from input columns and to find the max index value for
119
+ * each column.
105
120
*/
106
- private static class ExtractInputColsValueFunction
107
- implements FlatMapFunction <Row , Tuple2 <Integer , Integer >> {
121
+ private static class ExtractInputValueAndFindMaxIndexOperator
122
+ extends AbstractStreamOperator <Integer []>
123
+ implements OneInputStreamOperator <Row , Integer []>, BoundedOneInput {
124
+
108
125
private final String [] inputCols ;
109
126
110
- private ExtractInputColsValueFunction (String [] inputCols ) {
127
+ private ListState <Integer []> maxIndicesState ;
128
+
129
+ private Integer [] maxIndices ;
130
+
131
+ private ExtractInputValueAndFindMaxIndexOperator (String [] inputCols ) {
111
132
this .inputCols = inputCols ;
112
133
}
113
134
114
135
@ Override
115
- public void flatMap (Row row , Collector <Tuple2 <Integer , Integer >> collector ) {
136
+ public void initializeState (StateInitializationContext context ) throws Exception {
137
+ super .initializeState (context );
138
+
139
+ TypeInformation <Integer []> type =
140
+ ObjectArrayTypeInfo .getInfoFor (BasicTypeInfo .INT_TYPE_INFO );
141
+
142
+ maxIndicesState =
143
+ context .getOperatorStateStore ()
144
+ .getListState (new ListStateDescriptor <>("maxIndices" , type ));
145
+
146
+ maxIndices =
147
+ OperatorStateUtils .getUniqueElement (maxIndicesState , "maxIndices" )
148
+ .orElse (initMaxIndices ());
149
+ }
150
+
151
+ private Integer [] initMaxIndices () {
152
+ Integer [] indices = new Integer [inputCols .length ];
153
+ Arrays .fill (indices , Integer .MIN_VALUE );
154
+ return indices ;
155
+ }
156
+
157
+ @ Override
158
+ public void snapshotState (StateSnapshotContext context ) throws Exception {
159
+ super .snapshotState (context );
160
+ maxIndicesState .update (Collections .singletonList (maxIndices ));
161
+ }
162
+
163
+ @ Override
164
+ public void processElement (StreamRecord <Row > streamRecord ) {
165
+ Row row = streamRecord .getValue ();
116
166
for (int i = 0 ; i < inputCols .length ; i ++) {
117
167
Number number = (Number ) row .getField (inputCols [i ]);
118
- Preconditions .checkArgument (
119
- number .intValue () == number .doubleValue (),
120
- String .format ("Value %s cannot be parsed as indexed integer." , number ));
121
- Preconditions .checkArgument (
122
- number .intValue () >= 0 , "Negative value not supported." );
123
- collector .collect (new Tuple2 <>(i , number .intValue ()));
168
+ int value = number .intValue ();
169
+
170
+ if (value != number .doubleValue ()) {
171
+ throw new IllegalArgumentException (
172
+ String .format ("Value %s cannot be parsed as indexed integer." , number ));
173
+ }
174
+ Preconditions .checkArgument (value >= 0 , "Negative value not supported." );
175
+
176
+ if (value > maxIndices [i ]) {
177
+ maxIndices [i ] = value ;
178
+ }
124
179
}
125
180
}
181
+
182
+ @ Override
183
+ public void endInput () {
184
+ output .collect (new StreamRecord <>(maxIndices ));
185
+ }
126
186
}
127
187
128
- /** Function to find the max index value for each column. */
129
- private static class FindMaxIndexFunction
130
- implements MapPartitionFunction <Tuple2 <Integer , Integer >, Tuple2 <Integer , Integer >> {
188
+ /**
189
+ * Collects and reduces the max index value in each column and produces the model data.
190
+ *
191
+ * <p>Output: Pairs of column index and max index value in this column.
192
+ */
193
+ private static class GenerateModelDataOperator
194
+ extends AbstractStreamOperator <Tuple2 <Integer , Integer >>
195
+ implements OneInputStreamOperator <Integer [], Tuple2 <Integer , Integer >>,
196
+ BoundedOneInput {
197
+
198
+ private ListState <Integer []> maxIndicesState ;
199
+
200
+ private Integer [] maxIndices ;
131
201
132
202
@ Override
133
- public void mapPartition (
134
- Iterable <Tuple2 <Integer , Integer >> iterable ,
135
- Collector <Tuple2 <Integer , Integer >> collector ) {
136
- Map <Integer , Integer > map = new HashMap <>();
137
- for (Tuple2 <Integer , Integer > value : iterable ) {
138
- map .put (
139
- value .f0 ,
140
- Math .max (map .getOrDefault (value .f0 , Integer .MIN_VALUE ), value .f1 ));
203
+ public void initializeState (StateInitializationContext context ) throws Exception {
204
+ super .initializeState (context );
205
+
206
+ TypeInformation <Integer []> type =
207
+ ObjectArrayTypeInfo .getInfoFor (BasicTypeInfo .INT_TYPE_INFO );
208
+
209
+ maxIndicesState =
210
+ context .getOperatorStateStore ()
211
+ .getListState (new ListStateDescriptor <>("maxIndices" , type ));
212
+
213
+ maxIndices =
214
+ OperatorStateUtils .getUniqueElement (maxIndicesState , "maxIndices" ).orElse (null );
215
+ }
216
+
217
+ @ Override
218
+ public void snapshotState (StateSnapshotContext context ) throws Exception {
219
+ super .snapshotState (context );
220
+ maxIndicesState .update (Collections .singletonList (maxIndices ));
221
+ }
222
+
223
+ @ Override
224
+ public void processElement (StreamRecord <Integer []> streamRecord ) {
225
+ if (maxIndices == null ) {
226
+ maxIndices = streamRecord .getValue ();
227
+ } else {
228
+ Integer [] indices = streamRecord .getValue ();
229
+ for (int i = 0 ; i < maxIndices .length ; i ++) {
230
+ if (indices [i ] > maxIndices [i ]) {
231
+ maxIndices [i ] = indices [i ];
232
+ }
233
+ }
141
234
}
142
- for (Map .Entry <Integer , Integer > entry : map .entrySet ()) {
143
- collector .collect (new Tuple2 <>(entry .getKey (), entry .getValue ()));
235
+ }
236
+
237
+ @ Override
238
+ public void endInput () {
239
+ for (int i = 0 ; i < maxIndices .length ; i ++) {
240
+ output .collect (new StreamRecord <>(Tuple2 .of (i , maxIndices [i ])));
144
241
}
145
242
}
146
243
}
0 commit comments