Skip to content

Commit 91e7464

Browse files
[FLINK-27715] Add pyflink examples
This closes apache#121.
1 parent 68aab8a commit 91e7464

34 files changed

+1685
-32
lines changed

.github/workflows/python-checks.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ jobs:
4141
run: python -m mypy --config=setup.cfg
4242
- name: Test the source code
4343
working-directory: flink-ml-python
44-
run: pytest
44+
run: |
45+
pytest pyflink/ml
46+
pytest pyflink/examples
4547
4648

flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/IndexToStringModel.java

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -147,15 +147,8 @@ public void flatMap(Row input, Collector<Row> out) {
147147
}
148148

149149
Row outputStrings = new Row(inputCols.length);
150-
int stringId;
151150
for (int i = 0; i < inputCols.length; i++) {
152-
try {
153-
stringId = (Integer) input.getField(inputCols[i]);
154-
} catch (Exception e) {
155-
throw new RuntimeException(
156-
"The input contains non-integer value: "
157-
+ input.getField(inputCols[i] + "."));
158-
}
151+
int stringId = (Integer) input.getField(inputCols[i]);
159152
if (stringId < stringArrays[i].length && stringId >= 0) {
160153
outputStrings.setField(i, stringArrays[i][stringId]);
161154
} else {
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
################################################################################
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
################################################################################
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
################################################################################
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
################################################################################
18+
19+
# Simple program that trains a Knn model and uses it for classification.
20+
#
21+
# Before executing this program, please make sure you have followed Flink ML's
22+
# quick start guideline to setup Flink ML and Flink environment. The guideline
23+
# can be found at
24+
#
25+
# https://nightlies.apache.org/flink/flink-ml-docs-master/docs/try-flink-ml/quick-start/
26+
27+
from pyflink.common import Types
28+
from pyflink.datastream import StreamExecutionEnvironment
29+
from pyflink.ml.core.linalg import Vectors, DenseVectorTypeInfo
30+
from pyflink.ml.lib.classification.knn import KNN
31+
from pyflink.table import StreamTableEnvironment
32+
33+
# create a new StreamExecutionEnvironment
34+
env = StreamExecutionEnvironment.get_execution_environment()
35+
36+
# create a StreamTableEnvironment
37+
t_env = StreamTableEnvironment.create(env)
38+
39+
# generate input training and prediction data
40+
train_data = t_env.from_data_stream(
41+
env.from_collection([
42+
(Vectors.dense([2.0, 3.0]), 1.0),
43+
(Vectors.dense([2.1, 3.1]), 1.0),
44+
(Vectors.dense([200.1, 300.1]), 2.0),
45+
(Vectors.dense([200.2, 300.2]), 2.0),
46+
(Vectors.dense([200.3, 300.3]), 2.0),
47+
(Vectors.dense([200.4, 300.4]), 2.0),
48+
(Vectors.dense([200.4, 300.4]), 2.0),
49+
(Vectors.dense([200.6, 300.6]), 2.0),
50+
(Vectors.dense([2.1, 3.1]), 1.0),
51+
(Vectors.dense([2.1, 3.1]), 1.0),
52+
(Vectors.dense([2.1, 3.1]), 1.0),
53+
(Vectors.dense([2.1, 3.1]), 1.0),
54+
(Vectors.dense([2.3, 3.2]), 1.0),
55+
(Vectors.dense([2.3, 3.2]), 1.0),
56+
(Vectors.dense([2.8, 3.2]), 3.0),
57+
(Vectors.dense([300., 3.2]), 4.0),
58+
(Vectors.dense([2.2, 3.2]), 1.0),
59+
(Vectors.dense([2.4, 3.2]), 5.0),
60+
(Vectors.dense([2.5, 3.2]), 5.0),
61+
(Vectors.dense([2.5, 3.2]), 5.0),
62+
(Vectors.dense([2.1, 3.1]), 1.0)
63+
],
64+
type_info=Types.ROW_NAMED(
65+
['features', 'label'],
66+
[DenseVectorTypeInfo(), Types.DOUBLE()])))
67+
68+
predict_data = t_env.from_data_stream(
69+
env.from_collection([
70+
(Vectors.dense([4.0, 4.1]), 5.0),
71+
(Vectors.dense([300, 42]), 2.0),
72+
],
73+
type_info=Types.ROW_NAMED(
74+
['features', 'label'],
75+
[DenseVectorTypeInfo(), Types.DOUBLE()])))
76+
77+
# create a knn object and initialize its parameters
78+
knn = KNN().set_k(4)
79+
80+
# train the knn model
81+
model = knn.fit(train_data)
82+
83+
# use the knn model for predictions
84+
output = model.transform(predict_data)[0]
85+
86+
# extract and display the results
87+
field_names = output.get_schema().get_field_names()
88+
for result in t_env.to_data_stream(output).execute_and_collect():
89+
features = result[field_names.index(knn.get_features_col())]
90+
expected_result = result[field_names.index(knn.get_label_col())]
91+
actual_result = result[field_names.index(knn.get_prediction_col())]
92+
print('Features: ' + str(features) + ' \tExpected Result: ' + str(expected_result)
93+
+ ' \tActual Result: ' + str(actual_result))
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
################################################################################
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
################################################################################
18+
19+
# Simple program that trains a LinearSVC model and uses it for classification.
20+
#
21+
# Before executing this program, please make sure you have followed Flink ML's
22+
# quick start guideline to setup Flink ML and Flink environment. The guideline
23+
# can be found at
24+
#
25+
# https://nightlies.apache.org/flink/flink-ml-docs-master/docs/try-flink-ml/quick-start/
26+
27+
from pyflink.common import Types
28+
from pyflink.datastream import StreamExecutionEnvironment
29+
from pyflink.ml.core.linalg import Vectors, DenseVectorTypeInfo
30+
from pyflink.ml.lib.classification.linearsvc import LinearSVC
31+
from pyflink.table import StreamTableEnvironment
32+
33+
# create a new StreamExecutionEnvironment
34+
env = StreamExecutionEnvironment.get_execution_environment()
35+
36+
# create a StreamTableEnvironment
37+
t_env = StreamTableEnvironment.create(env)
38+
39+
# generate input data
40+
input_table = t_env.from_data_stream(
41+
env.from_collection([
42+
(Vectors.dense([1, 2, 3, 4]), 0., 1.),
43+
(Vectors.dense([2, 2, 3, 4]), 0., 2.),
44+
(Vectors.dense([3, 2, 3, 4]), 0., 3.),
45+
(Vectors.dense([4, 2, 3, 4]), 0., 4.),
46+
(Vectors.dense([5, 2, 3, 4]), 0., 5.),
47+
(Vectors.dense([11, 2, 3, 4]), 1., 1.),
48+
(Vectors.dense([12, 2, 3, 4]), 1., 2.),
49+
(Vectors.dense([13, 2, 3, 4]), 1., 3.),
50+
(Vectors.dense([14, 2, 3, 4]), 1., 4.),
51+
(Vectors.dense([15, 2, 3, 4]), 1., 5.),
52+
],
53+
type_info=Types.ROW_NAMED(
54+
['features', 'label', 'weight'],
55+
[DenseVectorTypeInfo(), Types.DOUBLE(), Types.DOUBLE()])
56+
))
57+
58+
# create a linear svc object and initialize its parameters
59+
linear_svc = LinearSVC().set_weight_col('weight')
60+
61+
# train the linear svc model
62+
model = linear_svc.fit(input_table)
63+
64+
# use the linear svc model for predictions
65+
output = model.transform(input_table)[0]
66+
67+
# extract and display the results
68+
field_names = output.get_schema().get_field_names()
69+
for result in t_env.to_data_stream(output).execute_and_collect():
70+
features = result[field_names.index(linear_svc.get_features_col())]
71+
expected_result = result[field_names.index(linear_svc.get_label_col())]
72+
prediction_result = result[field_names.index(linear_svc.get_prediction_col())]
73+
raw_prediction_result = result[field_names.index(linear_svc.get_raw_prediction_col())]
74+
print('Features: ' + str(features) + ' \tExpected Result: ' + str(expected_result)
75+
+ ' \tPrediction Result: ' + str(prediction_result)
76+
+ ' \tRaw Prediction Result: ' + str(raw_prediction_result))
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
################################################################################
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
################################################################################
18+
19+
# Simple program that trains a LogisticRegression model and uses it for
20+
# classification.
21+
#
22+
# Before executing this program, please make sure you have followed Flink ML's
23+
# quick start guideline to setup Flink ML and Flink environment. The guideline
24+
# can be found at
25+
#
26+
# https://nightlies.apache.org/flink/flink-ml-docs-master/docs/try-flink-ml/quick-start/
27+
28+
from pyflink.common import Types
29+
from pyflink.datastream import StreamExecutionEnvironment
30+
from pyflink.ml.core.linalg import Vectors, DenseVectorTypeInfo
31+
from pyflink.ml.lib.classification.logisticregression import LogisticRegression
32+
from pyflink.table import StreamTableEnvironment
33+
34+
# create a new StreamExecutionEnvironment
35+
env = StreamExecutionEnvironment.get_execution_environment()
36+
37+
# create a StreamTableEnvironment
38+
t_env = StreamTableEnvironment.create(env)
39+
40+
# generate input data
41+
input_data = t_env.from_data_stream(
42+
env.from_collection([
43+
(Vectors.dense([1, 2, 3, 4]), 0., 1.),
44+
(Vectors.dense([2, 2, 3, 4]), 0., 2.),
45+
(Vectors.dense([3, 2, 3, 4]), 0., 3.),
46+
(Vectors.dense([4, 2, 3, 4]), 0., 4.),
47+
(Vectors.dense([5, 2, 3, 4]), 0., 5.),
48+
(Vectors.dense([11, 2, 3, 4]), 1., 1.),
49+
(Vectors.dense([12, 2, 3, 4]), 1., 2.),
50+
(Vectors.dense([13, 2, 3, 4]), 1., 3.),
51+
(Vectors.dense([14, 2, 3, 4]), 1., 4.),
52+
(Vectors.dense([15, 2, 3, 4]), 1., 5.),
53+
],
54+
type_info=Types.ROW_NAMED(
55+
['features', 'label', 'weight'],
56+
[DenseVectorTypeInfo(), Types.DOUBLE(), Types.DOUBLE()])
57+
))
58+
59+
# create a logistic regression object and initialize its parameters
60+
logistic_regression = LogisticRegression().set_weight_col('weight')
61+
62+
# train the logistic regression model
63+
model = logistic_regression.fit(input_data)
64+
65+
# use the logistic regression model for predictions
66+
output = model.transform(input_data)[0]
67+
68+
# extract and display the results
69+
field_names = output.get_schema().get_field_names()
70+
for result in t_env.to_data_stream(output).execute_and_collect():
71+
features = result[field_names.index(logistic_regression.get_features_col())]
72+
expected_result = result[field_names.index(logistic_regression.get_label_col())]
73+
prediction_result = result[field_names.index(logistic_regression.get_prediction_col())]
74+
raw_prediction_result = result[field_names.index(logistic_regression.get_raw_prediction_col())]
75+
print('Features: ' + str(features) + ' \tExpected Result: ' + str(expected_result)
76+
+ ' \tPrediction Result: ' + str(prediction_result)
77+
+ ' \tRaw Prediction Result: ' + str(raw_prediction_result))
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
################################################################################
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
################################################################################
18+
19+
# Simple program that trains a NaiveBayes model and uses it for classification.
20+
#
21+
# Before executing this program, please make sure you have followed Flink ML's
22+
# quick start guideline to setup Flink ML and Flink environment. The guideline
23+
# can be found at
24+
#
25+
# https://nightlies.apache.org/flink/flink-ml-docs-master/docs/try-flink-ml/quick-start/
26+
27+
from pyflink.common import Types
28+
from pyflink.datastream import StreamExecutionEnvironment
29+
from pyflink.ml.core.linalg import Vectors, DenseVectorTypeInfo
30+
from pyflink.ml.lib.classification.naivebayes import NaiveBayes
31+
from pyflink.table import StreamTableEnvironment
32+
33+
# create a new StreamExecutionEnvironment
34+
env = StreamExecutionEnvironment.get_execution_environment()
35+
36+
# create a StreamTableEnvironment
37+
t_env = StreamTableEnvironment.create(env)
38+
39+
# generate input training and prediction data
40+
train_table = t_env.from_data_stream(
41+
env.from_collection([
42+
(Vectors.dense([0, 0.]), 11.),
43+
(Vectors.dense([1, 0]), 10.),
44+
(Vectors.dense([1, 1.]), 10.),
45+
],
46+
type_info=Types.ROW_NAMED(
47+
['features', 'label'],
48+
[DenseVectorTypeInfo(), Types.DOUBLE()])))
49+
50+
predict_table = t_env.from_data_stream(
51+
env.from_collection([
52+
(Vectors.dense([0, 1.]),),
53+
(Vectors.dense([0, 0.]),),
54+
(Vectors.dense([1, 0]),),
55+
(Vectors.dense([1, 1.]),),
56+
],
57+
type_info=Types.ROW_NAMED(
58+
['features'],
59+
[DenseVectorTypeInfo()])))
60+
61+
# create a naive bayes object and initialize its parameters
62+
naive_bayes = NaiveBayes() \
63+
.set_smoothing(1.0) \
64+
.set_features_col('features') \
65+
.set_label_col('label') \
66+
.set_prediction_col('prediction') \
67+
.set_model_type('multinomial')
68+
69+
# train the naive bayes model
70+
model = naive_bayes.fit(train_table)
71+
72+
# use the naive bayes model for predictions
73+
output = model.transform(predict_table)[0]
74+
75+
# extract and display the results
76+
field_names = output.get_schema().get_field_names()
77+
for result in t_env.to_data_stream(output).execute_and_collect():
78+
features = result[field_names.index(naive_bayes.get_features_col())]
79+
prediction_result = result[field_names.index(naive_bayes.get_prediction_col())]
80+
print('Features: ' + str(features) + ' \tPrediction Result: ' + str(prediction_result))
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
################################################################################
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
################################################################################

0 commit comments

Comments
 (0)