Skip to content

Commit 8dc6549

Browse files
holdenkdavies
authored andcommitted
[SPARK-12300] [SQL] [PYSPARK] fix schema inferance on local collections
Current schema inference for local python collections halts as soon as there are no NullTypes. This is different than when we specify a sampling ratio of 1.0 on a distributed collection. This could result in incomplete schema information. Author: Holden Karau <holden@us.ibm.com> Closes apache#10275 from holdenk/SPARK-12300-fix-schmea-inferance-on-local-collections. (cherry picked from commit d1ca634) Signed-off-by: Davies Liu <davies.liu@gmail.com>
1 parent c069ffc commit 8dc6549

File tree

2 files changed

+14
-7
lines changed

2 files changed

+14
-7
lines changed

python/pyspark/sql/context.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import sys
1919
import warnings
2020
import json
21+
from functools import reduce
2122

2223
if sys.version >= '3':
2324
basestring = unicode = str
@@ -236,14 +237,9 @@ def _inferSchemaFromList(self, data):
236237
if type(first) is dict:
237238
warnings.warn("inferring schema from dict is deprecated,"
238239
"please use pyspark.sql.Row instead")
239-
schema = _infer_schema(first)
240+
schema = reduce(_merge_type, map(_infer_schema, data))
240241
if _has_nulltype(schema):
241-
for r in data:
242-
schema = _merge_type(schema, _infer_schema(r))
243-
if not _has_nulltype(schema):
244-
break
245-
else:
246-
raise ValueError("Some of types cannot be determined after inferring")
242+
raise ValueError("Some of types cannot be determined after inferring")
247243
return schema
248244

249245
def _inferSchema(self, rdd, samplingRatio=None):

python/pyspark/sql/tests.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,17 @@ def test_apply_schema_to_row(self):
353353
df3 = self.sqlCtx.createDataFrame(rdd, df.schema)
354354
self.assertEqual(10, df3.count())
355355

356+
def test_infer_schema_to_local(self):
357+
input = [{"a": 1}, {"b": "coffee"}]
358+
rdd = self.sc.parallelize(input)
359+
df = self.sqlCtx.createDataFrame(input)
360+
df2 = self.sqlCtx.createDataFrame(rdd, samplingRatio=1.0)
361+
self.assertEqual(df.schema, df2.schema)
362+
363+
rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x))
364+
df3 = self.sqlCtx.createDataFrame(rdd, df.schema)
365+
self.assertEqual(10, df3.count())
366+
356367
def test_serialize_nested_array_and_map(self):
357368
d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})]
358369
rdd = self.sc.parallelize(d)

0 commit comments

Comments
 (0)