Skip to content

Commit 3d0fdef

Browse files
author
volth
committed
Support java.net.InetAddress (encoding and decoding) and user-defined types (encoding only)
1 parent 7dc83b9 commit 3d0fdef

File tree

5 files changed

+136
-58
lines changed

5 files changed

+136
-58
lines changed
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
/*
2+
* Copyright 2013 Maurício Linhares
3+
*
4+
* Maurício Linhares licenses this file to you under the Apache License,
5+
* version 2.0 (the "License"); you may not use this file except in compliance
6+
* with the License. You may obtain a copy of the License at:
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12+
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13+
* License for the specific language governing permissions and limitations
14+
* under the License.
15+
*/
16+
17+
package com.github.mauricio.async.db.column
18+
19+
import java.net.InetAddress
20+
import sun.net.util.IPAddressUtil.{textToNumericFormatV4,textToNumericFormatV6}
21+
22+
object InetAddressEncoderDecoder extends ColumnEncoderDecoder {
23+
24+
override def decode(value: String): Any = {
25+
if (value contains ':') {
26+
InetAddress.getByAddress(textToNumericFormatV6(value))
27+
} else {
28+
InetAddress.getByAddress(textToNumericFormatV4(value))
29+
}
30+
}
31+
32+
override def encode(value: Any): String = {
33+
value.asInstanceOf[InetAddress].getHostAddress
34+
}
35+
36+
}

postgresql-async/src/main/scala/com/github/mauricio/async/db/postgresql/column/ColumnTypes.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ object ColumnTypes {
6767
final val UUIDArray = 2951
6868
final val XMLArray = 143
6969

70+
final val Inet = 869
71+
final val InetArray = 1041
7072
}
7173

