Skip to content

Commit 341df45

Browse files
zhipeng93lindong28
authored andcommitted
[FLINK-27877] Reduce the length of the operator chain for generating input table
1 parent ba77607 commit 341df45

File tree

6 files changed

+275
-264
lines changed

6 files changed

+275
-264
lines changed

flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/DenseVectorArrayGenerator.java

Lines changed: 32 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -18,100 +18,50 @@
1818

1919
package org.apache.flink.ml.benchmark.datagenerator.common;
2020

21-
import org.apache.flink.api.common.functions.RichMapFunction;
22-
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
23-
import org.apache.flink.api.java.tuple.Tuple2;
24-
import org.apache.flink.configuration.Configuration;
25-
import org.apache.flink.ml.benchmark.datagenerator.InputDataGenerator;
21+
import org.apache.flink.api.common.typeinfo.TypeInformation;
22+
import org.apache.flink.api.java.typeutils.RowTypeInfo;
2623
import org.apache.flink.ml.benchmark.datagenerator.param.HasArraySize;
2724
import org.apache.flink.ml.benchmark.datagenerator.param.HasVectorDim;
28-
import org.apache.flink.ml.common.datastream.TableUtils;
2925
import org.apache.flink.ml.linalg.DenseVector;
30-
import org.apache.flink.ml.param.Param;
31-
import org.apache.flink.ml.util.ParamUtils;
32-
import org.apache.flink.streaming.api.datastream.DataStream;
33-
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
34-
import org.apache.flink.table.api.DataTypes;
35-
import org.apache.flink.table.api.Schema;
36-
import org.apache.flink.table.api.Table;
37-
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
38-
import org.apache.flink.util.NumberSequenceIterator;
26+
import org.apache.flink.types.Row;
3927
import org.apache.flink.util.Preconditions;
4028

