Skip to content

Commit 10c1ef4

Browse files
[FLINK-27096] Optimize KMeans performance
This closes apache#110.
1 parent 454f7d1 commit 10c1ef4

File tree

14 files changed

+460
-127
lines changed

14 files changed

+460
-127
lines changed

flink-ml-core/src/main/java/org/apache/flink/ml/common/distance/DistanceMeasure.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
package org.apache.flink.ml.common.distance;
2020

21-
import org.apache.flink.ml.linalg.Vector;
21+
import org.apache.flink.ml.linalg.VectorWithNorm;
2222

2323
import java.io.Serializable;
2424

@@ -40,5 +40,8 @@ static DistanceMeasure getInstance(String distanceMeasure) {
4040
*
4141
* <p>Required: The two vectors should have the same dimension.
4242
*/
43-
double distance(Vector v1, Vector v2);
43+
double distance(VectorWithNorm v1, VectorWithNorm v2);
44+
45+
/** Finds the index of the closest center to the given point. */
46+
int findClosest(VectorWithNorm[] centroids, VectorWithNorm point);
4447
}

flink-ml-core/src/main/java/org/apache/flink/ml/common/distance/EuclideanDistanceMeasure.java

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818

1919
package org.apache.flink.ml.common.distance;
2020

21-
import org.apache.flink.ml.linalg.Vector;
22-
import org.apache.flink.util.Preconditions;
21+
import org.apache.flink.ml.linalg.BLAS;
22+
import org.apache.flink.ml.linalg.VectorWithNorm;
2323

