Skip to content

Commit 77ae525

Browse files
committed
Merge pull request #10 from alaisi/pipelining
WIP: Add pipelining to executed statements
2 parents 618ee11 + 60bec87 commit 77ae525

File tree

5 files changed

+209
-17
lines changed

5 files changed

+209
-17
lines changed

src/main/java/com/github/pgasync/ConnectionPoolBuilder.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,11 @@ public ConnectionPoolBuilder ssl(boolean ssl) {
8383
return this;
8484
}
8585

86+
public ConnectionPoolBuilder pipeline(boolean pipeline) {
87+
properties.usePipelining = pipeline;
88+
return this;
89+
}
90+
8691
public ConnectionPoolBuilder validationQuery(String validationQuery) {
8792
properties.validationQuery = validationQuery;
8893
return this;
@@ -102,6 +107,7 @@ public static class PoolProperties {
102107
DataConverter dataConverter = null;
103108
List<Converter<?>> converters = new ArrayList<>();
104109
boolean useSsl;
110+
boolean usePipelining;
105111
String validationQuery = "SELECT 1";
106112

107113
public String getHostname() {
@@ -125,6 +131,9 @@ public int getPoolSize() {
125131
public boolean getUseSsl() {
126132
return useSsl;
127133
}
134+
public boolean getUsePipelining() {
135+
return usePipelining;
136+
}
128137
public DataConverter getDataConverter() {
129138
return dataConverter != null ? dataConverter : new DataConverter(converters);
130139
}

src/main/java/com/github/pgasync/impl/PgConnectionPool.java

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ static class QueuedCallback {
6060
final ConnectionValidator validator;
6161

6262
final int poolSize;
63+
protected final boolean pipeline;
64+
6365
int currentSize;
6466
volatile boolean closed;
6567

@@ -71,6 +73,7 @@ public PgConnectionPool(PoolProperties properties) {
7173
this.poolSize = properties.getPoolSize();
7274
this.dataConverter = properties.getDataConverter();
7375
this.validator = properties.getValidator();
76+
this.pipeline = properties.getUsePipelining();
7477
}
7578

7679
@Override
@@ -82,15 +85,18 @@ public void query(String sql, Consumer<ResultSet> onResult, Consumer<Throwable>
8285
@SuppressWarnings("rawtypes")
8386
public void query(String sql, List params, Consumer<ResultSet> onResult, Consumer<Throwable> onError) {
8487
getConnection(connection ->
88+
{
8589
connection.query(sql, params,
8690
result -> {
87-
release(connection);
91+
releaseIfNotPipelining(connection);
8892
onResult.accept(result);
8993
},
9094
exception -> {
91-
release(connection);
95+
releaseIfNotPipelining(connection);
9296
onError.accept(exception);
93-
}),
97+
});
98+
releaseIfPipelining(connection);
99+
},
94100
onError);
95101
}
96102

@@ -175,6 +181,18 @@ void getConnection(final Consumer<Connection> onConnection, final Consumer<Throw
175181

176182
}
177183

184+
private void releaseIfPipelining(Connection connection) {
185+
if (pipeline) {
186+
release(connection);
187+
}
188+
}
189+
190+
private void releaseIfNotPipelining(Connection connection) {
191+
if (!pipeline) {
192+
release(connection);
193+
}
194+
}
195+
178196
@Override
179197
public void release(Connection connection) {
180198
if(closed) {

src/main/java/com/github/pgasync/impl/netty/NettyPgConnectionPool.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ public NettyPgConnectionPool(PoolProperties properties) {
4040

4141
@Override
4242
protected PgProtocolStream openStream(InetSocketAddress address) {
43-
return new NettyPgProtocolStream(group, address, useSsl);
43+
return new NettyPgProtocolStream(group, address, useSsl, pipeline);
4444
}
4545

4646
@Override

src/main/java/com/github/pgasync/impl/netty/NettyPgProtocolStream.java

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@
3434
import java.util.UUID;
3535
import java.util.concurrent.ConcurrentHashMap;
3636
import java.util.concurrent.ConcurrentMap;
37+
import java.util.Queue;
38+
import java.util.concurrent.ArrayBlockingQueue;
39+
import java.util.concurrent.LinkedBlockingDeque;
3740
import java.util.function.Consumer;
3841

3942
import static java.util.Collections.singletonList;
@@ -51,12 +54,13 @@ public class NettyPgProtocolStream implements PgProtocolStream {
5154
final ConcurrentMap<String,Map<String,Consumer<String>>> listeners = new ConcurrentHashMap<>();
5255

5356
ChannelHandlerContext ctx;
54-
volatile Consumer<Message> onReceive;
57+
final Queue<Consumer<Message>> onReceivers;
5558

56-
public NettyPgProtocolStream(EventLoopGroup group, SocketAddress address, boolean useSsl) {
59+
public NettyPgProtocolStream(EventLoopGroup group, SocketAddress address, boolean useSsl, boolean pipeline) {
5760
this.group = group;
5861
this.address = address;
5962
this.useSsl = useSsl; // TODO: refactor into SSLConfig with trust parameters
63+
this.onReceivers = pipeline ? new LinkedBlockingDeque<>() : new ArrayBlockingQueue<>(1);
6064
}
6165

6266
@Override
@@ -78,16 +82,22 @@ public void send(Message message, Consumer<List<Message>> replyTo) {
7882
if(!isConnected()) {
7983
throw new IllegalStateException("Channel is closed");
8084
}
81-
onReceive = newReplyHandler(replyTo);
85+
addNewReplyHandler(replyTo);
8286
ctx.writeAndFlush(message);
8387
}
8488

89+
private void addNewReplyHandler(Consumer<List<Message>> replyTo) {
90+
if (!onReceivers.offer(newReplyHandler(replyTo))) {
91+
replyTo.accept(singletonList(new ChannelError("Pipelining not enabled")));
92+
}
93+
}
94+
8595
@Override
8696
public void send(List<Message> messages, Consumer<List<Message>> replyTo) {
8797
if(!isConnected()) {
8898
throw new IllegalStateException("Channel is closed");
8999
}
90-
onReceive = newReplyHandler(replyTo);
100+
addNewReplyHandler(replyTo);
91101
messages.forEach(ctx::write);
92102
ctx.flush();
93103
}
@@ -128,7 +138,7 @@ Consumer<Message> newReplyHandler(Consumer<List<Message>> consumer) {
128138
if(msg instanceof ReadyForQuery
129139
|| msg instanceof ChannelError
130140
|| (msg instanceof Authentication && !((Authentication) msg).isAuthenticationOk())) {
131-
onReceive = null;
141+
onReceivers.remove();
132142
consumer.accept(messages);
133143
}
134144
};
@@ -159,7 +169,7 @@ public void channelActive(ChannelHandlerContext context) {
159169
}
160170
void startup(ChannelHandlerContext context) {
161171
ctx = context;
162-
onReceive = newReplyHandler(replyTo);
172+
addNewReplyHandler(replyTo);
163173
context.writeAndFlush(startup);
164174
context.pipeline().remove(this);
165175
}
@@ -210,19 +220,15 @@ public void channelRead(ChannelHandlerContext context, Object msg) throws Except
210220
publishNotification((NotificationResponse) msg);
211221
return;
212222
}
213-
onReceive.accept((Message) msg);
223+
onReceivers.peek().accept((Message) msg);
214224
}
215225
@Override
216226
public void channelInactive(ChannelHandlerContext context) throws Exception {
217-
if(onReceive != null) {
218-
onReceive.accept(new ChannelError("Channel state changed to inactive"));
219-
}
227+
onReceivers.forEach(r -> r.accept(new ChannelError("Channel state changed to inactive")));
220228
}
221229
@Override
222230
public void exceptionCaught(ChannelHandlerContext context, Throwable cause) throws Exception {
223-
if(onReceive != null) {
224-
onReceive.accept(new ChannelError(cause));
225-
}
231+
onReceivers.forEach(r -> r.accept(new ChannelError(cause)));
226232
}
227233
};
228234
}
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
*/
14+
15+
package com.github.pgasync.impl;
16+
17+
import static com.github.pgasync.impl.DatabaseRule.createPoolBuilder;
18+
import static java.lang.System.currentTimeMillis;
19+
import static java.util.concurrent.TimeUnit.MILLISECONDS;
20+
import static java.util.concurrent.TimeUnit.SECONDS;
21+
import static org.hamcrest.CoreMatchers.containsString;
22+
import static org.hamcrest.CoreMatchers.is;
23+
import static org.hamcrest.CoreMatchers.isA;
24+
import static org.junit.Assert.assertThat;
25+
26+
import java.util.Deque;
27+
import java.util.concurrent.BlockingQueue;
28+
import java.util.concurrent.CountDownLatch;
29+
import java.util.concurrent.LinkedBlockingDeque;
30+
import java.util.concurrent.SynchronousQueue;
31+
import java.util.concurrent.atomic.AtomicLong;
32+
import java.util.function.Consumer;
33+
34+
import org.junit.After;
35+
import org.junit.Test;
36+
37+
import com.github.pgasync.Connection;
38+
import com.github.pgasync.ConnectionPool;
39+
import com.github.pgasync.ResultSet;
40+
41+
/**
42+
* Tests for statement pipelining.
43+
*
44+
* @author Mikko Tiihonen
45+
*/
46+
public class PipelineTest {
47+
final Consumer<Throwable> err = t -> { throw new AssertionError("failed", t); };
48+
49+
Connection c;
50+
ConnectionPool pool;
51+
52+
@After
53+
public void closeConnection() {
54+
if (c != null) {
55+
pool.release(c);
56+
}
57+
if (pool != null) {
58+
pool.close();
59+
}
60+
}
61+
62+
@Test
63+
public void connectionPipelinesQueries() throws InterruptedException {
64+
pool = createPoolBuilder(1).pipeline(true).build();
65+
66+
int count = 5;
67+
double sleep = 0.5;
68+
Deque<Long> results = new LinkedBlockingDeque<>();
69+
long startWrite = currentTimeMillis();
70+
for (int i = 0; i < count; ++i) {
71+
pool.query("select " + i + ", pg_sleep(" + sleep + ")", r -> results.add(currentTimeMillis()),
72+
err);
73+
}
74+
long writeTime = currentTimeMillis() - startWrite;
75+
76+
long remoteWaitTimeSeconds = (long) (sleep * count);
77+
SECONDS.sleep(1 + remoteWaitTimeSeconds);
78+
long readTime = results.getLast() - results.getFirst();
79+
80+
assertThat(results.size(), is(count));
81+
assertThat(MILLISECONDS.toSeconds(writeTime), is(0L));
82+
assertThat(MILLISECONDS.toSeconds(readTime + 999) >= remoteWaitTimeSeconds, is(true));
83+
}
84+
85+
private Connection getConnection(boolean pipeline) throws InterruptedException {
86+
pool = createPoolBuilder(1).pipeline(pipeline).build();
87+
SynchronousQueue<Connection> connQueue = new SynchronousQueue<>();
88+
pool.getConnection(c -> connQueue.add(c), err);
89+
return c = connQueue.take();
90+
}
91+
92+
@Test
93+
public void disabledConnectionPipeliningThrowsErrorWhenPipeliningIsAttempted() throws Exception {
94+
Connection c = getConnection(false);
95+
96+
BlockingQueue<ResultSet> rs = new LinkedBlockingDeque<>();
97+
BlockingQueue<Throwable> err = new LinkedBlockingDeque<>();
98+
for (int i = 0; i < 2; ++i) {
99+
c.query("select " + i + ", pg_sleep(0.5)", r -> rs.add(r), e -> err.add(e));
100+
}
101+
assertThat(err.take().getMessage(), containsString("Pipelining not enabled"));
102+
assertThat(rs.take(), isA(ResultSet.class));
103+
}
104+
105+
@Test
106+
public void connectionPoolPipelinesQueries() throws InterruptedException {
107+
Connection c = getConnection(true);
108+
109+
int count = 5;
110+
double sleep = 0.5;
111+
Deque<Long> results = new LinkedBlockingDeque<>();
112+
long startWrite = currentTimeMillis();
113+
for (int i = 0; i < count; ++i) {
114+
c.query("select " + i + ", pg_sleep(" + sleep + ")", r -> results.add(currentTimeMillis()),
115+
err);
116+
}
117+
long writeTime = currentTimeMillis() - startWrite;
118+
119+
long remoteWaitTimeSeconds = (long) (sleep * count);
120+
SECONDS.sleep(1 + remoteWaitTimeSeconds);
121+
long readTime = results.getLast() - results.getFirst();
122+
123+
assertThat(results.size(), is(count));
124+
assertThat(MILLISECONDS.toSeconds(writeTime), is(0L));
125+
assertThat(MILLISECONDS.toSeconds(readTime + 999) >= remoteWaitTimeSeconds, is(true));
126+
}
127+
128+
@Test
129+
public void connectionPoolPipelinesQueriesWithinTransaction() throws InterruptedException {
130+
pool = createPoolBuilder(1).pipeline(true).build();
131+
132+
int count = 5;
133+
double sleep = 0.5;
134+
Deque<Long> results = new LinkedBlockingDeque<>();
135+
AtomicLong writeTime = new AtomicLong();
136+
137+
CountDownLatch sync = new CountDownLatch(1);
138+
long startWrite = currentTimeMillis();
139+
pool.begin(t -> {
140+
for (int i = 0; i < count; ++i) {
141+
t.query("select " + i + ", pg_sleep(" + sleep + ")", r -> results.add(currentTimeMillis()),
142+
err);
143+
}
144+
t.commit(() -> {
145+
sync.countDown();
146+
} , err);
147+
writeTime.set(currentTimeMillis() - startWrite);
148+
} , err);
149+
sync.await(3, SECONDS);
150+
151+
long remoteWaitTimeSeconds = (long) (sleep * count);
152+
SECONDS.sleep(1 + remoteWaitTimeSeconds);
153+
long readTime = results.getLast() - results.getFirst();
154+
155+
assertThat(results.size(), is(count));
156+
assertThat(MILLISECONDS.toSeconds(writeTime.get()), is(0L));
157+
assertThat(MILLISECONDS.toSeconds(readTime + 999) >= remoteWaitTimeSeconds, is(true));
158+
}
159+
}

0 commit comments

Comments
 (0)