Skip to content

Commit 573ac55

Browse files
xguo27marmbrus
authored andcommitted
[SPARK-12512][SQL] support column name with dot in withColumn()
Author: Xiu Guo <xguo27@gmail.com> Closes apache#10500 from xguo27/SPARK-12512.
1 parent 43706bf commit 573ac55

File tree

2 files changed

+27
-12
lines changed

2 files changed

+27
-12
lines changed

sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1171,13 +1171,17 @@ class DataFrame private[sql](
11711171
*/
11721172
def withColumn(colName: String, col: Column): DataFrame = {
11731173
val resolver = sqlContext.analyzer.resolver
1174-
val replaced = schema.exists(f => resolver(f.name, colName))
1175-
if (replaced) {
1176-
val colNames = schema.map { field =>
1177-
val name = field.name
1178-
if (resolver(name, colName)) col.as(colName) else Column(name)
1174+
val output = queryExecution.analyzed.output
1175+
val shouldReplace = output.exists(f => resolver(f.name, colName))
1176+
if (shouldReplace) {
1177+
val columns = output.map { field =>
1178+
if (resolver(field.name, colName)) {
1179+
col.as(colName)
1180+
} else {
1181+
Column(field)
1182+
}
11791183
}
1180-
select(colNames : _*)
1184+
select(columns : _*)
11811185
} else {
11821186
select(Column("*"), col.as(colName))
11831187
}
@@ -1188,13 +1192,17 @@ class DataFrame private[sql](
11881192
*/
11891193
private[spark] def withColumn(colName: String, col: Column, metadata: Metadata): DataFrame = {
11901194
val resolver = sqlContext.analyzer.resolver
1191-
val replaced = schema.exists(f => resolver(f.name, colName))
1192-
if (replaced) {
1193-
val colNames = schema.map { field =>
1194-
val name = field.name
1195-
if (resolver(name, colName)) col.as(colName, metadata) else Column(name)
1195+
val output = queryExecution.analyzed.output
1196+
val shouldReplace = output.exists(f => resolver(f.name, colName))
1197+
if (shouldReplace) {
1198+
val columns = output.map { field =>
1199+
if (resolver(field.name, colName)) {
1200+
col.as(colName, metadata)
1201+
} else {
1202+
Column(field)
1203+
}
11961204
}
1197-
select(colNames : _*)
1205+
select(columns : _*)
11981206
} else {
11991207
select(Column("*"), col.as(colName, metadata))
12001208
}

sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1221,4 +1221,11 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
12211221
" _2: bigint ... 2 more fields> ... 2 more fields> ... 2 more fields]")
12221222

12231223
}
1224+
1225+
test("SPARK-12512: support `.` in column name for withColumn()") {
1226+
val df = Seq("a" -> "b").toDF("col.a", "col.b")
1227+
checkAnswer(df.select(df("*")), Row("a", "b"))
1228+
checkAnswer(df.withColumn("col.a", lit("c")), Row("c", "b"))
1229+
checkAnswer(df.withColumn("col.c", lit("c")), Row("a", "b", "c"))
1230+
}
12241231
}

0 commit comments

Comments
 (0)