Skip to content

Commit 18ec1b3

Browse files
[FLINK-27096] Improve DataCache and KMeans performance
This closes apache#97.
1 parent 239788f commit 18ec1b3

File tree

35 files changed

+1513
-337
lines changed

35 files changed

+1513
-337
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,6 @@ target
1717
*.ipr
1818
*.iws
1919
*.pyc
20+
.vscode
2021
flink-ml-python/dist/
2122
flink-ml-python/apache_flink_ml.egg-info/

flink-ml-core/src/main/java/org/apache/flink/ml/common/broadcast/operator/AbstractBroadcastWrapperOperator.java

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -389,14 +389,12 @@ private void processPendingElementsAndWatermarks(
389389
ThrowingConsumer<StreamRecord, Exception> elementConsumer,
390390
ThrowingConsumer<Watermark, Exception> watermarkConsumer)
391391
throws Exception {
392-
dataCacheWriters[inputIndex].finishCurrentSegment();
393-
List<Segment> pendingSegments = dataCacheWriters[inputIndex].getFinishSegments();
392+
List<Segment> pendingSegments = dataCacheWriters[inputIndex].getSegments();
394393
if (pendingSegments.size() != 0) {
395394
DataCacheReader dataCacheReader =
396395
new DataCacheReader<>(
397396
new CacheElementTypeInfo<>(inTypes[inputIndex])
398397
.createSerializer(containingTask.getExecutionConfig()),
399-
basePath.getFileSystem(),
400398
pendingSegments);
401399
while (dataCacheReader.hasNext()) {
402400
CacheElement cacheElement = (CacheElement) dataCacheReader.next();
@@ -565,12 +563,10 @@ public void snapshotState(StateSnapshotContext stateSnapshotContext) throws Exce
565563
dos.writeInt(numInputs);
566564
}
567565
for (int i = 0; i < numInputs; i++) {
568-
dataCacheWriters[i].finishCurrentSegment();
566+
dataCacheWriters[i].writeSegmentsToFiles();
569567
DataCacheSnapshot dataCacheSnapshot =
570568
new DataCacheSnapshot(
571-
basePath.getFileSystem(),
572-
null,
573-
dataCacheWriters[i].getFinishSegments());
569+
basePath.getFileSystem(), null, dataCacheWriters[i].getSegments());
574570
dataCacheSnapshot.writeTo(checkpointOutputStream);
575571
}
576572
}

flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java

Lines changed: 156 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,25 +26,36 @@
2626
import org.apache.flink.api.common.state.ListState;
2727
import org.apache.flink.api.common.state.ListStateDescriptor;
2828
import org.apache.flink.api.common.typeinfo.TypeInformation;
29+
import org.apache.flink.api.common.typeutils.base.IntSerializer;
30+
import org.apache.flink.api.dag.Transformation;
2931
import org.apache.flink.api.java.tuple.Tuple2;
3032
import org.apache.flink.api.java.typeutils.TypeExtractor;
33+
import org.apache.flink.core.memory.ManagedMemoryUseCase;
34+
import org.apache.flink.iteration.datacache.nonkeyed.ListStateWithCache;
3135
import org.apache.flink.iteration.operator.OperatorStateUtils;
3236
import org.apache.flink.runtime.state.StateInitializationContext;
3337
import org.apache.flink.runtime.state.StateSnapshotContext;
3438
import org.apache.flink.streaming.api.datastream.DataStream;
3539
import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
40+
import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
3641
import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
3742
import org.apache.flink.streaming.api.operators.BoundedOneInput;
3843
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
3944
import org.apache.flink.streaming.api.operators.TimestampedCollector;
4045
import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
4146
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
47+
import org.apache.flink.table.api.TableException;
4248
import org.apache.flink.util.Collector;
4349

4450
import org.apache.commons.collections.IteratorUtils;
4551

52+
import java.util.ArrayList;
4653
import java.util.Arrays;
54+
import java.util.Collections;
55+
import java.util.Iterator;
4756
import java.util.List;
57+
import java.util.Optional;
58+
import java.util.Random;
4859

4960
/** Provides utility functions for {@link DataStream}. */
5061
@Internal
@@ -105,6 +116,56 @@ public static <T> DataStream<T> reduce(DataStream<T> input, ReduceFunction<T> fu
105116
}
106117
}
107118