2424
/** Interface for measuring the Euclidean distance between two vectors. */
2525
public class EuclideanDistanceMeasure implements DistanceMeasure {
@@ -33,16 +33,35 @@ public static EuclideanDistanceMeasure getInstance() {
3333
return instance;
3434
}
3535

36-
// TODO: Improve distance calculation with BLAS.
3736
@Override
38-
public double distance(Vector v1, Vector v2) {
39-
Preconditions.checkArgument(v1.size() == v2.size());
40-
double squaredDistance = 0.0;
37+
public double distance(VectorWithNorm v1, VectorWithNorm v2) {
38+
return Math.sqrt(distanceSquare(v1, v2));
39+
}
40+
41+
private double distanceSquare(VectorWithNorm v1, VectorWithNorm v2) {
42+
return v1.l2Norm * v1.l2Norm + v2.l2Norm * v2.l2Norm - 2.0 * BLAS.dot(v1.vector, v2.vector);
43+
}
4144

42-
for (int i = 0; i < v1.size(); i++) {
43-
double diff = v1.get(i) - v2.get(i);
44-
squaredDistance += diff * diff;
45+
@Override
46+
public int findClosest(VectorWithNorm[] centroids, VectorWithNorm point) {
47+
double bestL2DistanceSquare = Double.POSITIVE_INFINITY;
48+
int bestIndex = 0;
49+
for (int i = 0; i < centroids.length; i++) {
50+
VectorWithNorm centroid = centroids[i];
51+
52+
double lowerBoundSqrt = point.l2Norm - centroid.l2Norm;
53+
double lowerBound = lowerBoundSqrt * lowerBoundSqrt;
54+
if (lowerBound >= bestL2DistanceSquare) {
55+
continue;
56+
}
57+
58+
double l2DistanceSquare = distanceSquare(point, centroid);
59+
if (l2DistanceSquare < bestL2DistanceSquare) {
60+
bestL2DistanceSquare = l2DistanceSquare;
61+
bestIndex = i;
62+
}
4563
}
46-
return Math.sqrt(squaredDistance);
64+
65+
return bestIndex;
4766
}
4867
}

flink-ml-core/src/main/java/org/apache/flink/ml/linalg/BLAS.java

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,10 +114,21 @@ private static double dot(SparseVector x, SparseVector y) {
114114
}
115115

116116
/** \sqrt(\sum_i x_i * x_i) . */
117-
public static double norm2(DenseVector x) {
117+
public static double norm2(Vector x) {
118+
if (x instanceof DenseVector) {
119+
return norm2((DenseVector) x);
120+
}
121+
return norm2((SparseVector) x);
122+
}
123+
124+
private static double norm2(DenseVector x) {
118125
return JAVA_BLAS.dnrm2(x.size(), x.values, 1);
119126
}
120127

128+
private static double norm2(SparseVector x) {
129+
return JAVA_BLAS.dnrm2(x.values.length, x.values, 1);
130+
}
131+
121132
/** x = x * a . */
122133
public static void scal(double a, DenseVector x) {
123134
JAVA_BLAS.dscal(x.size(), a, x.values, 1);
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.flink.ml.linalg;
21+
22+
import org.apache.flink.api.common.typeinfo.TypeInfo;
23+
import org.apache.flink.ml.linalg.typeinfo.VectorWithNormTypeInfoFactory;
24+
25+
/** A vector with its norm. */
26+
@TypeInfo(VectorWithNormTypeInfoFactory.class)
27+
public class VectorWithNorm {
28+
public final Vector vector;
29+
30+
public final double l2Norm;
31+
32+
public VectorWithNorm(Vector vector) {
33+
this(vector, BLAS.norm2(vector));
34+
}
35+
36+
public VectorWithNorm(Vector vector, double l2Norm) {
37+
this.vector = vector;
38+
this.l2Norm = l2Norm;
39+
}
40+
}

flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseVectorSerializer.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ public void serialize(DenseVector vector, DataOutputView target) throws IOExcept
8484
target.writeInt(len);
8585

8686
for (int i = 0; i < len; i++) {
87-
Bits.putDouble(buf, i << 3, vector.values[i]);
87+
Bits.putDouble(buf, (i & 127) << 3, vector.values[i]);
8888
if ((i & 127) == 127) {
8989
target.write(buf);
9090
}
@@ -104,12 +104,12 @@ public DenseVector deserialize(DataInputView source) throws IOException {
104104
private void readDoubleArray(double[] dst, DataInputView source, int len) throws IOException {
105105
int index = 0;
106106
for (int i = 0; i < (len >> 7); i++) {
107-
source.read(buf, 0, 1024);
107+
source.readFully(buf, 0, 1024);
108108
for (int j = 0; j < 128; j++) {
109109
dst[index++] = Bits.getDouble(buf, j << 3);
110110
}
111111
}
112-
source.read(buf, 0, (len << 3) & 1023);
112+
source.readFully(buf, 0, (len << 3) & 1023);
113113
for (int j = 0; j < (len & 127); j++) {
114114
dst[index++] = Bits.getDouble(buf, j << 3);
115115
}
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.flink.ml.linalg.typeinfo;
21+
22+
import org.apache.flink.api.common.typeutils.SimpleTypeSerializerSnapshot;
23+
import org.apache.flink.api.common.typeutils.TypeSerializer;
24+
import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot;
25+
import org.apache.flink.core.memory.DataInputView;
26+
import org.apache.flink.core.memory.DataOutputView;
27+
import org.apache.flink.ml.linalg.DenseVector;
28+
import org.apache.flink.ml.linalg.Vector;
29+
import org.apache.flink.ml.linalg.VectorWithNorm;
30+
31+
import java.io.IOException;
32+
33+
/** Specialized serializer for {@link VectorWithNorm}. */
34+
public class VectorWithNormSerializer extends TypeSerializer<VectorWithNorm> {
35+
private final VectorSerializer vectorSerializer = new VectorSerializer();
36+
37+
private static final long serialVersionUID = 1L;
38+
39+
private static final double[] EMPTY = new double[0];
40+
41+
@Override
42+
public boolean isImmutableType() {
43+
return false;
44+
}
45+
46+
@Override
47+
public TypeSerializer<VectorWithNorm> duplicate() {
48+
return new VectorWithNormSerializer();
49+
}
50+
51+
@Override
52+
public VectorWithNorm createInstance() {
53+
return new VectorWithNorm(new DenseVector(EMPTY));
54+
}
55+
56+
@Override
57+
public VectorWithNorm copy(VectorWithNorm from) {
58+
Vector vector = vectorSerializer.copy(from.vector);
59+
return new VectorWithNorm(vector, from.l2Norm);
60+
}
61+
62+
@Override
63+
public VectorWithNorm copy(VectorWithNorm from, VectorWithNorm reuse) {
64+
Vector vector = vectorSerializer.copy(from.vector, reuse.vector);
65+
return new VectorWithNorm(vector, from.l2Norm);
66+
}
67+
68+
@Override
69+
public int getLength() {
70+
return -1;
71+
}
72+
73+
@Override
74+
public void serialize(VectorWithNorm from, DataOutputView dataOutputView) throws IOException {
75+
vectorSerializer.serialize(from.vector, dataOutputView);
76+
dataOutputView.writeDouble(from.l2Norm);
77+
}
78+
79+
@Override
80+
public VectorWithNorm deserialize(DataInputView dataInputView) throws IOException {
81+
Vector vector = vectorSerializer.deserialize(dataInputView);
82+
double l2NormSquare = dataInputView.readDouble();
83+
return new VectorWithNorm(vector, l2NormSquare);
84+
}
85+
86+
@Override
87+
public VectorWithNorm deserialize(VectorWithNorm reuse, DataInputView dataInputView)
88+
throws IOException {
89+
Vector vector = vectorSerializer.deserialize(reuse.vector, dataInputView);
90+
double l2NormSquare = dataInputView.readDouble();
91+
return new VectorWithNorm(vector, l2NormSquare);
92+
}
93+
94+
@Override
95+
public void copy(DataInputView dataInputView, DataOutputView dataOutputView)
96+
throws IOException {
97+
vectorSerializer.copy(dataInputView, dataOutputView);
98+
dataOutputView.write(dataInputView, 8);
99+
}
100+
101+
@Override
102+
public boolean equals(Object o) {
103+
return o instanceof VectorWithNormSerializer;
104+
}
105+
106+
@Override
107+
public int hashCode() {
108+
return VectorWithNormSerializer.class.hashCode();
109+
}
110+
111+
@Override
112+
public TypeSerializerSnapshot<VectorWithNorm> snapshotConfiguration() {
113+
return new VectorWithNormSerializerSnapshot();
114+
}
115+
116+
private static class VectorWithNormSerializerSnapshot
117+
extends SimpleTypeSerializerSnapshot<VectorWithNorm> {
118+
public VectorWithNormSerializerSnapshot() {
119+
super(VectorWithNormSerializer::new);
120+
}
121+
}
122+
}
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.flink.ml.linalg.typeinfo;
21+
22+
import org.apache.flink.api.common.ExecutionConfig;
23+
import org.apache.flink.api.common.typeinfo.TypeInformation;
24+
import org.apache.flink.api.common.typeutils.TypeSerializer;
25+
import org.apache.flink.ml.linalg.VectorWithNorm;
26+
27+
/** A {@link TypeInformation} for the {@link VectorWithNorm} type. */
28+
public class VectorWithNormTypeInfo extends TypeInformation<VectorWithNorm> {
29+
@Override
30+
public boolean isBasicType() {
31+
return false;
32+
}
33+
34+
@Override
35+
public boolean isTupleType() {
36+
return false;
37+
}
38+
39+
@Override
40+
public int getArity() {
41+
return 2;
42+
}
43+
44+
@Override
45+
public int getTotalFields() {
46+
return 2;
47+
}
48+
49+
@Override
50+
public Class<VectorWithNorm> getTypeClass() {
51+
return VectorWithNorm.class;
52+
}
53+
54+
@Override
55+
public boolean isKeyType() {
56+
return false;
57+
}
58+
59+
@Override
60+
public TypeSerializer<VectorWithNorm> createSerializer(ExecutionConfig executionConfig) {
61+
return new VectorWithNormSerializer();
62+
}
63+
64+
@Override
65+
public String toString() {
66+
return "VectorWithNormType";
67+
}
68+
69+
@Override
70+
public boolean equals(Object o) {
71+
return o instanceof VectorWithNormTypeInfo;
72+
}
73+
74+
@Override
75+
public int hashCode() {
76+
return getClass().hashCode();
77+
}
78+
79+
@Override
80+
public boolean canEqual(Object o) {
81+
return o instanceof VectorWithNormTypeInfo;
82+
}
83+
}

0 commit comments

Comments
 (0)