|
| 1 | +/* |
| 2 | + * Licensed to the Apache Software Foundation (ASF) under one |
| 3 | + * or more contributor license agreements. See the NOTICE file |
| 4 | + * distributed with this work for additional information |
| 5 | + * regarding copyright ownership. The ASF licenses this file |
| 6 | + * to you under the Apache License, Version 2.0 (the |
| 7 | + * "License"); you may not use this file except in compliance |
| 8 | + * with the License. You may obtain a copy of the License at |
| 9 | + * |
| 10 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 11 | + * |
| 12 | + * Unless required by applicable law or agreed to in writing, software |
| 13 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 14 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 15 | + * See the License for the specific language governing permissions and |
| 16 | + * limitations under the License. |
| 17 | + */ |
| 18 | + |
| 19 | +package org.apache.flink.ml.feature.featurehasher; |
| 20 | + |
| 21 | +import org.apache.flink.api.common.functions.MapFunction; |
| 22 | +import org.apache.flink.api.java.typeutils.RowTypeInfo; |
| 23 | +import org.apache.flink.ml.api.Transformer; |
| 24 | +import org.apache.flink.ml.common.datastream.TableUtils; |
| 25 | +import org.apache.flink.ml.linalg.SparseVector; |
| 26 | +import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; |
| 27 | +import org.apache.flink.ml.param.Param; |
| 28 | +import org.apache.flink.ml.util.ParamUtils; |
| 29 | +import org.apache.flink.ml.util.ReadWriteUtils; |
| 30 | +import org.apache.flink.streaming.api.datastream.DataStream; |
| 31 | +import org.apache.flink.table.api.DataTypes; |
| 32 | +import org.apache.flink.table.api.Table; |
| 33 | +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; |
| 34 | +import org.apache.flink.table.api.internal.TableImpl; |
| 35 | +import org.apache.flink.table.catalog.ResolvedSchema; |
| 36 | +import org.apache.flink.table.types.DataType; |
| 37 | +import org.apache.flink.types.Row; |
| 38 | +import org.apache.flink.util.Preconditions; |
| 39 | + |
| 40 | +import org.apache.commons.lang3.ArrayUtils; |
| 41 | + |
| 42 | +import java.io.IOException; |
| 43 | +import java.util.ArrayList; |
| 44 | +import java.util.Arrays; |
| 45 | +import java.util.HashMap; |
| 46 | +import java.util.List; |
| 47 | +import java.util.Map; |
| 48 | +import java.util.TreeMap; |
| 49 | + |
| 50 | +import static org.apache.flink.shaded.guava30.com.google.common.hash.Hashing.murmur3_32; |
| 51 | + |
| 52 | +/** |
| 53 | + * A Transformer that transforms a set of categorical or numerical features into a sparse vector of |
| 54 | + * a specified dimension. The rules of hashing categorical columns and numerical columns are as |
| 55 | + * follows: |
| 56 | + * |
| 57 | + * <ul> |
| 58 | + * <li>For numerical columns, the index of this feature in the output vector is the hash value of |
| 59 | + * the column name and its correponding value is the same as the input. |
| 60 | + * <li>For categorical columns, the index of this feature in the output vector is the hash value |
| 61 | + * of the string "column_name=value" and the corresponding value is 1.0. |
| 62 | + * </ul> |
| 63 | + * |
| 64 | + * <p>If multiple features are projected into the same column, the output values are accumulated. |
| 65 | + * For the hashing trick, see https://en.wikipedia.org/wiki/Feature_hashing for details. |
| 66 | + */ |
| 67 | +public class FeatureHasher |
| 68 | + implements Transformer<FeatureHasher>, FeatureHasherParams<FeatureHasher> { |
| 69 | + private final Map<Param<?>, Object> paramMap = new HashMap<>(); |
| 70 | + private static final org.apache.flink.shaded.guava30.com.google.common.hash.HashFunction HASH = |
| 71 | + murmur3_32(0); |
| 72 | + |
| 73 | + public FeatureHasher() { |
| 74 | + ParamUtils.initializeMapWithDefaultValues(paramMap, this); |
| 75 | + } |
| 76 | + |
| 77 | + @Override |
| 78 | + public Table[] transform(Table... inputs) { |
| 79 | + Preconditions.checkArgument(inputs.length == 1); |
| 80 | + StreamTableEnvironment tEnv = |
| 81 | + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); |
| 82 | + ResolvedSchema tableSchema = inputs[0].getResolvedSchema(); |
| 83 | + RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(tableSchema); |
| 84 | + RowTypeInfo outputTypeInfo = |
| 85 | + new RowTypeInfo( |
| 86 | + ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), VectorTypeInfo.INSTANCE), |
| 87 | + ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getOutputCol())); |
| 88 | + DataStream<Row> output = |
| 89 | + tEnv.toDataStream(inputs[0]) |
| 90 | + .map( |
| 91 | + new HashFunction( |
| 92 | + getInputCols(), |
| 93 | + generateCategoricalCols( |
| 94 | + tableSchema, getInputCols(), getCategoricalCols()), |
| 95 | + getNumFeatures()), |
| 96 | + outputTypeInfo); |
| 97 | + Table outputTable = tEnv.fromDataStream(output); |
| 98 | + return new Table[] {outputTable}; |
| 99 | + } |
| 100 | + |
| 101 | + /** |
| 102 | + * The main logic for transforming the categorical and numerical features into a sparse vector. |
| 103 | + * It uses MurMurHash3 to compute the transformed index in the output vector. If multiple |
| 104 | + * features are projected to the same column, their values are accumulated. |
| 105 | + */ |
| 106 | + private static class HashFunction implements MapFunction<Row, Row> { |
| 107 | + private final String[] categoricalCols; |
| 108 | + private final int numFeatures; |
| 109 | + private final String[] numericCols; |
| 110 | + |
| 111 | + public HashFunction(String[] inputCols, String[] categoricalCols, int numFeatures) { |
| 112 | + this.categoricalCols = categoricalCols; |
| 113 | + this.numFeatures = numFeatures; |
| 114 | + this.numericCols = ArrayUtils.removeElements(inputCols, this.categoricalCols); |
| 115 | + } |
| 116 | + |
| 117 | + @Override |
| 118 | + public Row map(Row row) { |
| 119 | + TreeMap<Integer, Double> feature = new TreeMap<>(); |
| 120 | + for (String col : numericCols) { |
| 121 | + if (null != row.getField(col)) { |
| 122 | + double value = ((Number) row.getFieldAs(col)).doubleValue(); |
| 123 | + updateMap(col, value, feature, numFeatures); |
| 124 | + } |
| 125 | + } |
| 126 | + for (String col : categoricalCols) { |
| 127 | + if (null != row.getField(col)) { |
| 128 | + updateMap(col + "=" + row.getField(col), 1.0, feature, numFeatures); |
| 129 | + } |
| 130 | + } |
| 131 | + int nnz = feature.size(); |
| 132 | + int[] indices = new int[nnz]; |
| 133 | + double[] values = new double[nnz]; |
| 134 | + int pos = 0; |
| 135 | + for (Map.Entry<Integer, Double> entry : feature.entrySet()) { |
| 136 | + indices[pos] = entry.getKey(); |
| 137 | + values[pos] = entry.getValue(); |
| 138 | + pos++; |
| 139 | + } |
| 140 | + return Row.join(row, Row.of(new SparseVector(numFeatures, indices, values))); |
| 141 | + } |
| 142 | + } |
| 143 | + |
| 144 | + private String[] generateCategoricalCols( |
| 145 | + ResolvedSchema tableSchema, String[] inputCols, String[] categoricalCols) { |
| 146 | + if (null == inputCols) { |
| 147 | + return categoricalCols; |
| 148 | + } |
| 149 | + List<String> categoricalList = Arrays.asList(categoricalCols); |
| 150 | + List<String> inputList = Arrays.asList(inputCols); |
| 151 | + if (categoricalCols.length > 0 && !inputList.containsAll(categoricalList)) { |
| 152 | + throw new IllegalArgumentException("CategoricalCols must be included in inputCols!"); |
| 153 | + } |
| 154 | + List<DataType> dataColTypes = tableSchema.getColumnDataTypes(); |
| 155 | + List<String> dataColNames = tableSchema.getColumnNames(); |
| 156 | + List<DataType> inputColTypes = new ArrayList<>(); |
| 157 | + for (String col : inputCols) { |
| 158 | + for (int i = 0; i < dataColNames.size(); ++i) { |
| 159 | + if (col.equals(dataColNames.get(i))) { |
| 160 | + inputColTypes.add(dataColTypes.get(i)); |
| 161 | + break; |
| 162 | + } |
| 163 | + } |
| 164 | + } |
| 165 | + List<String> resultColList = new ArrayList<>(); |
| 166 | + for (int i = 0; i < inputCols.length; i++) { |
| 167 | + boolean included = categoricalList.contains(inputCols[i]); |
| 168 | + if (included |
| 169 | + || DataTypes.BOOLEAN().equals(inputColTypes.get(i)) |
| 170 | + || DataTypes.STRING().equals(inputColTypes.get(i))) { |
| 171 | + resultColList.add(inputCols[i]); |
| 172 | + } |
| 173 | + } |
| 174 | + return resultColList.toArray(new String[0]); |
| 175 | + } |
| 176 | + |
| 177 | + /** |
| 178 | + * Updates the treeMap which saves the key-value pair of the final vector, use the hash value of |
| 179 | + * the string as key and the accumulate the corresponding value. |
| 180 | + * |
| 181 | + * @param s the string to hash |
| 182 | + * @param value the accumulated value |
| 183 | + */ |
| 184 | + private static void updateMap( |
| 185 | + String s, double value, TreeMap<Integer, Double> feature, int numFeature) { |
| 186 | + int hashValue = Math.abs(HASH.hashUnencodedChars(s).asInt()); |
| 187 | + |
| 188 | + int index = Math.floorMod(hashValue, numFeature); |
| 189 | + if (feature.containsKey(index)) { |
| 190 | + feature.put(index, feature.get(index) + value); |
| 191 | + } else { |
| 192 | + feature.put(index, value); |
| 193 | + } |
| 194 | + } |
| 195 | + |
| 196 | + @Override |
| 197 | + public void save(String path) throws IOException { |
| 198 | + ReadWriteUtils.saveMetadata(this, path); |
| 199 | + } |
| 200 | + |
| 201 | + public static FeatureHasher load(StreamTableEnvironment env, String path) throws IOException { |
| 202 | + return ReadWriteUtils.loadStageParam(path); |
| 203 | + } |
| 204 | + |
| 205 | + @Override |
| 206 | + public Map<Param<?>, Object> getParamMap() { |
| 207 | + return paramMap; |
| 208 | + } |
| 209 | +} |
0 commit comments