7274
/*

postgresql-async/src/main/scala/com/github/mauricio/async/db/postgresql/column/PostgreSQLColumnDecoderRegistry.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ class PostgreSQLColumnDecoderRegistry( charset : Charset = CharsetUtil.UTF_8 ) e
4646
private final val timeWithTimestampArrayDecoder = new ArrayDecoder(TimeWithTimezoneEncoderDecoder)
4747
private final val intervalArrayDecoder = new ArrayDecoder(PostgreSQLIntervalEncoderDecoder)
4848
private final val uuidArrayDecoder = new ArrayDecoder(UUIDEncoderDecoder)
49+
private final val inetAddressArrayDecoder = new ArrayDecoder(InetAddressEncoderDecoder)
4950

5051
override def decode(kind: ColumnData, value: ByteBuf, charset: Charset): Any = {
5152
decoderFor(kind.dataType).decode(kind, value, charset)
@@ -114,6 +115,9 @@ class PostgreSQLColumnDecoderRegistry( charset : Charset = CharsetUtil.UTF_8 ) e
114115
case XMLArray => this.stringArrayDecoder
115116
case ByteA => ByteArrayEncoderDecoder
116117

118+
case Inet => InetAddressEncoderDecoder
119+
case InetArray => this.inetAddressArrayDecoder
120+
117121
case _ => StringEncoderDecoder
118122
}
119123
}

postgresql-async/src/main/scala/com/github/mauricio/async/db/postgresql/column/PostgreSQLColumnEncoderRegistry.scala

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ class PostgreSQLColumnEncoderRegistry extends ColumnEncoderRegistry {
5252
classOf[BigDecimal] -> (BigDecimalEncoderDecoder -> ColumnTypes.Numeric),
5353
classOf[java.math.BigDecimal] -> (BigDecimalEncoderDecoder -> ColumnTypes.Numeric),
5454

55+
classOf[java.net.InetAddress] -> (InetAddressEncoderDecoder -> ColumnTypes.Inet),
56+
5557
classOf[java.util.UUID] -> (UUIDEncoderDecoder -> ColumnTypes.UUID),
5658

5759
classOf[LocalDate] -> ( DateEncoderDecoder -> ColumnTypes.Date ),
@@ -104,17 +106,12 @@ class PostgreSQLColumnEncoderRegistry extends ColumnEncoderRegistry {
104106
if (encoder.isDefined) {
105107
encoder.get._1.encode(value)
106108
} else {
107-
108-
val view: Option[Traversable[Any]] = value match {
109-
case i: java.lang.Iterable[_] => Some(i.toIterable)
110-
case i: Traversable[_] => Some(i)
111-
case i: Array[_] => Some(i.toIterable)
112-
case _ => None
113-
}
114-
115-
view match {
116-
case Some(collection) => encodeArray(collection)
117-
case None => {
109+
value match {
110+
case i: java.lang.Iterable[_] => encodeArray(i.toIterable)
111+
case i: Traversable[_] => encodeArray(i)
112+
case i: Array[_] => encodeArray(i.toIterable)
113+
case p: Product => encodeComposite(p)
114+
case _ => {
118115
this.classesSequence.find(entry => entry._1.isAssignableFrom(value.getClass)) match {
119116
case Some(parent) => parent._2._1.encode(value)
120117
case None => value.toString
@@ -126,14 +123,9 @@ class PostgreSQLColumnEncoderRegistry extends ColumnEncoderRegistry {
126123

127124
}
128125

129-
private def encodeArray(collection: Traversable[_]): String = {
130-
val builder = new StringBuilder()
131-
132-
builder.append('{')
133-
134-
val result = collection.map {
126+
private def encodeComposite(p: Product): String = {
127+
p.productIterator.map {
135128
item =>
136-
137129
if (item == null || item == None) {
138130
"NULL"
139131
} else {
@@ -143,13 +135,22 @@ class PostgreSQLColumnEncoderRegistry extends ColumnEncoderRegistry {
143135
this.encode(item)
144136
}
145137
}
138+
}.mkString("(", ",", ")")
139+
}
146140

147-
}.mkString(",")
148-
149-
builder.append(result)
150-
builder.append('}')
151-
152-
builder.toString()
141+
private def encodeArray(collection: Traversable[_]): String = {
142+
collection.map {
143+
item =>
144+
if (item == null || item == None) {
145+
"NULL"
146+
} else {
147+
if (this.shouldQuote(item)) {
148+
"\"" + this.encode(item).replaceAllLiterally("\\", """\\""").replaceAllLiterally("\"", """\"""") + "\""
149+
} else {
150+
this.encode(item)
151+
}
152+
}
153+
}.mkString("{", ",", "}")
153154
}
154155

155156
private def shouldQuote(value: Any): Boolean = {

postgresql-async/src/test/scala/com/github/mauricio/async/db/postgresql/ArrayTypesSpec.scala

Lines changed: 69 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -16,73 +16,108 @@
1616

1717
package com.github.mauricio.async.db.postgresql
1818

19-
import com.github.mauricio.async.db.column.TimestampWithTimezoneEncoderDecoder
19+
import com.github.mauricio.async.db.column.{TimestampWithTimezoneEncoderDecoder, InetAddressEncoderDecoder}
2020
import org.specs2.mutable.Specification
21+
import java.net.InetAddress
2122

2223
class ArrayTypesSpec extends Specification with DatabaseTestHelper {
23-
24-
val simpleCreate = """create temp table type_test_table (
25-
bigserial_column bigserial not null,
26-
smallint_column integer[] not null,
27-
text_column text[] not null,
28-
timestamp_column timestamp with time zone[] not null,
29-
constraint bigserial_column_pkey primary key (bigserial_column)
30-
)"""
24+
// `uniq` allows sbt to run the tests concurrently as there is no CREATE TEMP TYPE
25+
def simpleCreate(uniq: String) = s"""DROP TYPE IF EXISTS dir_$uniq;
26+
CREATE TYPE direction_$uniq AS ENUM ('in','out');
27+
DROP TYPE IF EXISTS endpoint_$uniq;
28+
CREATE TYPE endpoint_$uniq AS (ip inet, port integer);
29+
create temp table type_test_table_$uniq (
30+
bigserial_column bigserial not null,
31+
smallint_column integer[] not null,
32+
text_column text[] not null,
33+
inet_column inet[] not null,
34+
direction_column direction_$uniq[] not null,
35+
endpoint_column endpoint_$uniq[] not null,
36+
timestamp_column timestamp with time zone[] not null,
37+
constraint bigserial_column_pkey primary key (bigserial_column)
38+
)"""
39+
def simpleDrop(uniq: String) = s"""drop table if exists type_test_table_$uniq;
40+
drop type if exists endpoint_$uniq;
41+
drop type if exists direction_$uniq"""
3142

3243
val insert =
33-
"""insert into type_test_table
34-
(smallint_column, text_column, timestamp_column)
44+
"""insert into type_test_table_cptat
45+
(smallint_column, text_column, inet_column, direction_column, endpoint_column, timestamp_column)
3546
values (
3647
'{1,2,3,4}',
3748
'{"some,\"comma,separated,text","another line of text","fake\,backslash","real\\,backslash\\",NULL}',
49+
'{"127.0.0.1","2002:15::1"}',
50+
'{"in","out"}',
51+
'{"(\"127.0.0.1\",80)","(\"2002:15::1\",443)"}',
3852
'{"2013-04-06 01:15:10.528-03","2013-04-06 01:15:08.528-03"}'
3953
)"""
4054

41-
val insertPreparedStatement = """insert into type_test_table
42-
(smallint_column, text_column, timestamp_column)
43-
values (?,?,?)"""
55+
val insertPreparedStatement = """insert into type_test_table_csaups
56+
(smallint_column, text_column, inet_column, direction_column, endpoint_column, timestamp_column)
57+
values (?,?,?,?,?,?)"""
4458

4559
"connection" should {
4660

4761
"correctly parse the array type" in {
4862

4963
withHandler {
5064
handler =>
51-
executeDdl(handler, simpleCreate)
52-
executeDdl(handler, insert, 1)
53-
val result = executeQuery(handler, "select * from type_test_table").rows.get
54-
result(0)("smallint_column") === List(1,2,3,4)
55-
result(0)("text_column") === List("some,\"comma,separated,text", "another line of text", "fake,backslash", "real\\,backslash\\", null )
56-
result(0)("timestamp_column") === List(
57-
TimestampWithTimezoneEncoderDecoder.decode("2013-04-06 01:15:10.528-03"),
58-
TimestampWithTimezoneEncoderDecoder.decode("2013-04-06 01:15:08.528-03")
59-
)
65+
try {
66+
executeDdl(handler, simpleCreate("cptat"))
67+
executeDdl(handler, insert, 1)
68+
val result = executeQuery(handler, "select * from type_test_table_cptat").rows.get
69+
result(0)("smallint_column") === List(1,2,3,4)
70+
result(0)("text_column") === List("some,\"comma,separated,text", "another line of text", "fake,backslash", "real\\,backslash\\", null )
71+
result(0)("timestamp_column") === List(
72+
TimestampWithTimezoneEncoderDecoder.decode("2013-04-06 01:15:10.528-03"),
73+
TimestampWithTimezoneEncoderDecoder.decode("2013-04-06 01:15:08.528-03")
74+
)
75+
} finally {
76+
executeDdl(handler, simpleDrop("cptat"))
77+
}
6078
}
6179

6280
}
6381

6482
"correctly send arrays using prepared statements" in {
83+
case class Endpoint(ip: InetAddress, port: Int)
6584

6685
val timestamps = List(
6786
TimestampWithTimezoneEncoderDecoder.decode("2013-04-06 01:15:10.528-03"),
6887
TimestampWithTimezoneEncoderDecoder.decode("2013-04-06 01:15:08.528-03")
6988
)
89+
val inets = List(
90+
InetAddressEncoderDecoder.decode("127.0.0.1"),
91+
InetAddressEncoderDecoder.decode("2002:15::1")
92+
)
93+
val directions = List("in", "out")
94+
val endpoints = List(
95+
Endpoint(InetAddress.getByName("127.0.0.1"), 80), // case class
96+
(InetAddress.getByName("2002:15::1"), 443) // tuple
97+
)
7098
val numbers = List(1,2,3,4)
7199
val texts = List("some,\"comma,separated,text", "another line of text", "fake,backslash", "real\\,backslash\\", null )
72100

73101
withHandler {
74102
handler =>
75-
executeDdl(handler, simpleCreate)
76-
executePreparedStatement(
77-
handler,
78-
this.insertPreparedStatement,
79-
Array( numbers, texts, timestamps ) )
80-
81-
val result = executeQuery(handler, "select * from type_test_table").rows.get
82-
83-
result(0)("smallint_column") === numbers
84-
result(0)("text_column") === texts
85-
result(0)("timestamp_column") === timestamps
103+
try {
104+
executeDdl(handler, simpleCreate("csaups"))
105+
executePreparedStatement(
106+
handler,
107+
this.insertPreparedStatement,
108+
Array( numbers, texts, inets, directions, endpoints, timestamps ) )
109+
110+
val result = executeQuery(handler, "select * from type_test_table_csaups").rows.get
111+
112+
result(0)("smallint_column") === numbers
113+
result(0)("text_column") === texts
114+
result(0)("inet_column") === inets
115+
result(0)("direction_column") === "{in,out}" // user type decoding not supported
116+
result(0)("endpoint_column") === """{"(127.0.0.1,80)","(2002:15::1,443)"}""" // user type decoding not supported
117+
result(0)("timestamp_column") === timestamps
118+
} finally {
119+
executeDdl(handler, simpleDrop("csaups"))
120+
}
86121
}
87122

88123
}

0 commit comments

Comments
 (0)