Skip to content

Commit c8fef60

Browse files
yunfengzhou-hublindong28
authored andcommitted
[FLINK-27798] Fix python test for Flink 1.15 migration.
This closes apache#104.
1 parent f932783 commit c8fef60

File tree

2 files changed

+29
-16
lines changed

2 files changed

+29
-16
lines changed

flink-ml-python/pyflink/ml/core/tests/test_pipeline.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818
import os
1919
from typing import Dict, Any, List
2020

21-
import pandas as pd
22-
from pandas._testing import assert_frame_equal
2321
from pyflink.table import Table, StreamTableEnvironment
2422

2523
from pyflink.ml.core.api import Model
@@ -39,17 +37,19 @@ def test_pipeline_model(self):
3937
model = PipelineModel([model_a, model_b, model_c])
4038
output_table = model.transform(input_table)[0]
4139

42-
assert_frame_equal(output_table.to_pandas(),
43-
pd.DataFrame([[31], [32], [33]], columns=['a']))
40+
predicted_results = [result[0] for result in
41+
self.t_env.to_data_stream(output_table).execute_and_collect()]
42+
self.assertEqual(predicted_results, [31, 32, 33])
4443

4544
# Saves and loads the PipelineModel.
4645
path = os.path.join(self.temp_dir, "test_pipeline_model")
4746
model.save(path)
4847
loaded_model = PipelineModel.load(self.t_env, path)
4948

5049
output_table2 = loaded_model.transform(input_table)[0]
51-
assert_frame_equal(output_table2.to_pandas(),
52-
pd.DataFrame([[31], [32], [33]], columns=['a']))
50+
predicted_results = [result[0] for result in
51+
self.t_env.to_data_stream(output_table2).execute_and_collect()]
52+
self.assertEqual(predicted_results, [31, 32, 33])
5353

5454
def test_pipeline(self):
5555
input_table = self.t_env.from_elements([(1,), (2,), (3,)], ['a'])
@@ -60,8 +60,9 @@ def test_pipeline(self):
6060
model = estimator.fit(input_table)
6161
output_table = model.transform(input_table)[0]
6262

63-
assert_frame_equal(output_table.to_pandas(),
64-
pd.DataFrame([[21], [22], [23]], columns=['a']))
63+
predicted_results = [result[0] for result in
64+
self.t_env.to_data_stream(output_table).execute_and_collect()]
65+
self.assertEqual(predicted_results, [21, 22, 23])
6566

6667
# Saves and loads the PipelineModel.
6768
path = os.path.join(self.temp_dir, "test_pipeline")
@@ -70,8 +71,10 @@ def test_pipeline(self):
7071

7172
model = loaded_estimator.fit(input_table)
7273
output_table = model.transform(input_table)[0]
73-
assert_frame_equal(output_table.to_pandas(),
74-
pd.DataFrame([[21], [22], [23]], columns=['a']))
74+
75+
predicted_results = [result[0] for result in
76+
self.t_env.to_data_stream(output_table).execute_and_collect()]
77+
self.assertEqual(predicted_results, [21, 22, 23])
7578

7679

7780
class Add10Model(Model):

flink-ml-python/pyflink/ml/lib/feature/tests/test_vectorassembler.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# See the License for the specific language governing permissions and
1616
# limitations under the License.
1717
################################################################################
18+
import os
1819

1920
from pyflink.common import Types
2021

@@ -26,6 +27,7 @@
2627
class VectorAssemblerTest(PyFlinkMLTestCase):
2728
def setUp(self):
2829
super(VectorAssemblerTest, self).setUp()
30+
# TODO: Add test for handling invalid values after FLINK-27797 is resolved.
2931
self.input_data_table = self.t_env.from_data_stream(
3032
self.env.from_collection([
3133
(0,
@@ -37,7 +39,6 @@ def setUp(self):
3739
1.0,
3840
Vectors.sparse(5, [1, 2, 3, 4],
3941
[1.0, 2.0, 3.0, 4.0])),
40-
(2, None, None, None),
4142
],
4243
type_info=Types.ROW_NAMED(
4344
['id', 'vec', 'num', 'sparse_vec'],
@@ -60,13 +61,22 @@ def test_param(self):
6061
self.assertEqual('skip', vector_assembler.handle_invalid)
6162
self.assertEqual('assembled_vec', vector_assembler.output_col)
6263

63-
def test_keep_invalid(self):
64+
def test_save_load_transform(self):
6465
vector_assembler = VectorAssembler() \
6566
.set_input_cols('vec', 'num', 'sparse_vec') \
6667
.set_output_col('assembled_vec') \
6768
.set_handle_invalid('keep')
6869

69-
output = vector_assembler.transform(self.input_data_table)[0]
70-
self.assertEqual(
71-
['id', 'vec', 'num', 'sparse_vec', 'assembled_vec'],
72-
output.get_schema().get_field_names())
70+
path = os.path.join(self.temp_dir, 'test_save_load_transform_vector_assembler')
71+
vector_assembler.save(path)
72+
vector_assembler = VectorAssembler.load(self.t_env, path)
73+
74+
output_table = vector_assembler.transform(self.input_data_table)[0]
75+
actual_outputs = [(result[0], result[4]) for result in
76+
self.t_env.to_data_stream(output_table).execute_and_collect()]
77+
78+
for actual_output in actual_outputs:
79+
if actual_output[0] == 0:
80+
self.assertEqual(self.expected_output_data_1, actual_output[1])
81+
else:
82+
self.assertEqual(self.expected_output_data_2, actual_output[1])

0 commit comments

Comments
 (0)