|
18 | 18 |
|
19 | 19 | package org.apache.flink.ml.benchmark.datagenerator.common;
|
20 | 20 |
|
21 |
| -import org.apache.flink.api.common.functions.RichMapFunction; |
22 |
| -import org.apache.flink.api.common.typeinfo.BasicTypeInfo; |
23 |
| -import org.apache.flink.api.java.tuple.Tuple2; |
24 |
| -import org.apache.flink.configuration.Configuration; |
25 |
| -import org.apache.flink.ml.benchmark.datagenerator.InputDataGenerator; |
| 21 | +import org.apache.flink.api.common.typeinfo.TypeInformation; |
| 22 | +import org.apache.flink.api.java.typeutils.RowTypeInfo; |
26 | 23 | import org.apache.flink.ml.benchmark.datagenerator.param.HasArraySize;
|
27 | 24 | import org.apache.flink.ml.benchmark.datagenerator.param.HasVectorDim;
|
28 |
| -import org.apache.flink.ml.common.datastream.TableUtils; |
29 | 25 | import org.apache.flink.ml.linalg.DenseVector;
|
30 |
| -import org.apache.flink.ml.param.Param; |
31 |
| -import org.apache.flink.ml.util.ParamUtils; |
32 |
| -import org.apache.flink.streaming.api.datastream.DataStream; |
33 |
| -import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; |
34 |
| -import org.apache.flink.table.api.DataTypes; |
35 |
| -import org.apache.flink.table.api.Schema; |
36 |
| -import org.apache.flink.table.api.Table; |
37 |
| -import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; |
38 |
| -import org.apache.flink.util.NumberSequenceIterator; |
| 26 | +import org.apache.flink.types.Row; |
39 | 27 | import org.apache.flink.util.Preconditions;
|
40 | 28 |
|
41 |
| -import java.util.HashMap; |
42 |
| -import java.util.Map; |
43 |
| -import java.util.Random; |
44 |
| - |
45 | 29 | /** A DataGenerator which creates a table of DenseVector array. */
|
46 |
| -public class DenseVectorArrayGenerator |
47 |
| - implements InputDataGenerator<DenseVectorArrayGenerator>, |
48 |
| - HasArraySize<DenseVectorArrayGenerator>, |
| 30 | +public class DenseVectorArrayGenerator extends InputTableGenerator<DenseVectorArrayGenerator> |
| 31 | + implements HasArraySize<DenseVectorArrayGenerator>, |
49 | 32 | HasVectorDim<DenseVectorArrayGenerator> {
|
50 |
| - private final Map<Param<?>, Object> paramMap = new HashMap<>(); |
51 |
| - |
52 |
| - public DenseVectorArrayGenerator() { |
53 |
| - ParamUtils.initializeMapWithDefaultValues(paramMap, this); |
54 |
| - } |
55 | 33 |
|
56 | 34 | @Override
|
57 |
| - public Table[] getData(StreamTableEnvironment tEnv) { |
58 |
| - StreamExecutionEnvironment env = TableUtils.getExecutionEnvironment(tEnv); |
59 |
| - |
60 |
| - DataStream<DenseVector[]> dataStream = |
61 |
| - env.fromParallelCollection( |
62 |
| - new NumberSequenceIterator(1L, getNumValues()), |
63 |
| - BasicTypeInfo.LONG_TYPE_INFO) |
64 |
| - .map( |
65 |
| - new GenerateRandomContinuousVectorArrayFunction( |
66 |
| - getSeed(), getVectorDim(), getArraySize())); |
67 |
| - |
68 |
| - Schema schema = Schema.newBuilder().column("f0", DataTypes.of(DenseVector[].class)).build(); |
69 |
| - Table dataTable = tEnv.fromDataStream(dataStream, schema); |
70 |
| - if (getColNames() != null) { |
71 |
| - Preconditions.checkState(getColNames().length == 1); |
72 |
| - Preconditions.checkState(getColNames()[0].length == 1); |
73 |
| - dataTable = dataTable.as(getColNames()[0][0]); |
74 |
| - } |
75 |
| - |
76 |
| - return new Table[] {dataTable}; |
77 |
| - } |
| 35 | + protected RowGenerator[] getRowGenerators() { |
| 36 | + String[][] columnNames = getColNames(); |
| 37 | + Preconditions.checkState(columnNames.length == 1); |
| 38 | + Preconditions.checkState(columnNames[0].length == 1); |
| 39 | + int arraySize = getArraySize(); |
| 40 | + int vectorDim = getVectorDim(); |
78 | 41 |
|
79 |
| - private static class GenerateRandomContinuousVectorArrayFunction |
80 |
| - extends RichMapFunction<Long, DenseVector[]> { |
81 |
| - private final int vectorDim; |
82 |
| - private final long initSeed; |
83 |
| - private final int arraySize; |
84 |
| - private Random random; |
85 |
| - |
86 |
| - private GenerateRandomContinuousVectorArrayFunction( |
87 |
| - long initSeed, int vectorDim, int arraySize) { |
88 |
| - this.vectorDim = vectorDim; |
89 |
| - this.initSeed = initSeed; |
90 |
| - this.arraySize = arraySize; |
91 |
| - } |
92 |
| - |
93 |
| - @Override |
94 |
| - public void open(Configuration parameters) throws Exception { |
95 |
| - super.open(parameters); |
96 |
| - int index = getRuntimeContext().getIndexOfThisSubtask(); |
97 |
| - random = new Random(Tuple2.of(initSeed, index).hashCode()); |
98 |
| - } |
| 42 | + return new RowGenerator[] { |
| 43 | + new RowGenerator(getNumValues(), getSeed()) { |
| 44 | + @Override |
| 45 | + protected Row nextRow() { |
| 46 | + DenseVector[] result = new DenseVector[arraySize]; |
| 47 | + for (int i = 0; i < arraySize; i++) { |
| 48 | + result[i] = new DenseVector(vectorDim); |
| 49 | + for (int j = 0; j < vectorDim; j++) { |
| 50 | + result[i].values[j] = random.nextDouble(); |
| 51 | + } |
| 52 | + } |
| 53 | + Row row = new Row(1); |
| 54 | + row.setField(0, result); |
| 55 | + return row; |
| 56 | + } |
99 | 57 |
|
100 |
| - @Override |
101 |
| - public DenseVector[] map(Long value) { |
102 |
| - DenseVector[] result = new DenseVector[arraySize]; |
103 |
| - for (int i = 0; i < arraySize; i++) { |
104 |
| - result[i] = new DenseVector(vectorDim); |
105 |
| - for (int j = 0; j < vectorDim; j++) { |
106 |
| - result[i].values[j] = random.nextDouble(); |
| 58 | + @Override |
| 59 | + protected RowTypeInfo getRowTypeInfo() { |
| 60 | + return new RowTypeInfo( |
| 61 | + new TypeInformation[] {TypeInformation.of(DenseVector[].class)}, |
| 62 | + columnNames[0]); |
107 | 63 | }
|
108 | 64 | }
|
109 |
| - return result; |
110 |
| - } |
111 |
| - } |
112 |
| - |
113 |
| - @Override |
114 |
| - public Map<Param<?>, Object> getParamMap() { |
115 |
| - return paramMap; |
| 65 | + }; |
116 | 66 | }
|
117 | 67 | }
|
0 commit comments