41-
import java.util.HashMap;
42-
import java.util.Map;
43-
import java.util.Random;
44-
4529
/** A DataGenerator which creates a table of DenseVector array. */
46-
public class DenseVectorArrayGenerator
47-
implements InputDataGenerator<DenseVectorArrayGenerator>,
48-
HasArraySize<DenseVectorArrayGenerator>,
30+
public class DenseVectorArrayGenerator extends InputTableGenerator<DenseVectorArrayGenerator>
31+
implements HasArraySize<DenseVectorArrayGenerator>,
4932
HasVectorDim<DenseVectorArrayGenerator> {
50-
private final Map<Param<?>, Object> paramMap = new HashMap<>();
51-
52-
public DenseVectorArrayGenerator() {
53-
ParamUtils.initializeMapWithDefaultValues(paramMap, this);
54-
}
5533

5634
@Override
57-
public Table[] getData(StreamTableEnvironment tEnv) {
58-
StreamExecutionEnvironment env = TableUtils.getExecutionEnvironment(tEnv);
59-
60-
DataStream<DenseVector[]> dataStream =
61-
env.fromParallelCollection(
62-
new NumberSequenceIterator(1L, getNumValues()),
63-
BasicTypeInfo.LONG_TYPE_INFO)
64-
.map(
65-
new GenerateRandomContinuousVectorArrayFunction(
66-
getSeed(), getVectorDim(), getArraySize()));
67-
68-
Schema schema = Schema.newBuilder().column("f0", DataTypes.of(DenseVector[].class)).build();
69-
Table dataTable = tEnv.fromDataStream(dataStream, schema);
70-
if (getColNames() != null) {
71-
Preconditions.checkState(getColNames().length == 1);
72-
Preconditions.checkState(getColNames()[0].length == 1);
73-
dataTable = dataTable.as(getColNames()[0][0]);
74-
}
75-
76-
return new Table[] {dataTable};
77-
}
35+
protected RowGenerator[] getRowGenerators() {
36+
String[][] columnNames = getColNames();
37+
Preconditions.checkState(columnNames.length == 1);
38+
Preconditions.checkState(columnNames[0].length == 1);
39+
int arraySize = getArraySize();
40+
int vectorDim = getVectorDim();
7841

79-
private static class GenerateRandomContinuousVectorArrayFunction
80-
extends RichMapFunction<Long, DenseVector[]> {
81-
private final int vectorDim;
82-
private final long initSeed;
83-
private final int arraySize;
84-
private Random random;
85-
86-
private GenerateRandomContinuousVectorArrayFunction(
87-
long initSeed, int vectorDim, int arraySize) {
88-
this.vectorDim = vectorDim;
89-
this.initSeed = initSeed;
90-
this.arraySize = arraySize;
91-
}
92-
93-
@Override
94-
public void open(Configuration parameters) throws Exception {
95-
super.open(parameters);
96-
int index = getRuntimeContext().getIndexOfThisSubtask();
97-
random = new Random(Tuple2.of(initSeed, index).hashCode());
98-
}
42+
return new RowGenerator[] {
43+
new RowGenerator(getNumValues(), getSeed()) {
44+
@Override
45+
protected Row nextRow() {
46+
DenseVector[] result = new DenseVector[arraySize];
47+
for (int i = 0; i < arraySize; i++) {
48+
result[i] = new DenseVector(vectorDim);
49+
for (int j = 0; j < vectorDim; j++) {
50+
result[i].values[j] = random.nextDouble();
51+
}
52+
}
53+
Row row = new Row(1);
54+
row.setField(0, result);
55+
return row;
56+
}
9957

100-
@Override
101-
public DenseVector[] map(Long value) {
102-
DenseVector[] result = new DenseVector[arraySize];
103-
for (int i = 0; i < arraySize; i++) {
104-
result[i] = new DenseVector(vectorDim);
105-
for (int j = 0; j < vectorDim; j++) {
106-
result[i].values[j] = random.nextDouble();
58+
@Override
59+
protected RowTypeInfo getRowTypeInfo() {
60+
return new RowTypeInfo(
61+
new TypeInformation[] {TypeInformation.of(DenseVector[].class)},
62+
columnNames[0]);
10763
}
10864
}
109-
return result;
110-
}
111-
}
112-
113-
@Override
114-
public Map<Param<?>, Object> getParamMap() {
115-
return paramMap;
65+
};
11666
}
11767
}

flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/DenseVectorGenerator.java

Lines changed: 30 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -18,86 +18,43 @@
1818

1919
package org.apache.flink.ml.benchmark.datagenerator.common;
2020

21-
import org.apache.flink.api.common.functions.RichMapFunction;
22-
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
23-
import org.apache.flink.api.java.tuple.Tuple2;
24-
import org.apache.flink.configuration.Configuration;
25-
import org.apache.flink.ml.benchmark.datagenerator.InputDataGenerator;
21+
import org.apache.flink.api.common.typeinfo.TypeInformation;
22+
import org.apache.flink.api.java.typeutils.RowTypeInfo;
2623
import org.apache.flink.ml.benchmark.datagenerator.param.HasVectorDim;
27-
import org.apache.flink.ml.common.datastream.TableUtils;
28-
import org.apache.flink.ml.linalg.DenseVector;
2924
import org.apache.flink.ml.linalg.Vectors;
30-
import org.apache.flink.ml.param.Param;
31-
import org.apache.flink.ml.util.ParamUtils;
32-
import org.apache.flink.streaming.api.datastream.DataStream;
33-
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
34-
import org.apache.flink.table.api.Table;
35-
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
36-
import org.apache.flink.util.NumberSequenceIterator;
25+
import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
26+
import org.apache.flink.types.Row;
3727
import org.apache.flink.util.Preconditions;
3828

39-
import java.util.HashMap;
40-
import java.util.Map;
41-
import java.util.Random;
42-
4329
/** A DataGenerator which creates a table of DenseVector. */
44-
public class DenseVectorGenerator
45-
implements InputDataGenerator<DenseVectorGenerator>, HasVectorDim<DenseVectorGenerator> {
46-
private final Map<Param<?>, Object> paramMap = new HashMap<>();
47-
48-
public DenseVectorGenerator() {
49-
ParamUtils.initializeMapWithDefaultValues(paramMap, this);
50-
}
30+
public class DenseVectorGenerator extends InputTableGenerator<DenseVectorGenerator>
31+
implements HasVectorDim<DenseVectorGenerator> {
5132

5233
@Override
53-
public Table[] getData(StreamTableEnvironment tEnv) {
54-
StreamExecutionEnvironment env = TableUtils.getExecutionEnvironment(tEnv);
55-
56-
DataStream<DenseVector> dataStream =
57-
env.fromParallelCollection(
58-
new NumberSequenceIterator(1L, getNumValues()),
59-
BasicTypeInfo.LONG_TYPE_INFO)
60-
.map(new RandomDenseVectorGenerator(getSeed(), getVectorDim()));
61-
62-
Table dataTable = tEnv.fromDataStream(dataStream);
63-
if (getColNames() != null) {
64-
Preconditions.checkState(getColNames().length == 1);
65-
Preconditions.checkState(getColNames()[0].length == 1);
66-
dataTable = dataTable.as(getColNames()[0][0]);
67-
}
68-
69-
return new Table[] {dataTable};
70-
}
71-
72-
private static class RandomDenseVectorGenerator extends RichMapFunction<Long, DenseVector> {
73-
private final int vectorDim;
74-
private final long initSeed;
75-
private Random random;
76-
77-
private RandomDenseVectorGenerator(long initSeed, int vectorDim) {
78-
this.vectorDim = vectorDim;
79-
this.initSeed = initSeed;
80-
}
81-
82-
@Override
83-
public void open(Configuration parameters) throws Exception {
84-
super.open(parameters);
85-
int index = getRuntimeContext().getIndexOfThisSubtask();
86-
random = new Random(Tuple2.of(initSeed, index).hashCode());
87-
}
88-
89-
@Override
90-
public DenseVector map(Long value) {
91-
double[] values = new double[vectorDim];
92-
for (int i = 0; i < vectorDim; i++) {
93-
values[i] = random.nextDouble();
34+
public RowGenerator[] getRowGenerators() {
35+
String[][] columnNames = getColNames();
36+
Preconditions.checkState(columnNames.length == 1);
37+
Preconditions.checkState(columnNames[0].length == 1);
38+
int vectorDim = getVectorDim();
39+
40+
return new RowGenerator[] {
41+
new RowGenerator(getNumValues(), getSeed()) {
42+
43+
@Override
44+
protected Row nextRow() {
45+
double[] values = new double[vectorDim];
46+
for (int i = 0; i < values.length; i++) {
47+
values[i] = random.nextDouble();
48+
}
49+
return Row.of(Vectors.dense(values));
50+
}
51+
52+
@Override
53+
protected RowTypeInfo getRowTypeInfo() {
54+
return new RowTypeInfo(
55+
new TypeInformation[] {DenseVectorTypeInfo.INSTANCE}, columnNames[0]);
56+
}
9457
}
95-
return Vectors.dense(values);
96-
}
97-
}
98-
99-
@Override
100-
public Map<Param<?>, Object> getParamMap() {
101-
return paramMap;
58+
};
10259
}
10360
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.flink.ml.benchmark.datagenerator.common;
20+
21+
import org.apache.flink.ml.benchmark.datagenerator.InputDataGenerator;
22+
import org.apache.flink.ml.common.datastream.TableUtils;
23+
import org.apache.flink.ml.param.Param;
24+
import org.apache.flink.ml.util.ParamUtils;
25+
import org.apache.flink.streaming.api.datastream.DataStream;
26+
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
27+
import org.apache.flink.table.api.Table;
28+
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
29+
import org.apache.flink.types.Row;
30+
31+
import java.util.HashMap;
32+
import java.util.Map;
33+
34+
/** Base class for generating data as input table arrays. */
35+
public abstract class InputTableGenerator<T extends InputTableGenerator<T>>
36+
implements InputDataGenerator<T> {
37+
protected final Map<Param<?>, Object> paramMap = new HashMap<>();
38+
39+
public InputTableGenerator() {
40+
ParamUtils.initializeMapWithDefaultValues(paramMap, this);
41+
}
42+
43+
@Override
44+
public final Table[] getData(StreamTableEnvironment tEnv) {
45+
StreamExecutionEnvironment env = TableUtils.getExecutionEnvironment(tEnv);
46+
47+
RowGenerator[] rowGenerators = getRowGenerators();
48+
Table[] dataTables = new Table[rowGenerators.length];
49+
for (int i = 0; i < rowGenerators.length; i++) {
50+
DataStream<Row> dataStream =
51+
env.addSource(rowGenerators[i], "sourceOp-" + i)
52+
.returns(rowGenerators[i].getRowTypeInfo());
53+
dataTables[i] = tEnv.fromDataStream(dataStream);
54+
}
55+
56+
return dataTables;
57+
}
58+
59+
/** Gets generators for all input tables. */
60+
protected abstract RowGenerator[] getRowGenerators();
61+
62+
@Override
63+
public final Map<Param<?>, Object> getParamMap() {
64+
return paramMap;
65+
}
66+
}

0 commit comments

Comments
 (0)