119+
/**
120+
* Performs a uniform sampling over the elements in a bounded data stream.
121+
*
122+
* <p>This method takes samples without replacement. If the number of elements in the stream is
123+
* smaller than expected number of samples, all elements will be included in the sample.
124+
*
125+
* @param input The input data stream.
126+
* @param numSamples The number of elements to be sampled.
127+
* @param randomSeed The seed to randomly pick elements as sample.
128+
* @return A data stream containing a list of the sampled elements.
129+
*/
130+
public static <T> DataStream<T> sample(DataStream<T> input, int numSamples, long randomSeed) {
131+
int inputParallelism = input.getParallelism();
132+
133+
return input.transform(
134+
"samplingOperator",
135+
input.getType(),
136+
new SamplingOperator<>(numSamples, randomSeed))
137+
.setParallelism(inputParallelism)
138+
.transform(
139+
"samplingOperator",
140+
input.getType(),
141+
new SamplingOperator<>(numSamples, randomSeed))
142+
.setParallelism(1)
143+
.map(x -> x, input.getType())
144+
.setParallelism(inputParallelism);
145+
}
146+
147+
/**
148+
* Sets {Transformation#declareManagedMemoryUseCaseAtOperatorScope(ManagedMemoryUseCase, int)}
149+
* using the given bytes for {@link ManagedMemoryUseCase#OPERATOR}.
150+
*
151+
* <p>This method is in reference to Flink's ExecNodeUtil.setManagedMemoryWeight. The provided
152+
* bytes should be in the same scale as existing usage in Flink, for example,
153+
* StreamExecWindowAggregate.WINDOW_AGG_MEMORY_RATIO.
154+
*/
155+
public static <T> void setManagedMemoryWeight(
156+
Transformation<T> transformation, long memoryBytes) {
157+
if (memoryBytes > 0) {
158+
final int weightInMebibyte = Math.max(1, (int) (memoryBytes >> 20));
159+
final Optional<Integer> previousWeight =
160+
transformation.declareManagedMemoryUseCaseAtOperatorScope(
161+
ManagedMemoryUseCase.OPERATOR, weightInMebibyte);
162+
if (previousWeight.isPresent()) {
163+
throw new TableException(
164+
"Managed memory weight has been set, this should not happen.");
165+
}
166+
}
167+
}
168+
108169
/**
109170
* A stream operator to apply {@link MapPartitionFunction} on each partition of the input
110171
* bounded data stream.
@@ -113,7 +174,7 @@ private static class MapPartitionOperator<IN, OUT>
113174
extends AbstractUdfStreamOperator<OUT, MapPartitionFunction<IN, OUT>>
114175
implements OneInputStreamOperator<IN, OUT>, BoundedOneInput {
115176

116-
private ListState<IN> valuesState;
177+
private ListStateWithCache<IN> valuesState;
117178

118179
public MapPartitionOperator(MapPartitionFunction<IN, OUT> mapPartitionFunc) {
119180
super(mapPartitionFunc);
@@ -122,24 +183,32 @@ public MapPartitionOperator(MapPartitionFunction<IN, OUT> mapPartitionFunc) {
122183
@Override
123184
public void initializeState(StateInitializationContext context) throws Exception {
124185
super.initializeState(context);
125-
ListStateDescriptor<IN> descriptor =
126-
new ListStateDescriptor<>(
127-
"inputState",
128-
getOperatorConfig()
129-
.getTypeSerializerIn(0, getClass().getClassLoader()));
130-
valuesState = context.getOperatorStateStore().getListState(descriptor);
186+
187+
valuesState =
188+
new ListStateWithCache<>(
189+
getOperatorConfig().getTypeSerializerIn(0, getClass().getClassLoader()),
190+
getContainingTask(),
191+
getRuntimeContext(),
192+
context,
193+
config.getOperatorID());
131194
}
132195

133196
@Override
134-
public void endInput() throws Exception {
135-
userFunction.mapPartition(valuesState.get(), new TimestampedCollector<>(output));
136-
valuesState.clear();
197+
public void snapshotState(StateSnapshotContext context) throws Exception {
198+
super.snapshotState(context);
199+
valuesState.snapshotState(context);
137200
}
138201

139202
@Override
140203
public void processElement(StreamRecord<IN> input) throws Exception {
141204
valuesState.add(input.getValue());
142205
}
206+
207+
@Override
208+
public void endInput() throws Exception {
209+
userFunction.mapPartition(valuesState.get(), new TimestampedCollector<>(output));
210+
valuesState.clear();
211+
}
143212
}
144213

145214
/** A stream operator to apply {@link ReduceFunction} on the input bounded data stream. */
@@ -176,7 +245,7 @@ public void initializeState(StateInitializationContext context) throws Exception
176245
state =
177246
context.getOperatorStateStore()
178247
.getListState(
179-
new ListStateDescriptor<T>(
248+
new ListStateDescriptor<>(
180249
"state",
181250
getOperatorConfig()
182251
.getTypeSerializerIn(
@@ -256,4 +325,80 @@ public void flatMap(T[] values, Collector<Tuple2<Integer, T[]>> collector) {
256325
}
257326
}
258327
}
328+
329+
/*
330+
* A stream operator that takes a randomly sampled subset of elements in a bounded data stream.
331+
*/
332+
private static class SamplingOperator<T> extends AbstractStreamOperator<T>
333+
implements OneInputStreamOperator<T, T>, BoundedOneInput {
334+
private final int numSamples;
335+
336+
private final Random random;
337+
338+
private ListState<T> samplesState;
339+
340+
private List<T> samples;
341+
342+
private ListState<Integer> countState;
343+
344+
private int count;
345+
346+
SamplingOperator(int numSamples, long randomSeed) {
347+
this.numSamples = numSamples;
348+
this.random = new Random(randomSeed);
349+
}
350+
351+
@Override
352+
public void initializeState(StateInitializationContext context) throws Exception {
353+
super.initializeState(context);
354+
355+
ListStateDescriptor<T> samplesDescriptor =
356+
new ListStateDescriptor<>(
357+
"samplesState",
358+
getOperatorConfig()
359+
.getTypeSerializerIn(0, getClass().getClassLoader()));
360+
samplesState = context.getOperatorStateStore().getListState(samplesDescriptor);
361+
samples = new ArrayList<>(numSamples);
362+
samplesState.get().forEach(samples::add);
363+
364+
ListStateDescriptor<Integer> countDescriptor =
365+
new ListStateDescriptor<>("countState", IntSerializer.INSTANCE);
366+
countState = context.getOperatorStateStore().getListState(countDescriptor);
367+
Iterator<Integer> countIterator = countState.get().iterator();
368+
if (countIterator.hasNext()) {
369+
count = countIterator.next();
370+
} else {
371+
count = 0;
372+
}
373+
}
374+
375+
@Override
376+
public void snapshotState(StateSnapshotContext context) throws Exception {
377+
super.snapshotState(context);
378+
samplesState.update(samples);
379+
countState.update(Collections.singletonList(count));
380+
}
381+
382+
@Override
383+
public void processElement(StreamRecord<T> streamRecord) throws Exception {
384+
T value = streamRecord.getValue();
385+
count++;
386+
387+
if (samples.size() < numSamples) {
388+
samples.add(value);
389+
} else {
390+
int index = random.nextInt(count);
391+
if (index < numSamples) {
392+
samples.set(index, value);
393+
}
394+
}
395+
}
396+
397+
@Override
398+
public void endInput() throws Exception {
399+
for (T sample : samples) {
400+
output.collect(new StreamRecord<>(sample));
401+
}
402+
}
403+
}
259404
}

flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorSerializer.java

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,29 +20,36 @@
2020
package org.apache.flink.ml.linalg.typeinfo;
2121

2222
import org.apache.flink.api.common.typeutils.SimpleTypeSerializerSnapshot;
23+
import org.apache.flink.api.common.typeutils.TypeSerializer;
2324
import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot;
24-
import org.apache.flink.api.common.typeutils.base.TypeSerializerSingleton;
2525
import org.apache.flink.core.memory.DataInputView;
2626
import org.apache.flink.core.memory.DataOutputView;
2727
import org.apache.flink.ml.linalg.DenseVector;
28+
import org.apache.flink.ml.util.Bits;
2829

2930
import java.io.IOException;
3031
import java.util.Arrays;
32+
import java.util.Objects;
3133

3234
/** Specialized serializer for {@link DenseVector}. */
33-
public final class DenseVectorSerializer extends TypeSerializerSingleton<DenseVector> {
35+
public final class DenseVectorSerializer extends TypeSerializer<DenseVector> {
3436

3537
private static final long serialVersionUID = 1L;
3638

3739
private static final double[] EMPTY = new double[0];
3840

39-
public static final DenseVectorSerializer INSTANCE = new DenseVectorSerializer();
41+
private final byte[] buf = new byte[1024];
4042

4143
@Override
4244
public boolean isImmutableType() {
4345
return false;
4446
}
4547

48+
@Override
49+
public TypeSerializer<DenseVector> duplicate() {
50+
return new DenseVectorSerializer();
51+
}
52+
4653
@Override
4754
public DenseVector createInstance() {
4855
return new DenseVector(EMPTY);
@@ -75,9 +82,14 @@ public void serialize(DenseVector vector, DataOutputView target) throws IOExcept
7582

7683
final int len = vector.values.length;
7784
target.writeInt(len);
85+
7886
for (int i = 0; i < len; i++) {
79-
target.writeDouble(vector.get(i));
87+
Bits.putDouble(buf, i << 3, vector.values[i]);
88+
if ((i & 127) == 127) {
89+
target.write(buf);
90+
}
8091
}
92+
target.write(buf, 0, (len & 127) << 3);
8193
}
8294

8395
@Override
@@ -89,10 +101,17 @@ public DenseVector deserialize(DataInputView source) throws IOException {
89101
}
90102

91103
// Reads `len` double values from `source` into `dst`.
92-
private static void readDoubleArray(double[] dst, DataInputView source, int len)
93-
throws IOException {
94-
for (int i = 0; i < len; i++) {
95-
dst[i] = source.readDouble();
104+
private void readDoubleArray(double[] dst, DataInputView source, int len) throws IOException {
105+
int index = 0;
106+
for (int i = 0; i < (len >> 7); i++) {
107+
source.read(buf, 0, 1024);
108+
for (int j = 0; j < 128; j++) {
109+
dst[index++] = Bits.getDouble(buf, j << 3);
110+
}
111+
}
112+
source.read(buf, 0, (len << 3) & 1023);
113+
for (int j = 0; j < (len & 127); j++) {
114+
dst[index++] = Bits.getDouble(buf, j << 3);
96115
}
97116
}
98117

@@ -116,6 +135,16 @@ public void copy(DataInputView source, DataOutputView target) throws IOException
116135
target.write(source, len * 8);
117136
}
118137

138+
@Override
139+
public boolean equals(Object o) {
140+
return o instanceof DenseVectorSerializer;
141+
}
142+
143+
@Override
144+
public int hashCode() {
145+
return Objects.hashCode(DenseVectorSerializer.class);
146+
}
147+
119148
// ------------------------------------------------------------------------
120149

121150
@Override
@@ -129,7 +158,7 @@ public static final class DenseVectorSerializerSnapshot
129158
extends SimpleTypeSerializerSnapshot<DenseVector> {
130159

131160
public DenseVectorSerializerSnapshot() {
132-
super(() -> INSTANCE);
161+
super(DenseVectorSerializer::new);
133162
}
134163
}
135164
}

flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorTypeInfo.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,8 @@ public boolean isKeyType() {
6262
}
6363

6464
@Override
65-
@SuppressWarnings("unchecked")
6665
public TypeSerializer<DenseVector> createSerializer(ExecutionConfig executionConfig) {
67-
return DenseVectorSerializer.INSTANCE;
66+
return new DenseVectorSerializer();
6867
}
6968

7069
// --------------------------------------------------------------------------------------------

flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseVectorSerializer.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ public void serialize(SparseVector vector, DataOutputView target) throws IOExcep
8282
target.writeInt(vector.n);
8383
final int len = vector.values.length;
8484
target.writeInt(len);
85+
// TODO: optimize the serialization/deserialization process of SparseVectorSerializer.
8586
for (int i = 0; i < len; i++) {
8687
target.writeInt(vector.indices[i]);
8788
target.writeDouble(vector.values[i]);

0 commit comments

Comments
 (0)