Skip to content

Commit 58fce03

Browse files
zhipeng93lindong28
authored andcommitted
[FLINK-27877] Add benchmark configuration for StringIndexer, StandardScaler and Bucketizer
1 parent 341df45 commit 58fce03

File tree

6 files changed

+326
-0
lines changed

6 files changed

+326
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.flink.ml.benchmark.datagenerator.common;
20+
21+
import org.apache.flink.api.common.typeinfo.TypeInformation;
22+
import org.apache.flink.api.common.typeinfo.Types;
23+
import org.apache.flink.api.java.typeutils.RowTypeInfo;
24+
import org.apache.flink.types.Row;
25+
import org.apache.flink.util.Preconditions;
26+
27+
import java.util.Arrays;
28+
29+
/** A DataGenerator which creates a table of doubles. */
30+
public class DoubleGenerator extends InputTableGenerator<DoubleGenerator> {
31+
32+
@Override
33+
protected RowGenerator[] getRowGenerators() {
34+
String[][] colNames = getColNames();
35+
Preconditions.checkState(colNames.length == 1);
36+
int numOutputCols = colNames[0].length;
37+
38+
return new RowGenerator[] {
39+
new RowGenerator(getNumValues(), getSeed()) {
40+
@Override
41+
public Row nextRow() {
42+
Row r = new Row(numOutputCols);
43+
for (int i = 0; i < numOutputCols; i++) {
44+
r.setField(i, random.nextDouble());
45+
}
46+
return r;
47+
}
48+
49+
@Override
50+
protected RowTypeInfo getRowTypeInfo() {
51+
TypeInformation[] outputTypes = new TypeInformation[colNames[0].length];
52+
Arrays.fill(outputTypes, Types.DOUBLE);
53+
return new RowTypeInfo(outputTypes, colNames[0]);
54+
}
55+
}
56+
};
57+
}
58+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.flink.ml.benchmark.datagenerator.common;
20+
21+
import org.apache.flink.api.common.typeinfo.TypeInformation;
22+
import org.apache.flink.api.common.typeinfo.Types;
23+
import org.apache.flink.api.java.typeutils.RowTypeInfo;
24+
import org.apache.flink.ml.param.IntParam;
25+
import org.apache.flink.ml.param.Param;
26+
import org.apache.flink.ml.param.ParamValidators;
27+
import org.apache.flink.types.Row;
28+
import org.apache.flink.util.Preconditions;
29+
30+
import java.util.Arrays;
31+
32+
/** A DataGenerator which creates a table of random strings. */
33+
public class RandomStringGenerator extends InputTableGenerator<RandomStringGenerator> {
34+
public static final Param<Integer> NUM_DISTINCT_VALUE =
35+
new IntParam(
36+
"numDistinctValue",
37+
"Number of distinct values of the data to be generated.",
38+
10,
39+
ParamValidators.gt(0));
40+
41+
public int getNumDistinctValue() {
42+
return get(NUM_DISTINCT_VALUE);
43+
}
44+
45+
public RandomStringGenerator setNumDistinctValue(int value) {
46+
return set(NUM_DISTINCT_VALUE, value);
47+
}
48+
49+
@Override
50+
protected RowGenerator[] getRowGenerators() {
51+
String[][] colNames = getColNames();
52+
Preconditions.checkState(colNames.length == 1);
53+
int numOutputCols = colNames[0].length;
54+
int numDistinctValues = getNumDistinctValue();
55+
56+
return new RowGenerator[] {
57+
new RowGenerator(getNumValues(), getSeed()) {
58+
@Override
59+
public Row nextRow() {
60+
Row r = new Row(numOutputCols);
61+
for (int i = 0; i < numOutputCols; i++) {
62+
r.setField(i, Integer.toString(random.nextInt(numDistinctValues)));
63+
}
64+
return r;
65+
}
66+
67+
@Override
68+
protected RowTypeInfo getRowTypeInfo() {
69+
TypeInformation[] outputTypes = new TypeInformation[colNames[0].length];
70+
Arrays.fill(outputTypes, Types.STRING);
71+
return new RowTypeInfo(outputTypes, colNames[0]);
72+
}
73+
}
74+
};
75+
}
76+
}
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one or more
2+
// contributor license agreements. See the NOTICE file distributed with
3+
// this work for additional information regarding copyright ownership.
4+
// The ASF licenses this file to You under the Apache License, Version 2.0
5+
// (the "License"); you may not use this file except in compliance with
6+
// the License. You may obtain a copy of the License at
7+
//
8+
// http://www.apache.org/licenses/LICENSE-2.0
9+
//
10+
// Unless required by applicable law or agreed to in writing, software
11+
// distributed under the License is distributed on an "AS IS" BASIS,
12+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
// See the License for the specific language governing permissions and
14+
// limitations under the License.
15+
16+
{
17+
"version": 1,
18+
"bucketizer100000000": {
19+
"inputData": {
20+
"className": "org.apache.flink.ml.benchmark.datagenerator.common.DoubleGenerator",
21+
"paramMap": {
22+
"colNames": [
23+
[
24+
"col0"
25+
]
26+
],
27+
"seed": 2,
28+
"numValues": 100000000
29+
}
30+
},
31+
"stage": {
32+
"className": "org.apache.flink.ml.feature.bucketizer.Bucketizer",
33+
"paramMap": {
34+
"outputCols": [
35+
"outputCol0"
36+
],
37+
"handleInvalid": "skip",
38+
"inputCols": [
39+
"col0"
40+
],
41+
"splitsArray": [
42+
[
43+
-1.0,
44+
0.0,
45+
0.5,
46+
1.0,
47+
2.0
48+
]
49+
]
50+
}
51+
}
52+
}
53+
}
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one or more
2+
// contributor license agreements. See the NOTICE file distributed with
3+
// this work for additional information regarding copyright ownership.
4+
// The ASF licenses this file to You under the Apache License, Version 2.0
5+
// (the "License"); you may not use this file except in compliance with
6+
// the License. You may obtain a copy of the License at
7+
//
8+
// http://www.apache.org/licenses/LICENSE-2.0
9+
//
10+
// Unless required by applicable law or agreed to in writing, software
11+
// distributed under the License is distributed on an "AS IS" BASIS,
12+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
// See the License for the specific language governing permissions and
14+
// limitations under the License.
15+
16+
{
17+
"version": 1,
18+
"standardscaler10000000": {
19+
"inputData": {
20+
"className": "org.apache.flink.ml.benchmark.datagenerator.common.DenseVectorGenerator",
21+
"paramMap": {
22+
"vectorDim": 100,
23+
"colNames": [
24+
[
25+
"features"
26+
]
27+
],
28+
"seed": 2,
29+
"numValues": 10000000
30+
}
31+
},
32+
"stage": {
33+
"className": "org.apache.flink.ml.feature.standardscaler.StandardScaler",
34+
"paramMap": {
35+
"inputCol": "features",
36+
"withMean": true,
37+
"withStd": true,
38+
"outputCol": "outputCol"
39+
}
40+
}
41+
}
42+
}
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one or more
2+
// contributor license agreements. See the NOTICE file distributed with
3+
// this work for additional information regarding copyright ownership.
4+
// The ASF licenses this file to You under the Apache License, Version 2.0
5+
// (the "License"); you may not use this file except in compliance with
6+
// the License. You may obtain a copy of the License at
7+
//
8+
// http://www.apache.org/licenses/LICENSE-2.0
9+
//
10+
// Unless required by applicable law or agreed to in writing, software
11+
// distributed under the License is distributed on an "AS IS" BASIS,
12+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
// See the License for the specific language governing permissions and
14+
// limitations under the License.
15+
16+
{
17+
"version": 1,
18+
"stringindexer100000000": {
19+
"inputData": {
20+
"className": "org.apache.flink.ml.benchmark.datagenerator.common.RandStringGenerator",
21+
"paramMap": {
22+
"colNames": [
23+
[
24+
"col0"
25+
]
26+
],
27+
"seed": 2,
28+
"numValues": 100000000,
29+
"numDistinctValue": 100
30+
}
31+
},
32+
"stage": {
33+
"className": "org.apache.flink.ml.feature.stringindexer.StringIndexer",
34+
"paramMap": {
35+
"outputCols": [
36+
"outputCol0"
37+
],
38+
"handleInvalid": "skip",
39+
"inputCols": [
40+
"col0"
41+
],
42+
"stringOrderType": "arbitrary"
43+
}
44+
}
45+
}
46+
}

flink-ml-benchmark/src/test/java/org/apache/flink/ml/benchmark/DataGeneratorTest.java

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@
2222
import org.apache.flink.configuration.Configuration;
2323
import org.apache.flink.ml.benchmark.datagenerator.common.DenseVectorArrayGenerator;
2424
import org.apache.flink.ml.benchmark.datagenerator.common.DenseVectorGenerator;
25+
import org.apache.flink.ml.benchmark.datagenerator.common.DoubleGenerator;
2526
import org.apache.flink.ml.benchmark.datagenerator.common.LabeledPointWithWeightGenerator;
27+
import org.apache.flink.ml.benchmark.datagenerator.common.RandomStringGenerator;
2628
import org.apache.flink.ml.linalg.DenseVector;
2729
import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
2830
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
@@ -136,4 +138,53 @@ public void testLabeledPointWithWeightGenerator() {
136138
}
137139
assertEquals(generator.getNumValues(), count);
138140
}
141+
142+
@Test
143+
public void testRandomStringGenerator() {
144+
String col1 = "col1";
145+
String col2 = "col2";
146+
147+
RandomStringGenerator generator =
148+
new RandomStringGenerator()
149+
.setColNames(new String[] {col1, col2})
150+
.setSeed(2L)
151+
.setNumValues(5)
152+
.setNumDistinctValue(2);
153+
154+
int count = 0;
155+
for (CloseableIterator<Row> it = generator.getData(tEnv)[0].execute().collect();
156+
it.hasNext(); ) {
157+
Row row = it.next();
158+
count++;
159+
String value1 = (String) row.getField(col1);
160+
String value2 = (String) row.getField(col2);
161+
assertTrue(Integer.parseInt(value1) < generator.getNumDistinctValue());
162+
assertTrue(Integer.parseInt(value2) < generator.getNumDistinctValue());
163+
}
164+
assertEquals(generator.getNumValues(), count);
165+
}
166+
167+
@Test
168+
public void testDoubleGenerator() {
169+
String col1 = "col1";
170+
String col2 = "col2";
171+
172+
DoubleGenerator generator =
173+
new DoubleGenerator()
174+
.setColNames(new String[] {"col1", "col2"})
175+
.setSeed(2L)
176+
.setNumValues(5);
177+
178+
int count = 0;
179+
for (CloseableIterator<Row> it = generator.getData(tEnv)[0].execute().collect();
180+
it.hasNext(); ) {
181+
Row row = it.next();
182+
count++;
183+
double value1 = (Double) row.getField(col1);
184+
double value2 = (Double) row.getField(col2);
185+
assertTrue(value1 <= 1 && value1 >= 0);
186+
assertTrue(value2 <= 1 && value2 >= 0);
187+
}
188+
assertEquals(generator.getNumValues(), count);
189+
}
139190
}

0 commit comments

Comments
 (0)