Skip to content

Commit ce3e9cb

Browse files
authored
[FLINK-27096] Optimize VectorAssembler performance
This closes apache#114.
1 parent de08eb8 commit ce3e9cb

File tree

1 file changed

+82
-50
lines changed

1 file changed

+82
-50
lines changed

flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssembler.java

Lines changed: 82 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import org.apache.flink.ml.linalg.DenseVector;
2727
import org.apache.flink.ml.linalg.SparseVector;
2828
import org.apache.flink.ml.linalg.Vector;
29+
import org.apache.flink.ml.linalg.Vectors;
2930
import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo;
3031
import org.apache.flink.ml.param.Param;
3132
import org.apache.flink.ml.util.ParamUtils;
@@ -42,7 +43,6 @@
4243

4344
import java.io.IOException;
4445
import java.util.HashMap;
45-
import java.util.LinkedHashMap;
4646
import java.util.Map;
4747

4848
/**
@@ -90,14 +90,27 @@ public AssemblerFunc(String[] inputCols, String handleInvalid) {
9090
}
9191

9292
@Override
93-
public void flatMap(Row value, Collector<Row> out) throws Exception {
93+
public void flatMap(Row value, Collector<Row> out) {
94+
int nnz = 0;
95+
int vectorSize = 0;
9496
try {
95-
Object[] objects = new Object[inputCols.length];
96-
for (int i = 0; i < objects.length; ++i) {
97-
objects[i] = value.getField(inputCols[i]);
97+
for (String inputCol : inputCols) {
98+
Object object = value.getField(inputCol);
99+
Preconditions.checkNotNull(object, "Input column value should not be null.");
100+
if (object instanceof Number) {
101+
nnz += 1;
102+
vectorSize += 1;
103+
} else if (object instanceof SparseVector) {
104+
nnz += ((SparseVector) object).indices.length;
105+
vectorSize += ((SparseVector) object).size();
106+
} else if (object instanceof DenseVector) {
107+
nnz += ((DenseVector) object).size();
108+
vectorSize += ((DenseVector) object).size();
109+
} else {
110+
throw new IllegalArgumentException(
111+
"Input type has not been supported yet.");
112+
}
98113
}
99-
Vector assembledVector = assemble(objects);
100-
out.collect(Row.join(value, Row.of(assembledVector)));
101114
} catch (Exception e) {
102115
switch (handleInvalid) {
103116
case ERROR_INVALID:
@@ -112,6 +125,13 @@ public void flatMap(Row value, Collector<Row> out) throws Exception {
112125
"Unsupported " + HANDLE_INVALID + " type: " + handleInvalid);
113126
}
114127
}
128+
129+
boolean toDense = nnz * RATIO > vectorSize;
130+
Vector assembledVec =
131+
toDense
132+
? assembleDense(inputCols, value, vectorSize)
133+
: assembleSparse(inputCols, value, vectorSize, nnz);
134+
out.collect(Row.join(value, Row.of(assembledVec)));
115135
}
116136
}
117137

@@ -129,57 +149,69 @@ public Map<Param<?>, Object> getParamMap() {
129149
return paramMap;
130150
}
131151

132-
private static Vector assemble(Object[] objects) {
133-
int offset = 0;
134-
Map<Integer, Double> map = new LinkedHashMap<>(objects.length);
135-
for (Object object : objects) {
136-
Preconditions.checkNotNull(object, "Input column value should not be null.");
152+
/** Assembles the input columns into a dense vector. */
153+
private static Vector assembleDense(String[] inputCols, Row inputRow, int vectorSize) {
154+
double[] values = new double[vectorSize];
155+
int currentOffset = 0;
156+
157+
for (String inputCol : inputCols) {
158+
Object object = inputRow.getField(inputCol);
137159
if (object instanceof Number) {
138-
map.put(offset++, ((Number) object).doubleValue());
139-
} else if (object instanceof Vector) {
140-
offset = appendVector((Vector) object, map, offset);
160+
values[currentOffset++] = ((Number) object).doubleValue();
161+
} else if (object instanceof SparseVector) {
162+
SparseVector sparseVector = (SparseVector) object;
163+
for (int i = 0; i < sparseVector.indices.length; i++) {
164+
values[currentOffset + sparseVector.indices[i]] = sparseVector.values[i];
165+
}
166+
currentOffset += sparseVector.size();
167+
141168
} else {
142-
throw new IllegalArgumentException("Input type has not been supported yet.");
143-
}
144-
}
169+
DenseVector denseVector = (DenseVector) object;
170+
System.arraycopy(
171+
denseVector.values, 0, values, currentOffset, denseVector.values.length);
145172

146-
if (map.size() * RATIO > offset) {
147-
DenseVector assembledVector = new DenseVector(offset);
148-
for (int key : map.keySet()) {
149-
assembledVector.values[key] = map.get(key);
173+
currentOffset += denseVector.size();
150174
}
151-
return assembledVector;
152-
} else {
153-
return convertMapToSparseVector(offset, map);
154175
}
176+
return Vectors.dense(values);
155177
}
156178

