Skip to content

Commit 41cccd1

Browse files
authored
[FLINK-28601] Add Transformer for FeatureHasher
This closes apache#133.
1 parent e5da0da commit 41cccd1

File tree

11 files changed

+802
-1
lines changed

11 files changed

+802
-1
lines changed
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.flink.ml.examples.feature;
20+
21+
import org.apache.flink.ml.feature.featurehasher.FeatureHasher;
22+
import org.apache.flink.ml.linalg.Vector;
23+
import org.apache.flink.streaming.api.datastream.DataStream;
24+
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
25+
import org.apache.flink.table.api.Table;
26+
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
27+
import org.apache.flink.types.Row;
28+
import org.apache.flink.util.CloseableIterator;
29+
30+
import java.util.Arrays;
31+
32+
/** Simple program that creates a FeatureHasher instance and uses it for feature engineering. */
33+
public class FeatureHasherExample {
34+
public static void main(String[] args) {
35+
36+
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
37+
StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
38+
39+
// Generates input data.
40+
DataStream<Row> dataStream =
41+
env.fromCollection(
42+
Arrays.asList(Row.of(0, "a", 1.0, true), Row.of(1, "c", 1.0, false)));
43+
Table inputDataTable = tEnv.fromDataStream(dataStream).as("id", "f0", "f1", "f2");
44+
45+
// Creates a FeatureHasher object and initializes its parameters.
46+
FeatureHasher featureHash =
47+
new FeatureHasher()
48+
.setInputCols("f0", "f1", "f2")
49+
.setCategoricalCols("f0", "f2")
50+
.setOutputCol("vec")
51+
.setNumFeatures(1000);
52+
53+
// Uses the FeatureHasher object for feature transformations.
54+
Table outputTable = featureHash.transform(inputDataTable)[0];
55+
56+
// Extracts and displays the results.
57+
for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
58+
Row row = it.next();
59+
60+
Object[] inputValues = new Object[featureHash.getInputCols().length];
61+
for (int i = 0; i < inputValues.length; i++) {
62+
inputValues[i] = row.getField(featureHash.getInputCols()[i]);
63+
}
64+
Vector outputValue = (Vector) row.getField(featureHash.getOutputCol());
65+
66+
System.out.printf(
67+
"Input Values: %s \tOutput Value: %s\n",
68+
Arrays.toString(inputValues), outputValue);
69+
}
70+
}
71+
}
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.flink.ml.common.param;
20+
21+
import org.apache.flink.ml.param.Param;
22+
import org.apache.flink.ml.param.ParamValidators;
23+
import org.apache.flink.ml.param.StringArrayParam;
24+
import org.apache.flink.ml.param.WithParams;
25+
26+
/** Interface for the shared categoricalCols param. */
27+
public interface HasCategoricalCols<T> extends WithParams<T> {
28+
Param<String[]> CATEGORICAL_COLS =
29+
new StringArrayParam(
30+
"categoricalCols",
31+
"Categorical column names.",
32+
new String[] {},
33+
ParamValidators.notNull());
34+
35+
default String[] getCategoricalCols() {
36+
return get(CATEGORICAL_COLS);
37+
}
38+
39+
default T setCategoricalCols(String... value) {
40+
return set(CATEGORICAL_COLS, value);
41+
}
42+
}
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.flink.ml.common.param;
20+
21+
import org.apache.flink.ml.param.IntParam;
22+
import org.apache.flink.ml.param.Param;
23+
import org.apache.flink.ml.param.ParamValidators;
24+
import org.apache.flink.ml.param.WithParams;
25+
26+
/** Interface for the shared num features param. */
27+
public interface HasNumFeatures<T> extends WithParams<T> {
28+
Param<Integer> NUM_FEATURES =
29+
new IntParam(
30+
"numFeatures",
31+
"The number of features. It will be the length of the output vector.",
32+
262144,
33+
ParamValidators.gt(0));
34+
35+
default int getNumFeatures() {
36+
return get(NUM_FEATURES);
37+
}
38+
39+
default T setNumFeatures(int value) {
40+
set(NUM_FEATURES, value);
41+
return (T) this;
42+
}
43+
}
Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.flink.ml.feature.featurehasher;
20+
21+
import org.apache.flink.api.common.functions.MapFunction;
22+
import org.apache.flink.api.java.typeutils.RowTypeInfo;
23+
import org.apache.flink.ml.api.Transformer;
24+
import org.apache.flink.ml.common.datastream.TableUtils;
25+
import org.apache.flink.ml.linalg.SparseVector;
26+
import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo;
27+
import org.apache.flink.ml.param.Param;
28+
import org.apache.flink.ml.util.ParamUtils;
29+
import org.apache.flink.ml.util.ReadWriteUtils;
30+
import org.apache.flink.streaming.api.datastream.DataStream;
31+
import org.apache.flink.table.api.DataTypes;
32+
import org.apache.flink.table.api.Table;
33+
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
34+
import org.apache.flink.table.api.internal.TableImpl;
35+
import org.apache.flink.table.catalog.ResolvedSchema;
36+
import org.apache.flink.table.types.DataType;
37+
import org.apache.flink.types.Row;
38+
import org.apache.flink.util.Preconditions;
39+
40+
import org.apache.commons.lang3.ArrayUtils;
41+
42+
import java.io.IOException;
43+
import java.util.ArrayList;
44+
import java.util.Arrays;
45+
import java.util.HashMap;
46+
import java.util.List;
47+
import java.util.Map;
48+
import java.util.TreeMap;
49+
50+
import static org.apache.flink.shaded.guava30.com.google.common.hash.Hashing.murmur3_32;
51+
52+
/**
53+
* A Transformer that transforms a set of categorical or numerical features into a sparse vector of
54+
* a specified dimension. The rules of hashing categorical columns and numerical columns are as
55+
* follows:
56+
*
57+
* <ul>
58+
* <li>For numerical columns, the index of this feature in the output vector is the hash value of
59+
* the column name and its correponding value is the same as the input.
60+
* <li>For categorical columns, the index of this feature in the output vector is the hash value
61+
* of the string "column_name=value" and the corresponding value is 1.0.
62+
* </ul>
63+
*
64+
* <p>If multiple features are projected into the same column, the output values are accumulated.
65+
* For the hashing trick, see https://en.wikipedia.org/wiki/Feature_hashing for details.
66+
*/
67+
public class FeatureHasher
68+
implements Transformer<FeatureHasher>, FeatureHasherParams<FeatureHasher> {
69+
private final Map<Param<?>, Object> paramMap = new HashMap<>();
70+
private static final org.apache.flink.shaded.guava30.com.google.common.hash.HashFunction HASH =
71+
murmur3_32(0);
72+
73+
public FeatureHasher() {
74+
ParamUtils.initializeMapWithDefaultValues(paramMap, this);
75+
}
76+
77+
@Override
78+
public Table[] transform(Table... inputs) {
79+
Preconditions.checkArgument(inputs.length == 1);
80+
StreamTableEnvironment tEnv =
81+
(StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
82+
ResolvedSchema tableSchema = inputs[0].getResolvedSchema();
83+
RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(tableSchema);
84+
RowTypeInfo outputTypeInfo =
85+
new RowTypeInfo(
86+
ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), VectorTypeInfo.INSTANCE),
87+
ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getOutputCol()));
88+
DataStream<Row> output =
89+
tEnv.toDataStream(inputs[0])
90+
.map(
91+
new HashFunction(
92+
getInputCols(),
93+
generateCategoricalCols(
94+
tableSchema, getInputCols(), getCategoricalCols()),
95+
getNumFeatures()),
96+
outputTypeInfo);
97+
Table outputTable = tEnv.fromDataStream(output);
98+
return new Table[] {outputTable};
99+
}
100+
101+
/**
102+
* The main logic for transforming the categorical and numerical features into a sparse vector.
103+
* It uses MurMurHash3 to compute the transformed index in the output vector. If multiple
104+
* features are projected to the same column, their values are accumulated.
105+
*/
106+
private static class HashFunction implements MapFunction<Row, Row> {
107+
private final String[] categoricalCols;
108+
private final int numFeatures;
109+
private final String[] numericCols;
110+
111+
public HashFunction(String[] inputCols, String[] categoricalCols, int numFeatures) {
112+
this.categoricalCols = categoricalCols;
113+
this.numFeatures = numFeatures;
114+
this.numericCols = ArrayUtils.removeElements(inputCols, this.categoricalCols);
115+
}
116+
117+
@Override
118+
public Row map(Row row) {
119+
TreeMap<Integer, Double> feature = new TreeMap<>();
120+
for (String col : numericCols) {
121+
if (null != row.getField(col)) {
122+
double value = ((Number) row.getFieldAs(col)).doubleValue();
123+
updateMap(col, value, feature, numFeatures);
124+
}
125+
}
126+
for (String col : categoricalCols) {
127+
if (null != row.getField(col)) {
128+
updateMap(col + "=" + row.getField(col), 1.0, feature, numFeatures);
129+
}
130+
}
131+
int nnz = feature.size();
132+
int[] indices = new int[nnz];
133+
double[] values = new double[nnz];
134+
int pos = 0;
135+
for (Map.Entry<Integer, Double> entry : feature.entrySet()) {
136+
indices[pos] = entry.getKey();
137+
values[pos] = entry.getValue();
138+
pos++;
139+
}
140+
return Row.join(row, Row.of(new SparseVector(numFeatures, indices, values)));
141+
}
142+
}
143+
144+
private String[] generateCategoricalCols(
145+
ResolvedSchema tableSchema, String[] inputCols, String[] categoricalCols) {
146+
if (null == inputCols) {
147+
return categoricalCols;
148+
}
149+
List<String> categoricalList = Arrays.asList(categoricalCols);
150+
List<String> inputList = Arrays.asList(inputCols);
151+
if (categoricalCols.length > 0 && !inputList.containsAll(categoricalList)) {
152+
throw new IllegalArgumentException("CategoricalCols must be included in inputCols!");
153+
}
154+
List<DataType> dataColTypes = tableSchema.getColumnDataTypes();
155+
List<String> dataColNames = tableSchema.getColumnNames();
156+
List<DataType> inputColTypes = new ArrayList<>();
157+
for (String col : inputCols) {
158+
for (int i = 0; i < dataColNames.size(); ++i) {
159+
if (col.equals(dataColNames.get(i))) {
160+
inputColTypes.add(dataColTypes.get(i));
161+
break;
162+
}
163+
}
164+
}
165+
List<String> resultColList = new ArrayList<>();
166+
for (int i = 0; i < inputCols.length; i++) {
167+
boolean included = categoricalList.contains(inputCols[i]);
168+
if (included
169+
|| DataTypes.BOOLEAN().equals(inputColTypes.get(i))
170+
|| DataTypes.STRING().equals(inputColTypes.get(i))) {
171+
resultColList.add(inputCols[i]);
172+
}
173+
}
174+
return resultColList.toArray(new String[0]);
175+
}
176+
177+
/**
178+
* Updates the treeMap which saves the key-value pair of the final vector, use the hash value of
179+
* the string as key and the accumulate the corresponding value.
180+
*
181+
* @param s the string to hash
182+
* @param value the accumulated value
183+
*/
184+
private static void updateMap(
185+
String s, double value, TreeMap<Integer, Double> feature, int numFeature) {
186+
int hashValue = Math.abs(HASH.hashUnencodedChars(s).asInt());
187+
188+
int index = Math.floorMod(hashValue, numFeature);
189+
if (feature.containsKey(index)) {
190+
feature.put(index, feature.get(index) + value);
191+
} else {
192+
feature.put(index, value);
193+
}
194+
}
195+
196+
@Override
197+
public void save(String path) throws IOException {
198+
ReadWriteUtils.saveMetadata(this, path);
199+
}
200+
201+
public static FeatureHasher load(StreamTableEnvironment env, String path) throws IOException {
202+
return ReadWriteUtils.loadStageParam(path);
203+
}
204+
205+
@Override
206+
public Map<Param<?>, Object> getParamMap() {
207+
return paramMap;
208+
}
209+
}

0 commit comments

Comments
 (0)