Skip to content

Commit de08eb8

Browse files
[FLINK-27096] Optimize OneHotEncoder performance
This closes apache#113.
1 parent 10c1ef4 commit de08eb8

File tree

3 files changed

+195
-38
lines changed

3 files changed

+195
-38
lines changed

flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/DoubleGenerator.java

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
import org.apache.flink.api.common.typeinfo.TypeInformation;
2222
import org.apache.flink.api.common.typeinfo.Types;
2323
import org.apache.flink.api.java.typeutils.RowTypeInfo;
24+
import org.apache.flink.ml.param.IntParam;
25+
import org.apache.flink.ml.param.Param;
26+
import org.apache.flink.ml.param.ParamValidators;
2427
import org.apache.flink.types.Row;
2528
import org.apache.flink.util.Preconditions;
2629

@@ -29,19 +32,41 @@
2932
/** A DataGenerator which creates a table of doubles. */
3033
public class DoubleGenerator extends InputTableGenerator<DoubleGenerator> {
3134

35+
public static final Param<Integer> ARITY =
36+
new IntParam(
37+
"arity",
38+
"Arity of the generated double values. "
39+
+ "If set to positive value, each feature would be an integer in range [0, arity - 1]. "
40+
+ "If set to zero, each feature would be a continuous double in range [0, 1).",
41+
0,
42+
ParamValidators.gtEq(0));
43+
44+
public int getArity() {
45+
return get(ARITY);
46+
}
47+
48+
public DoubleGenerator setArity(int value) {
49+
return set(ARITY, value);
50+
}
51+
3252
@Override
3353
protected RowGenerator[] getRowGenerators() {
3454
String[][] colNames = getColNames();
3555
Preconditions.checkState(colNames.length == 1);
3656
int numOutputCols = colNames[0].length;
57+
int arity = getArity();
3758

3859
return new RowGenerator[] {
3960
new RowGenerator(getNumValues(), getSeed()) {
4061
@Override
4162
public Row nextRow() {
4263
Row r = new Row(numOutputCols);
4364
for (int i = 0; i < numOutputCols; i++) {
44-
r.setField(i, random.nextDouble());
65+
if (arity > 0) {
66+
r.setField(i, (double) random.nextInt(arity));
67+
} else {
68+
r.setField(i, random.nextDouble());
69+
}
4570
}
4671
return r;
4772
}
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one or more
2+
// contributor license agreements. See the NOTICE file distributed with
3+
// this work for additional information regarding copyright ownership.
4+
// The ASF licenses this file to You under the Apache License, Version 2.0
5+
// (the "License"); you may not use this file except in compliance with
6+
// the License. You may obtain a copy of the License at
7+
//
8+
// http://www.apache.org/licenses/LICENSE-2.0
9+
//
10+
// Unless required by applicable law or agreed to in writing, software
11+
// distributed under the License is distributed on an "AS IS" BASIS,
12+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
// See the License for the specific language governing permissions and
14+
// limitations under the License.
15+
16+
{
17+
"version": 1,
18+
"OneHotEncoder": {
19+
"stage": {
20+
"className": "org.apache.flink.ml.feature.onehotencoder.OneHotEncoder",
21+
"paramMap": {
22+
"inputCols": ["input"],
23+
"outputCols": ["output"]
24+
}
25+
},
26+
"inputData": {
27+
"className": "org.apache.flink.ml.benchmark.datagenerator.common.DoubleGenerator",
28+
"paramMap": {
29+
"colNames": [["input"]],
30+
"arity": 10,
31+
"numValues": 100000
32+
}
33+
}
34+
}
35+
}

flink-ml-lib/src/main/java/org/apache/flink/ml/feature/onehotencoder/OneHotEncoder.java

Lines changed: 134 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -18,24 +18,35 @@
1818

1919
package org.apache.flink.ml.feature.onehotencoder;
2020

21-
import org.apache.flink.api.common.functions.FlatMapFunction;
22-
import org.apache.flink.api.common.functions.MapPartitionFunction;
21+
import org.apache.flink.api.common.state.ListState;
22+
import org.apache.flink.api.common.state.ListStateDescriptor;
23+
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
24+
import org.apache.flink.api.common.typeinfo.TypeInformation;
2325
import org.apache.flink.api.java.tuple.Tuple2;
26+
import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
27+
import org.apache.flink.api.java.typeutils.TupleTypeInfo;
28+
import org.apache.flink.iteration.operator.OperatorStateUtils;
2429
import org.apache.flink.ml.api.Estimator;
25-
import org.apache.flink.ml.common.datastream.DataStreamUtils;
2630
import org.apache.flink.ml.common.param.HasHandleInvalid;
2731
import org.apache.flink.ml.param.Param;
2832
import org.apache.flink.ml.util.ParamUtils;
2933
import org.apache.flink.ml.util.ReadWriteUtils;
34+
import org.apache.flink.runtime.state.StateInitializationContext;
35+
import org.apache.flink.runtime.state.StateSnapshotContext;
3036
import org.apache.flink.streaming.api.datastream.DataStream;
37+
import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
38+
import org.apache.flink.streaming.api.operators.BoundedOneInput;
39+
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
40+
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
3141
import org.apache.flink.table.api.Table;
3242
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
3343
import org.apache.flink.table.api.internal.TableImpl;
3444
import org.apache.flink.types.Row;
35-
import org.apache.flink.util.Collector;
3645
import org.apache.flink.util.Preconditions;
3746

3847
import java.io.IOException;
48+
import java.util.Arrays;
49+
import java.util.Collections;
3950
import java.util.HashMap;
4051
import java.util.Map;
4152

@@ -68,13 +79,20 @@ public OneHotEncoderModel fit(Table... inputs) {
6879

6980
StreamTableEnvironment tEnv =
7081
(StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
71-
DataStream<Tuple2<Integer, Integer>> columnsAndValues =
72-
tEnv.toDataStream(inputs[0]).flatMap(new ExtractInputColsValueFunction(inputCols));
82+
DataStream<Integer[]> localMaxIndices =
83+
tEnv.toDataStream(inputs[0])
84+
.transform(
85+
"ExtractInputValueAndFindMaxIndexOperator",
86+
ObjectArrayTypeInfo.getInfoFor(BasicTypeInfo.INT_TYPE_INFO),
87+
new ExtractInputValueAndFindMaxIndexOperator(inputCols));
7388

7489
DataStream<Tuple2<Integer, Integer>> modelData =
75-
DataStreamUtils.mapPartition(
76-
columnsAndValues.keyBy(columnIdAndValue -> columnIdAndValue.f0),
77-
new FindMaxIndexFunction());
90+
localMaxIndices
91+
.transform(
92+
"GenerateModelDataOperator",
93+
TupleTypeInfo.getBasicTupleTypeInfo(Integer.class, Integer.class),
94+
new GenerateModelDataOperator())
95+
.setParallelism(1);
7896

7997
OneHotEncoderModel model =
8098
new OneHotEncoderModel().setModelData(tEnv.fromDataStream(modelData));
@@ -97,50 +115,129 @@ public Map<Param<?>, Object> getParamMap() {
97115
}
98116

99117
/**
100-
* Extract values of input columns of input data.
101-
*
102-
* <p>Input: rows of input data containing designated input columns
103-
*
104-
* <p>Output: Pairs of column index and value stored in those columns
118+
* Operator to extract the integer values from input columns and to find the max index value for
119+
* each column.
105120
*/
106-
private static class ExtractInputColsValueFunction
107-
implements FlatMapFunction<Row, Tuple2<Integer, Integer>> {
121+
private static class ExtractInputValueAndFindMaxIndexOperator
122+
extends AbstractStreamOperator<Integer[]>
123+
implements OneInputStreamOperator<Row, Integer[]>, BoundedOneInput {
124+
108125
private final String[] inputCols;
109126

110-
private ExtractInputColsValueFunction(String[] inputCols) {
127+
private ListState<Integer[]> maxIndicesState;
128+
129+
private Integer[] maxIndices;
130+
131+
private ExtractInputValueAndFindMaxIndexOperator(String[] inputCols) {
111132
this.inputCols = inputCols;
112133
}
113134

114135
@Override
115-
public void flatMap(Row row, Collector<Tuple2<Integer, Integer>> collector) {
136+
public void initializeState(StateInitializationContext context) throws Exception {
137+
super.initializeState(context);
138+
139+
TypeInformation<Integer[]> type =
140+
ObjectArrayTypeInfo.getInfoFor(BasicTypeInfo.INT_TYPE_INFO);
141+
142+
maxIndicesState =
143+
context.getOperatorStateStore()
144+
.getListState(new ListStateDescriptor<>("maxIndices", type));
145+
146+
maxIndices =
147+
OperatorStateUtils.getUniqueElement(maxIndicesState, "maxIndices")
148+
.orElse(initMaxIndices());
149+
}
150+
151+
private Integer[] initMaxIndices() {
152+
Integer[] indices = new Integer[inputCols.length];
153+
Arrays.fill(indices, Integer.MIN_VALUE);
154+
return indices;
155+
}
156+
157+
@Override
158+
public void snapshotState(StateSnapshotContext context) throws Exception {
159+
super.snapshotState(context);
160+
maxIndicesState.update(Collections.singletonList(maxIndices));
161+
}
162+
163+
@Override
164+
public void processElement(StreamRecord<Row> streamRecord) {
165+
Row row = streamRecord.getValue();
116166
for (int i = 0; i < inputCols.length; i++) {
117167
Number number = (Number) row.getField(inputCols[i]);
118-
Preconditions.checkArgument(
119-
number.intValue() == number.doubleValue(),
120-
String.format("Value %s cannot be parsed as indexed integer.", number));
121-
Preconditions.checkArgument(
122-
number.intValue() >= 0, "Negative value not supported.");
123-
collector.collect(new Tuple2<>(i, number.intValue()));
168+
int value = number.intValue();
169+
170+
if (value != number.doubleValue()) {
171+
throw new IllegalArgumentException(
172+
String.format("Value %s cannot be parsed as indexed integer.", number));
173+
}
174+
Preconditions.checkArgument(value >= 0, "Negative value not supported.");
175+
176+
if (value > maxIndices[i]) {
177+
maxIndices[i] = value;
178+
}
124179
}
125180
}
181+
182+
@Override
183+
public void endInput() {
184+
output.collect(new StreamRecord<>(maxIndices));
185+
}
126186
}
127187

128-
/** Function to find the max index value for each column. */
129-
private static class FindMaxIndexFunction
130-
implements MapPartitionFunction<Tuple2<Integer, Integer>, Tuple2<Integer, Integer>> {
188+
/**
189+
* Collects and reduces the max index value in each column and produces the model data.
190+
*
191+
* <p>Output: Pairs of column index and max index value in this column.
192+
*/
193+
private static class GenerateModelDataOperator
194+
extends AbstractStreamOperator<Tuple2<Integer, Integer>>
195+
implements OneInputStreamOperator<Integer[], Tuple2<Integer, Integer>>,
196+
BoundedOneInput {
197+
198+
private ListState<Integer[]> maxIndicesState;
199+
200+
private Integer[] maxIndices;
131201

132202
@Override
133-
public void mapPartition(
134-
Iterable<Tuple2<Integer, Integer>> iterable,
135-
Collector<Tuple2<Integer, Integer>> collector) {
136-
Map<Integer, Integer> map = new HashMap<>();
137-
for (Tuple2<Integer, Integer> value : iterable) {
138-
map.put(
139-
value.f0,
140-
Math.max(map.getOrDefault(value.f0, Integer.MIN_VALUE), value.f1));
203+
public void initializeState(StateInitializationContext context) throws Exception {
204+
super.initializeState(context);
205+
206+
TypeInformation<Integer[]> type =
207+
ObjectArrayTypeInfo.getInfoFor(BasicTypeInfo.INT_TYPE_INFO);
208+
209+
maxIndicesState =
210+
context.getOperatorStateStore()
211+
.getListState(new ListStateDescriptor<>("maxIndices", type));
212+
213+
maxIndices =
214+
OperatorStateUtils.getUniqueElement(maxIndicesState, "maxIndices").orElse(null);
215+
}
216+
217+
@Override
218+
public void snapshotState(StateSnapshotContext context) throws Exception {
219+
super.snapshotState(context);
220+
maxIndicesState.update(Collections.singletonList(maxIndices));
221+
}
222+
223+
@Override
224+
public void processElement(StreamRecord<Integer[]> streamRecord) {
225+
if (maxIndices == null) {
226+
maxIndices = streamRecord.getValue();
227+
} else {
228+
Integer[] indices = streamRecord.getValue();
229+
for (int i = 0; i < maxIndices.length; i++) {
230+
if (indices[i] > maxIndices[i]) {
231+
maxIndices[i] = indices[i];
232+
}
233+
}
141234
}
142-
for (Map.Entry<Integer, Integer> entry : map.entrySet()) {
143-
collector.collect(new Tuple2<>(entry.getKey(), entry.getValue()));
235+
}
236+
237+
@Override
238+
public void endInput() {
239+
for (int i = 0; i < maxIndices.length; i++) {
240+
output.collect(new StreamRecord<>(Tuple2.of(i, maxIndices[i])));
144241
}
145242
}
146243
}

0 commit comments

Comments
 (0)