157-
private static int appendVector(Vector vec, Map<Integer, Double> map, int offset) {
158-
if (vec instanceof SparseVector) {
159-
SparseVector sparseVector = (SparseVector) vec;
160-
int[] indices = sparseVector.indices;
161-
double[] values = sparseVector.values;
162-
for (int i = 0; i < indices.length; ++i) {
163-
map.put(offset + indices[i], values[i]);
164-
}
165-
offset += sparseVector.size();
166-
} else {
167-
DenseVector denseVector = (DenseVector) vec;
168-
for (int i = 0; i < denseVector.size(); ++i) {
169-
map.put(offset++, denseVector.values[i]);
170-
}
171-
}
172-
return offset;
173-
}
179+
/** Assembles the input columns into a sparse vector. */
180+
private static Vector assembleSparse(
181+
String[] inputCols, Row inputRow, int vectorSize, int nnz) {
182+
int[] indices = new int[nnz];
183+
double[] values = new double[nnz];
174184

175-
private static SparseVector convertMapToSparseVector(int size, Map<Integer, Double> map) {
176-
int[] indices = new int[map.size()];
177-
double[] values = new double[map.size()];
178-
int offset = 0;
179-
for (Map.Entry<Integer, Double> entry : map.entrySet()) {
180-
indices[offset] = entry.getKey();
181-
values[offset++] = entry.getValue();
185+
int currentIndex = 0;
186+
int currentOffset = 0;
187+
188+
for (String inputCol : inputCols) {
189+
Object object = inputRow.getField(inputCol);
190+
if (object instanceof Number) {
191+
indices[currentOffset] = currentIndex;
192+
values[currentOffset] = ((Number) object).doubleValue();
193+
currentOffset++;
194+
currentIndex++;
195+
} else if (object instanceof SparseVector) {
196+
SparseVector sparseVector = (SparseVector) object;
197+
for (int i = 0; i < sparseVector.indices.length; i++) {
198+
indices[currentOffset + i] = sparseVector.indices[i] + currentIndex;
199+
}
200+
System.arraycopy(
201+
sparseVector.values, 0, values, currentOffset, sparseVector.values.length);
202+
currentIndex += sparseVector.size();
203+
currentOffset += sparseVector.indices.length;
204+
} else {
205+
DenseVector denseVector = (DenseVector) object;
206+
for (int i = 0; i < denseVector.size(); ++i) {
207+
indices[currentOffset + i] = i + currentIndex;
208+
}
209+
System.arraycopy(
210+
denseVector.values, 0, values, currentOffset, denseVector.values.length);
211+
currentIndex += denseVector.size();
212+
currentOffset += denseVector.size();
213+
}
182214
}
183-
return new SparseVector(size, indices, values);
215+
return new SparseVector(vectorSize, indices, values);
184216
}
185217
}

0 commit comments

Comments
 (0)