26
26
import org .apache .flink .ml .linalg .DenseVector ;
27
27
import org .apache .flink .ml .linalg .SparseVector ;
28
28
import org .apache .flink .ml .linalg .Vector ;
29
+ import org .apache .flink .ml .linalg .Vectors ;
29
30
import org .apache .flink .ml .linalg .typeinfo .VectorTypeInfo ;
30
31
import org .apache .flink .ml .param .Param ;
31
32
import org .apache .flink .ml .util .ParamUtils ;
42
43
43
44
import java .io .IOException ;
44
45
import java .util .HashMap ;
45
- import java .util .LinkedHashMap ;
46
46
import java .util .Map ;
47
47
48
48
/**
@@ -90,14 +90,27 @@ public AssemblerFunc(String[] inputCols, String handleInvalid) {
90
90
}
91
91
92
92
@ Override
93
- public void flatMap (Row value , Collector <Row > out ) throws Exception {
93
+ public void flatMap (Row value , Collector <Row > out ) {
94
+ int nnz = 0 ;
95
+ int vectorSize = 0 ;
94
96
try {
95
- Object [] objects = new Object [inputCols .length ];
96
- for (int i = 0 ; i < objects .length ; ++i ) {
97
- objects [i ] = value .getField (inputCols [i ]);
97
+ for (String inputCol : inputCols ) {
98
+ Object object = value .getField (inputCol );
99
+ Preconditions .checkNotNull (object , "Input column value should not be null." );
100
+ if (object instanceof Number ) {
101
+ nnz += 1 ;
102
+ vectorSize += 1 ;
103
+ } else if (object instanceof SparseVector ) {
104
+ nnz += ((SparseVector ) object ).indices .length ;
105
+ vectorSize += ((SparseVector ) object ).size ();
106
+ } else if (object instanceof DenseVector ) {
107
+ nnz += ((DenseVector ) object ).size ();
108
+ vectorSize += ((DenseVector ) object ).size ();
109
+ } else {
110
+ throw new IllegalArgumentException (
111
+ "Input type has not been supported yet." );
112
+ }
98
113
}
99
- Vector assembledVector = assemble (objects );
100
- out .collect (Row .join (value , Row .of (assembledVector )));
101
114
} catch (Exception e ) {
102
115
switch (handleInvalid ) {
103
116
case ERROR_INVALID :
@@ -112,6 +125,13 @@ public void flatMap(Row value, Collector<Row> out) throws Exception {
112
125
"Unsupported " + HANDLE_INVALID + " type: " + handleInvalid );
113
126
}
114
127
}
128
+
129
+ boolean toDense = nnz * RATIO > vectorSize ;
130
+ Vector assembledVec =
131
+ toDense
132
+ ? assembleDense (inputCols , value , vectorSize )
133
+ : assembleSparse (inputCols , value , vectorSize , nnz );
134
+ out .collect (Row .join (value , Row .of (assembledVec )));
115
135
}
116
136
}
117
137
@@ -129,57 +149,69 @@ public Map<Param<?>, Object> getParamMap() {
129
149
return paramMap ;
130
150
}
131
151
132
- private static Vector assemble (Object [] objects ) {
133
- int offset = 0 ;
134
- Map <Integer , Double > map = new LinkedHashMap <>(objects .length );
135
- for (Object object : objects ) {
136
- Preconditions .checkNotNull (object , "Input column value should not be null." );
152
+ /** Assembles the input columns into a dense vector. */
153
+ private static Vector assembleDense (String [] inputCols , Row inputRow , int vectorSize ) {
154
+ double [] values = new double [vectorSize ];
155
+ int currentOffset = 0 ;
156
+
157
+ for (String inputCol : inputCols ) {
158
+ Object object = inputRow .getField (inputCol );
137
159
if (object instanceof Number ) {
138
- map .put (offset ++, ((Number ) object ).doubleValue ());
139
- } else if (object instanceof Vector ) {
140
- offset = appendVector ((Vector ) object , map , offset );
160
+ values [currentOffset ++] = ((Number ) object ).doubleValue ();
161
+ } else if (object instanceof SparseVector ) {
162
+ SparseVector sparseVector = (SparseVector ) object ;
163
+ for (int i = 0 ; i < sparseVector .indices .length ; i ++) {
164
+ values [currentOffset + sparseVector .indices [i ]] = sparseVector .values [i ];
165
+ }
166
+ currentOffset += sparseVector .size ();
167
+
141
168
} else {
142
- throw new IllegalArgumentException ( "Input type has not been supported yet." ) ;
143
- }
144
- }
169
+ DenseVector denseVector = ( DenseVector ) object ;
170
+ System . arraycopy (
171
+ denseVector . values , 0 , values , currentOffset , denseVector . values . length );
145
172
146
- if (map .size () * RATIO > offset ) {
147
- DenseVector assembledVector = new DenseVector (offset );
148
- for (int key : map .keySet ()) {
149
- assembledVector .values [key ] = map .get (key );
173
+ currentOffset += denseVector .size ();
150
174
}
151
- return assembledVector ;
152
- } else {
153
- return convertMapToSparseVector (offset , map );
154
175
}
176
+ return Vectors .dense (values );
155
177
}
156
178
157
- private static int appendVector (Vector vec , Map <Integer , Double > map , int offset ) {
158
- if (vec instanceof SparseVector ) {
159
- SparseVector sparseVector = (SparseVector ) vec ;
160
- int [] indices = sparseVector .indices ;
161
- double [] values = sparseVector .values ;
162
- for (int i = 0 ; i < indices .length ; ++i ) {
163
- map .put (offset + indices [i ], values [i ]);
164
- }
165
- offset += sparseVector .size ();
166
- } else {
167
- DenseVector denseVector = (DenseVector ) vec ;
168
- for (int i = 0 ; i < denseVector .size (); ++i ) {
169
- map .put (offset ++, denseVector .values [i ]);
170
- }
171
- }
172
- return offset ;
173
- }
179
+ /** Assembles the input columns into a sparse vector. */
180
+ private static Vector assembleSparse (
181
+ String [] inputCols , Row inputRow , int vectorSize , int nnz ) {
182
+ int [] indices = new int [nnz ];
183
+ double [] values = new double [nnz ];
174
184
175
- private static SparseVector convertMapToSparseVector (int size , Map <Integer , Double > map ) {
176
- int [] indices = new int [map .size ()];
177
- double [] values = new double [map .size ()];
178
- int offset = 0 ;
179
- for (Map .Entry <Integer , Double > entry : map .entrySet ()) {
180
- indices [offset ] = entry .getKey ();
181
- values [offset ++] = entry .getValue ();
185
+ int currentIndex = 0 ;
186
+ int currentOffset = 0 ;
187
+
188
+ for (String inputCol : inputCols ) {
189
+ Object object = inputRow .getField (inputCol );
190
+ if (object instanceof Number ) {
191
+ indices [currentOffset ] = currentIndex ;
192
+ values [currentOffset ] = ((Number ) object ).doubleValue ();
193
+ currentOffset ++;
194
+ currentIndex ++;
195
+ } else if (object instanceof SparseVector ) {
196
+ SparseVector sparseVector = (SparseVector ) object ;
197
+ for (int i = 0 ; i < sparseVector .indices .length ; i ++) {
198
+ indices [currentOffset + i ] = sparseVector .indices [i ] + currentIndex ;
199
+ }
200
+ System .arraycopy (
201
+ sparseVector .values , 0 , values , currentOffset , sparseVector .values .length );
202
+ currentIndex += sparseVector .size ();
203
+ currentOffset += sparseVector .indices .length ;
204
+ } else {
205
+ DenseVector denseVector = (DenseVector ) object ;
206
+ for (int i = 0 ; i < denseVector .size (); ++i ) {
207
+ indices [currentOffset + i ] = i + currentIndex ;
208
+ }
209
+ System .arraycopy (
210
+ denseVector .values , 0 , values , currentOffset , denseVector .values .length );
211
+ currentIndex += denseVector .size ();
212
+ currentOffset += denseVector .size ();
213
+ }
182
214
}
183
- return new SparseVector (size , indices , values );
215
+ return new SparseVector (vectorSize , indices , values );
184
216
}
185
217
}
0 commit comments