26
26
import org .apache .flink .api .common .state .ListState ;
27
27
import org .apache .flink .api .common .state .ListStateDescriptor ;
28
28
import org .apache .flink .api .common .typeinfo .TypeInformation ;
29
+ import org .apache .flink .api .common .typeutils .base .IntSerializer ;
30
+ import org .apache .flink .api .dag .Transformation ;
29
31
import org .apache .flink .api .java .tuple .Tuple2 ;
30
32
import org .apache .flink .api .java .typeutils .TypeExtractor ;
33
+ import org .apache .flink .core .memory .ManagedMemoryUseCase ;
34
+ import org .apache .flink .iteration .datacache .nonkeyed .ListStateWithCache ;
31
35
import org .apache .flink .iteration .operator .OperatorStateUtils ;
32
36
import org .apache .flink .runtime .state .StateInitializationContext ;
33
37
import org .apache .flink .runtime .state .StateSnapshotContext ;
34
38
import org .apache .flink .streaming .api .datastream .DataStream ;
35
39
import org .apache .flink .streaming .api .functions .windowing .AllWindowFunction ;
40
+ import org .apache .flink .streaming .api .operators .AbstractStreamOperator ;
36
41
import org .apache .flink .streaming .api .operators .AbstractUdfStreamOperator ;
37
42
import org .apache .flink .streaming .api .operators .BoundedOneInput ;
38
43
import org .apache .flink .streaming .api .operators .OneInputStreamOperator ;
39
44
import org .apache .flink .streaming .api .operators .TimestampedCollector ;
40
45
import org .apache .flink .streaming .api .windowing .windows .GlobalWindow ;
41
46
import org .apache .flink .streaming .runtime .streamrecord .StreamRecord ;
47
+ import org .apache .flink .table .api .TableException ;
42
48
import org .apache .flink .util .Collector ;
43
49
44
50
import org .apache .commons .collections .IteratorUtils ;
45
51
52
+ import java .util .ArrayList ;
46
53
import java .util .Arrays ;
54
+ import java .util .Collections ;
55
+ import java .util .Iterator ;
47
56
import java .util .List ;
57
+ import java .util .Optional ;
58
+ import java .util .Random ;
48
59
49
60
/** Provides utility functions for {@link DataStream}. */
50
61
@ Internal
@@ -105,6 +116,56 @@ public static <T> DataStream<T> reduce(DataStream<T> input, ReduceFunction<T> fu
105
116
}
106
117
}
107
118
119
+ /**
120
+ * Performs a uniform sampling over the elements in a bounded data stream.
121
+ *
122
+ * <p>This method takes samples without replacement. If the number of elements in the stream is
123
+ * smaller than expected number of samples, all elements will be included in the sample.
124
+ *
125
+ * @param input The input data stream.
126
+ * @param numSamples The number of elements to be sampled.
127
+ * @param randomSeed The seed to randomly pick elements as sample.
128
+ * @return A data stream containing a list of the sampled elements.
129
+ */
130
+ public static <T > DataStream <T > sample (DataStream <T > input , int numSamples , long randomSeed ) {
131
+ int inputParallelism = input .getParallelism ();
132
+
133
+ return input .transform (
134
+ "samplingOperator" ,
135
+ input .getType (),
136
+ new SamplingOperator <>(numSamples , randomSeed ))
137
+ .setParallelism (inputParallelism )
138
+ .transform (
139
+ "samplingOperator" ,
140
+ input .getType (),
141
+ new SamplingOperator <>(numSamples , randomSeed ))
142
+ .setParallelism (1 )
143
+ .map (x -> x , input .getType ())
144
+ .setParallelism (inputParallelism );
145
+ }
146
+
147
+ /**
148
+ * Sets {Transformation#declareManagedMemoryUseCaseAtOperatorScope(ManagedMemoryUseCase, int)}
149
+ * using the given bytes for {@link ManagedMemoryUseCase#OPERATOR}.
150
+ *
151
+ * <p>This method is in reference to Flink's ExecNodeUtil.setManagedMemoryWeight. The provided
152
+ * bytes should be in the same scale as existing usage in Flink, for example,
153
+ * StreamExecWindowAggregate.WINDOW_AGG_MEMORY_RATIO.
154
+ */
155
+ public static <T > void setManagedMemoryWeight (
156
+ Transformation <T > transformation , long memoryBytes ) {
157
+ if (memoryBytes > 0 ) {
158
+ final int weightInMebibyte = Math .max (1 , (int ) (memoryBytes >> 20 ));
159
+ final Optional <Integer > previousWeight =
160
+ transformation .declareManagedMemoryUseCaseAtOperatorScope (
161
+ ManagedMemoryUseCase .OPERATOR , weightInMebibyte );
162
+ if (previousWeight .isPresent ()) {
163
+ throw new TableException (
164
+ "Managed memory weight has been set, this should not happen." );
165
+ }
166
+ }
167
+ }
168
+
108
169
/**
109
170
* A stream operator to apply {@link MapPartitionFunction} on each partition of the input
110
171
* bounded data stream.
@@ -113,7 +174,7 @@ private static class MapPartitionOperator<IN, OUT>
113
174
extends AbstractUdfStreamOperator <OUT , MapPartitionFunction <IN , OUT >>
114
175
implements OneInputStreamOperator <IN , OUT >, BoundedOneInput {
115
176
116
- private ListState <IN > valuesState ;
177
+ private ListStateWithCache <IN > valuesState ;
117
178
118
179
public MapPartitionOperator (MapPartitionFunction <IN , OUT > mapPartitionFunc ) {
119
180
super (mapPartitionFunc );
@@ -122,24 +183,32 @@ public MapPartitionOperator(MapPartitionFunction<IN, OUT> mapPartitionFunc) {
122
183
@ Override
123
184
public void initializeState (StateInitializationContext context ) throws Exception {
124
185
super .initializeState (context );
125
- ListStateDescriptor <IN > descriptor =
126
- new ListStateDescriptor <>(
127
- "inputState" ,
128
- getOperatorConfig ()
129
- .getTypeSerializerIn (0 , getClass ().getClassLoader ()));
130
- valuesState = context .getOperatorStateStore ().getListState (descriptor );
186
+
187
+ valuesState =
188
+ new ListStateWithCache <>(
189
+ getOperatorConfig ().getTypeSerializerIn (0 , getClass ().getClassLoader ()),
190
+ getContainingTask (),
191
+ getRuntimeContext (),
192
+ context ,
193
+ config .getOperatorID ());
131
194
}
132
195
133
196
@ Override
134
- public void endInput ( ) throws Exception {
135
- userFunction . mapPartition ( valuesState . get (), new TimestampedCollector <>( output ) );
136
- valuesState .clear ( );
197
+ public void snapshotState ( StateSnapshotContext context ) throws Exception {
198
+ super . snapshotState ( context );
199
+ valuesState .snapshotState ( context );
137
200
}
138
201
139
202
@ Override
140
203
public void processElement (StreamRecord <IN > input ) throws Exception {
141
204
valuesState .add (input .getValue ());
142
205
}
206
+
207
+ @ Override
208
+ public void endInput () throws Exception {
209
+ userFunction .mapPartition (valuesState .get (), new TimestampedCollector <>(output ));
210
+ valuesState .clear ();
211
+ }
143
212
}
144
213
145
214
/** A stream operator to apply {@link ReduceFunction} on the input bounded data stream. */
@@ -176,7 +245,7 @@ public void initializeState(StateInitializationContext context) throws Exception
176
245
state =
177
246
context .getOperatorStateStore ()
178
247
.getListState (
179
- new ListStateDescriptor <T >(
248
+ new ListStateDescriptor <>(
180
249
"state" ,
181
250
getOperatorConfig ()
182
251
.getTypeSerializerIn (
@@ -256,4 +325,80 @@ public void flatMap(T[] values, Collector<Tuple2<Integer, T[]>> collector) {
256
325
}
257
326
}
258
327
}
328
+
329
+ /*
330
+ * A stream operator that takes a randomly sampled subset of elements in a bounded data stream.
331
+ */
332
+ private static class SamplingOperator <T > extends AbstractStreamOperator <T >
333
+ implements OneInputStreamOperator <T , T >, BoundedOneInput {
334
+ private final int numSamples ;
335
+
336
+ private final Random random ;
337
+
338
+ private ListState <T > samplesState ;
339
+
340
+ private List <T > samples ;
341
+
342
+ private ListState <Integer > countState ;
343
+
344
+ private int count ;
345
+
346
+ SamplingOperator (int numSamples , long randomSeed ) {
347
+ this .numSamples = numSamples ;
348
+ this .random = new Random (randomSeed );
349
+ }
350
+
351
+ @ Override
352
+ public void initializeState (StateInitializationContext context ) throws Exception {
353
+ super .initializeState (context );
354
+
355
+ ListStateDescriptor <T > samplesDescriptor =
356
+ new ListStateDescriptor <>(
357
+ "samplesState" ,
358
+ getOperatorConfig ()
359
+ .getTypeSerializerIn (0 , getClass ().getClassLoader ()));
360
+ samplesState = context .getOperatorStateStore ().getListState (samplesDescriptor );
361
+ samples = new ArrayList <>(numSamples );
362
+ samplesState .get ().forEach (samples ::add );
363
+
364
+ ListStateDescriptor <Integer > countDescriptor =
365
+ new ListStateDescriptor <>("countState" , IntSerializer .INSTANCE );
366
+ countState = context .getOperatorStateStore ().getListState (countDescriptor );
367
+ Iterator <Integer > countIterator = countState .get ().iterator ();
368
+ if (countIterator .hasNext ()) {
369
+ count = countIterator .next ();
370
+ } else {
371
+ count = 0 ;
372
+ }
373
+ }
374
+
375
+ @ Override
376
+ public void snapshotState (StateSnapshotContext context ) throws Exception {
377
+ super .snapshotState (context );
378
+ samplesState .update (samples );
379
+ countState .update (Collections .singletonList (count ));
380
+ }
381
+
382
+ @ Override
383
+ public void processElement (StreamRecord <T > streamRecord ) throws Exception {
384
+ T value = streamRecord .getValue ();
385
+ count ++;
386
+
387
+ if (samples .size () < numSamples ) {
388
+ samples .add (value );
389
+ } else {
390
+ int index = random .nextInt (count );
391
+ if (index < numSamples ) {
392
+ samples .set (index , value );
393
+ }
394
+ }
395
+ }
396
+
397
+ @ Override
398
+ public void endInput () throws Exception {
399
+ for (T sample : samples ) {
400
+ output .collect (new StreamRecord <>(sample ));
401
+ }
402
+ }
403
+ }
259
404
}
0 commit comments