Skip to content

Commit ecd0a9f

Browse files
authored
Merge pull request apache#748 from tuohai666/dev
for apache#675 COM_STMT_EXECUTE procedure
2 parents 40c961f + a9d611d commit ecd0a9f

File tree

13 files changed

+247
-57
lines changed

13 files changed

+247
-57
lines changed

sharding-proxy/src/main/java/io/shardingjdbc/proxy/backend/common/StatementExecuteBackendHandler.java

Lines changed: 102 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,27 @@ public final class StatementExecuteBackendHandler implements BackendHandler {
6565

6666
private final PreparedStatementRoutingEngine routingEngine;
6767

68+
private List<Connection> connections;
69+
70+
private List<ResultSet> resultSets;
71+
72+
private MergedResult mergedResult;
73+
74+
private int currentSequenceId;
75+
76+
private int columnCount;
77+
78+
private final List<ColumnType> columnTypes;
79+
80+
private boolean hasMoreResultValueFlag;
81+
6882
public StatementExecuteBackendHandler(final List<PreparedStatementParameter> preparedStatementParameters, final int statementId, final DatabaseType databaseType, final boolean showSQL) {
6983
this.preparedStatementParameters = preparedStatementParameters;
7084
routingEngine = new PreparedStatementRoutingEngine(PreparedStatementRegistry.getInstance().getSQL(statementId), ShardingRuleRegistry.getInstance().getShardingRule(), databaseType, showSQL);
85+
connections = new ArrayList<>(1024);
86+
resultSets = new ArrayList<>(1024);
87+
columnTypes = new ArrayList<>(32);
88+
hasMoreResultValueFlag = true;
7189
}
7290

7391
@Override
@@ -77,24 +95,23 @@ public CommandResponsePackets execute() {
7795
if (routeResult.getExecutionUnits().isEmpty()) {
7896
return new CommandResponsePackets(new OKPacket(1, 0, 0, StatusFlag.SERVER_STATUS_AUTOCOMMIT.getValue(), 0, ""));
7997
}
80-
List<ColumnType> columnTypes = new ArrayList<>(32);
8198
List<CommandResponsePackets> result = new LinkedList<>();
8299
for (SQLExecutionUnit each : routeResult.getExecutionUnits()) {
83100
// TODO multiple threads
84-
result.add(execute(routeResult.getSqlStatement(), each, columnTypes));
101+
result.add(execute(routeResult.getSqlStatement(), each));
85102
}
86-
return merge(routeResult.getSqlStatement(), result, columnTypes);
103+
return merge(routeResult.getSqlStatement(), result);
87104
}
88105

89-
private CommandResponsePackets execute(final SQLStatement sqlStatement, final SQLExecutionUnit sqlExecutionUnit, final List<ColumnType> columnTypes) {
106+
private CommandResponsePackets execute(final SQLStatement sqlStatement, final SQLExecutionUnit sqlExecutionUnit) {
90107
switch (sqlStatement.getType()) {
91108
case DQL:
92-
return executeQuery(ShardingRuleRegistry.getInstance().getDataSourceMap().get(sqlExecutionUnit.getDataSource()), sqlExecutionUnit.getSql(), columnTypes);
109+
return executeQuery(ShardingRuleRegistry.getInstance().getDataSourceMap().get(sqlExecutionUnit.getDataSource()), sqlExecutionUnit.getSql());
93110
case DML:
94111
case DDL:
95112
return executeUpdate(ShardingRuleRegistry.getInstance().getDataSourceMap().get(sqlExecutionUnit.getDataSource()), sqlExecutionUnit.getSql(), sqlStatement);
96113
default:
97-
return executeCommon(ShardingRuleRegistry.getInstance().getDataSourceMap().get(sqlExecutionUnit.getDataSource()), sqlExecutionUnit.getSql(), columnTypes);
114+
return executeCommon(ShardingRuleRegistry.getInstance().getDataSourceMap().get(sqlExecutionUnit.getDataSource()), sqlExecutionUnit.getSql());
98115
}
99116
}
100117

@@ -112,13 +129,16 @@ private void setJDBCPreparedStatementParameters(final PreparedStatement prepared
112129
}
113130
}
114131

115-
private CommandResponsePackets executeQuery(final DataSource dataSource, final String sql, final List<ColumnType> columnTypes) {
116-
try (
117-
Connection connection = dataSource.getConnection();
118-
PreparedStatement preparedStatement = connection.prepareStatement(sql)) {
132+
private CommandResponsePackets executeQuery(final DataSource dataSource, final String sql) {
133+
PreparedStatement preparedStatement;
134+
try {
135+
Connection connection = dataSource.getConnection();
136+
connections.add(connection);
137+
preparedStatement = connection.prepareStatement(sql);
138+
preparedStatement.setFetchSize(Integer.MIN_VALUE);
119139
setJDBCPreparedStatementParameters(preparedStatement);
120-
ResultSet resultSet = preparedStatement.executeQuery();
121-
return getDatabaseProtocolPackets(resultSet, columnTypes);
140+
resultSets.add(preparedStatement.executeQuery());
141+
return getDatabaseProtocolPackets();
122142
} catch (final SQLException ex) {
123143
return new CommandResponsePackets(new ErrPacket(1, ex.getErrorCode(), "", ex.getSQLState(), ex.getMessage()));
124144
}
@@ -144,24 +164,24 @@ private CommandResponsePackets executeUpdate(final DataSource dataSource, final
144164
} catch (final SQLException ex) {
145165
return new CommandResponsePackets(new ErrPacket(1, ex.getErrorCode(), "", ex.getSQLState(), ex.getMessage()));
146166
} finally {
147-
if (preparedStatement != null) {
167+
if (null != preparedStatement) {
148168
try {
149169
preparedStatement.close();
150170
} catch (final SQLException ignore) {
151171
}
152172
}
153173
}
154-
155174
}
156175

157-
private CommandResponsePackets executeCommon(final DataSource dataSource, final String sql, final List<ColumnType> columnTypes) {
176+
private CommandResponsePackets executeCommon(final DataSource dataSource, final String sql) {
158177
try (
159178
Connection connection = dataSource.getConnection();
160179
PreparedStatement preparedStatement = connection.prepareStatement(sql)) {
161180
setJDBCPreparedStatementParameters(preparedStatement);
162181
boolean hasResultSet = preparedStatement.execute();
163182
if (hasResultSet) {
164-
return getDatabaseProtocolPackets(preparedStatement.getResultSet(), columnTypes);
183+
resultSets.add(preparedStatement.getResultSet());
184+
return getDatabaseProtocolPackets();
165185
} else {
166186
return new CommandResponsePackets(new OKPacket(1, preparedStatement.getUpdateCount(), 0, StatusFlag.SERVER_STATUS_AUTOCOMMIT.getValue(), 0, ""));
167187
}
@@ -170,11 +190,11 @@ private CommandResponsePackets executeCommon(final DataSource dataSource, final
170190
}
171191
}
172192

173-
private CommandResponsePackets getDatabaseProtocolPackets(final ResultSet resultSet, final List<ColumnType> columnTypes) throws SQLException {
193+
private CommandResponsePackets getDatabaseProtocolPackets() throws SQLException {
174194
CommandResponsePackets result = new CommandResponsePackets();
175195
int currentSequenceId = 0;
176-
ResultSetMetaData resultSetMetaData = resultSet.getMetaData();
177-
int columnCount = resultSetMetaData.getColumnCount();
196+
ResultSetMetaData resultSetMetaData = resultSets.get(resultSets.size() - 1).getMetaData();
197+
columnCount = resultSetMetaData.getColumnCount();
178198
if (0 == columnCount) {
179199
result.addPacket(new OKPacket(++currentSequenceId, 0, 0, StatusFlag.SERVER_STATUS_AUTOCOMMIT.getValue(), 0, ""));
180200
return result;
@@ -188,14 +208,6 @@ private CommandResponsePackets getDatabaseProtocolPackets(final ResultSet result
188208
columnTypes.add(columnType);
189209
}
190210
result.addPacket(new EofPacket(++currentSequenceId, 0, StatusFlag.SERVER_STATUS_AUTOCOMMIT.getValue()));
191-
while (resultSet.next()) {
192-
List<Object> data = new ArrayList<>(columnCount);
193-
for (int i = 1; i <= columnCount; i++) {
194-
data.add(resultSet.getObject(i));
195-
}
196-
result.addPacket(new BinaryResultSetRowPacket(++currentSequenceId, columnCount, data, columnTypes));
197-
}
198-
result.addPacket(new EofPacket(++currentSequenceId, 0, StatusFlag.SERVER_STATUS_AUTOCOMMIT.getValue()));
199211
return result;
200212
}
201213

@@ -208,7 +220,7 @@ private long getGeneratedKey(final PreparedStatement preparedStatement) throws S
208220
return result;
209221
}
210222

211-
private CommandResponsePackets merge(final SQLStatement sqlStatement, final List<CommandResponsePackets> packets, final List<ColumnType> columnTypes) {
223+
private CommandResponsePackets merge(final SQLStatement sqlStatement, final List<CommandResponsePackets> packets) {
212224
if (1 == packets.size()) {
213225
return packets.iterator().next();
214226
}
@@ -225,7 +237,7 @@ private CommandResponsePackets merge(final SQLStatement sqlStatement, final List
225237
return mergeDML(headPackets);
226238
}
227239
if (SQLType.DQL == sqlStatement.getType() || SQLType.DAL == sqlStatement.getType()) {
228-
return mergeDQLorDAL(sqlStatement, packets, columnTypes);
240+
return mergeDQLorDAL(sqlStatement, packets);
229241
}
230242
return packets.get(0);
231243
}
@@ -241,44 +253,88 @@ private CommandResponsePackets mergeDML(final CommandResponsePackets firstPacket
241253
return new CommandResponsePackets(new OKPacket(1, affectedRows, 0, StatusFlag.SERVER_STATUS_AUTOCOMMIT.getValue(), 0, ""));
242254
}
243255

244-
private CommandResponsePackets mergeDQLorDAL(final SQLStatement sqlStatement, final List<CommandResponsePackets> packets, final List<ColumnType> columnTypes) {
256+
private CommandResponsePackets mergeDQLorDAL(final SQLStatement sqlStatement, final List<CommandResponsePackets> packets) {
245257
List<QueryResult> queryResults = new ArrayList<>(packets.size());
246-
for (CommandResponsePackets each : packets) {
258+
for (int i = 0; i < packets.size(); i++) {
247259
// TODO replace to a common PacketQueryResult
248-
queryResults.add(new MySQLPacketStatementExecuteQueryResult(each));
260+
queryResults.add(new MySQLPacketStatementExecuteQueryResult(packets.get(i), resultSets.get(i), columnTypes));
249261
}
250-
MergedResult mergedResult;
251262
try {
252263
mergedResult = MergeEngineFactory.newInstance(ShardingRuleRegistry.getInstance().getShardingRule(), queryResults, sqlStatement).merge();
253264
} catch (final SQLException ex) {
254265
return new CommandResponsePackets(new ErrPacket(1, ex.getErrorCode(), "", ex.getSQLState(), ex.getMessage()));
255266
}
256-
return buildPackets(packets, mergedResult, columnTypes);
267+
return buildPackets(packets);
257268
}
258269

259-
private CommandResponsePackets buildPackets(final List<CommandResponsePackets> packets, final MergedResult mergedResult, final List<ColumnType> columnTypes) {
270+
private CommandResponsePackets buildPackets(final List<CommandResponsePackets> packets) {
260271
CommandResponsePackets result = new CommandResponsePackets();
261272
Iterator<DatabaseProtocolPacket> databaseProtocolPacketsSampling = packets.iterator().next().getDatabaseProtocolPackets().iterator();
262273
FieldCountPacket fieldCountPacketSampling = (FieldCountPacket) databaseProtocolPacketsSampling.next();
263274
result.addPacket(fieldCountPacketSampling);
264-
int columnCount = fieldCountPacketSampling.getColumnCount();
275+
++currentSequenceId;
265276
for (int i = 0; i < columnCount; i++) {
266277
result.addPacket(databaseProtocolPacketsSampling.next());
278+
++currentSequenceId;
267279
}
268280
result.addPacket(databaseProtocolPacketsSampling.next());
269-
int currentSequenceId = result.size();
281+
++currentSequenceId;
282+
return result;
283+
}
284+
285+
/**
286+
* Has more Result value.
287+
*
288+
* @return has more result value
289+
* @throws SQLException sql exception
290+
*/
291+
public boolean hasMoreResultValue() throws SQLException {
292+
if (!hasMoreResultValueFlag) {
293+
return false;
294+
}
295+
if (!mergedResult.next()) {
296+
hasMoreResultValueFlag = false;
297+
cleanJDBCResources();
298+
}
299+
return true;
300+
}
301+
302+
/**
303+
* Get result value.
304+
*
305+
* @return database protocol packet
306+
*/
307+
public DatabaseProtocolPacket getResultValue() {
308+
if (!hasMoreResultValueFlag) {
309+
return new EofPacket(++currentSequenceId, 0, StatusFlag.SERVER_STATUS_AUTOCOMMIT.getValue());
310+
}
270311
try {
271-
while (mergedResult.next()) {
272-
List<Object> data = new ArrayList<>(columnCount);
273-
for (int i = 1; i <= columnCount; i++) {
274-
data.add(mergedResult.getValue(i, Object.class));
275-
}
276-
result.addPacket(new BinaryResultSetRowPacket(++currentSequenceId, columnCount, data, columnTypes));
312+
List<Object> data = new ArrayList<>(columnCount);
313+
for (int i = 1; i <= columnCount; i++) {
314+
data.add(mergedResult.getValue(i, Object.class));
277315
}
316+
return new BinaryResultSetRowPacket(++currentSequenceId, columnCount, data, columnTypes);
278317
} catch (final SQLException ex) {
279-
return new CommandResponsePackets(new ErrPacket(1, ex.getErrorCode(), "", ex.getSQLState(), ex.getMessage()));
318+
return new ErrPacket(1, ex.getErrorCode(), "", ex.getSQLState(), ex.getMessage());
319+
}
320+
}
321+
322+
private void cleanJDBCResources() {
323+
for (ResultSet each : resultSets) {
324+
if (null != each) {
325+
try {
326+
each.close();
327+
} catch (final SQLException ignore) {
328+
}
329+
}
330+
}
331+
for (Connection each : connections) {
332+
if (null != each) {
333+
try {
334+
each.close();
335+
} catch (final SQLException ignore) {
336+
}
337+
}
280338
}
281-
result.addPacket(new EofPacket(++currentSequenceId, 0, StatusFlag.SERVER_STATUS_AUTOCOMMIT.getValue()));
282-
return result;
283339
}
284340
}

sharding-proxy/src/main/java/io/shardingjdbc/proxy/backend/mysql/MySQLPacketStatementExecuteQueryResult.java

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,21 @@
1919

2020
import io.shardingjdbc.core.merger.QueryResult;
2121
import io.shardingjdbc.proxy.transport.common.packet.DatabaseProtocolPacket;
22+
import io.shardingjdbc.proxy.transport.mysql.constant.ColumnType;
2223
import io.shardingjdbc.proxy.transport.mysql.packet.command.CommandResponsePackets;
2324
import io.shardingjdbc.proxy.transport.mysql.packet.command.statement.execute.BinaryResultSetRowPacket;
2425
import io.shardingjdbc.proxy.transport.mysql.packet.command.text.query.ColumnDefinition41Packet;
2526
import io.shardingjdbc.proxy.transport.mysql.packet.command.text.query.FieldCountPacket;
2627
import lombok.RequiredArgsConstructor;
2728

2829
import java.io.InputStream;
30+
import java.sql.ResultSet;
31+
import java.sql.SQLException;
32+
import java.util.ArrayList;
2933
import java.util.Calendar;
3034
import java.util.HashMap;
3135
import java.util.Iterator;
36+
import java.util.List;
3237
import java.util.Map;
3338

3439
/**
@@ -41,15 +46,19 @@ public final class MySQLPacketStatementExecuteQueryResult implements QueryResult
4146

4247
private final int columnCount;
4348

49+
private final List<ColumnType> columnTypes;
50+
4451
private final Map<Integer, String> columnIndexAndLabelMap;
4552

4653
private final Map<String, Integer> columnLabelAndIndexMap;
4754

48-
private final Iterator<DatabaseProtocolPacket> data;
55+
private final ResultSet resultSet;
56+
57+
private int currentSequenceId;
4958

5059
private BinaryResultSetRowPacket currentRow;
5160

52-
public MySQLPacketStatementExecuteQueryResult(final CommandResponsePackets packets) {
61+
public MySQLPacketStatementExecuteQueryResult(final CommandResponsePackets packets, final ResultSet resultSet, final List<ColumnType> columnTypes) {
5362
Iterator<DatabaseProtocolPacket> packetIterator = packets.getDatabaseProtocolPackets().iterator();
5463
columnCount = ((FieldCountPacket) packetIterator.next()).getColumnCount();
5564
columnIndexAndLabelMap = new HashMap<>(columnCount, 1);
@@ -59,15 +68,18 @@ public MySQLPacketStatementExecuteQueryResult(final CommandResponsePackets packe
5968
columnIndexAndLabelMap.put(i, columnDefinition41Packet.getName());
6069
columnLabelAndIndexMap.put(columnDefinition41Packet.getName(), i);
6170
}
62-
packetIterator.next();
63-
data = packetIterator;
71+
this.resultSet = resultSet;
72+
this.columnTypes = columnTypes;
6473
}
6574

6675
@Override
67-
public boolean next() {
68-
DatabaseProtocolPacket databaseProtocolPacket = data.next();
69-
if (databaseProtocolPacket instanceof BinaryResultSetRowPacket) {
70-
currentRow = (BinaryResultSetRowPacket) databaseProtocolPacket;
76+
public boolean next() throws SQLException {
77+
if (resultSet.next()) {
78+
List<Object> data = new ArrayList<>(columnCount);
79+
for (int i = 1; i <= columnCount; i++) {
80+
data.add(resultSet.getObject(i));
81+
}
82+
currentRow = new BinaryResultSetRowPacket(++currentSequenceId, columnCount, data, columnTypes);
7183
return true;
7284
}
7385
return false;

sharding-proxy/src/main/java/io/shardingjdbc/proxy/frontend/mysql/MySQLFrontendHandler.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,11 @@ public void run() {
6969
int sequenceId = mysqlPacketPayload.readInt1();
7070
CommandPacket commandPacket = CommandPacketFactory.getCommandPacket(sequenceId, mysqlPacketPayload);
7171
for (DatabaseProtocolPacket each : commandPacket.execute().getDatabaseProtocolPackets()) {
72-
context.write(each);
72+
context.writeAndFlush(each);
73+
}
74+
while (commandPacket.hasMoreResultValue()) {
75+
context.writeAndFlush(commandPacket.getResultValue());
7376
}
74-
context.flush();
7577
}
7678
});
7779
}

sharding-proxy/src/main/java/io/shardingjdbc/proxy/transport/mysql/packet/MySQLPacketPayload.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
* @see <a href="https://dev.mysql.com/doc/internals/en/binary-protocol-value.html">binary protocol value</a>
3434
*
3535
* @author zhangliang
36+
* @author zhangyonglun
3637
*/
3738
@RequiredArgsConstructor
3839
@Getter

sharding-proxy/src/main/java/io/shardingjdbc/proxy/transport/mysql/packet/command/CommandPacket.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package io.shardingjdbc.proxy.transport.mysql.packet.command;
1919

20+
import io.shardingjdbc.proxy.transport.common.packet.DatabaseProtocolPacket;
2021
import io.shardingjdbc.proxy.transport.mysql.packet.MySQLPacket;
2122

2223
/**
@@ -36,4 +37,18 @@ public CommandPacket(final int sequenceId) {
3637
* @return result packets to be sent
3738
*/
3839
public abstract CommandResponsePackets execute();
40+
41+
/**
42+
* Has more result value.
43+
*
44+
* @return has more result value
45+
*/
46+
public abstract boolean hasMoreResultValue();
47+
48+
/**
49+
* Get result value.
50+
*
51+
* @return result to be sent
52+
*/
53+
public abstract DatabaseProtocolPacket getResultValue();
3954
}

0 commit comments

Comments
 (0)