Skip to content

Commit c28ce0e

Browse files
committed
Ensure WebSocketHandlerEndpoint can connect only once
WebSocketHandlerEndpoint and SockJsWebSocketHandler are stateful wrappers that are not intended to be used with one client connection.
1 parent db4de52 commit c28ce0e

File tree

3 files changed

+57
-41
lines changed

3 files changed

+57
-41
lines changed

spring-websocket/src/main/java/org/springframework/sockjs/AbstractSockJsSession.java

+8-5
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ public abstract class AbstractSockJsSession implements WebSocketSession {
4343

4444
private final HandlerProvider<WebSocketHandler> handlerProvider;
4545

46-
private final TextMessageHandler handler;
46+
private TextMessageHandler handler;
4747

4848
private State state = State.NEW;
4949

@@ -61,10 +61,6 @@ public AbstractSockJsSession(String sessionId, HandlerProvider<WebSocketHandler>
6161
Assert.notNull(sessionId, "sessionId is required");
6262
Assert.notNull(handlerProvider, "handlerProvider is required");
6363
this.sessionId = sessionId;
64-
65-
WebSocketHandler webSocketHandler = handlerProvider.getHandler();
66-
Assert.isInstanceOf(TextMessageHandler.class, webSocketHandler, "Expected a TextMessageHandler");
67-
this.handler = (TextMessageHandler) webSocketHandler;
6864
this.handlerProvider = handlerProvider;
6965
}
7066

@@ -127,9 +123,16 @@ protected void updateLastActiveTime() {
127123

128124
public void delegateConnectionEstablished() throws Exception {
129125
this.state = State.OPEN;
126+
initHandler();
130127
this.handler.afterConnectionEstablished(this);
131128
}
132129

130+
private void initHandler() {
131+
WebSocketHandler webSocketHandler = handlerProvider.getHandler();
132+
Assert.isInstanceOf(TextMessageHandler.class, webSocketHandler, "Expected a TextMessageHandler");
133+
this.handler = (TextMessageHandler) webSocketHandler;
134+
}
135+
133136
public void delegateMessages(String[] messages) throws Exception {
134137
for (String message : messages) {
135138
this.handler.handleTextMessage(new TextMessage(message), this);

spring-websocket/src/main/java/org/springframework/sockjs/server/transport/SockJsWebSocketHandler.java

+8-2
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package org.springframework.sockjs.server.transport;
1818

1919
import java.io.IOException;
20+
import java.util.concurrent.atomic.AtomicInteger;
2021

2122
import org.apache.commons.logging.Log;
2223
import org.apache.commons.logging.LogFactory;
@@ -53,6 +54,8 @@ public class SockJsWebSocketHandler implements TextMessageHandler {
5354

5455
private AbstractSockJsSession session;
5556

57+
private final AtomicInteger sessionCount = new AtomicInteger(0);
58+
5659
// TODO: JSON library used must be configurable
5760
private final ObjectMapper objectMapper = new ObjectMapper();
5861

@@ -70,6 +73,7 @@ protected SockJsConfiguration getSockJsConfig() {
7073

7174
@Override
7275
public void afterConnectionEstablished(WebSocketSession wsSession) throws Exception {
76+
Assert.isTrue(this.sessionCount.compareAndSet(0, 1), "Unexpected connection");
7377
this.session = new WebSocketServerSockJsSession(wsSession, getSockJsConfig());
7478
}
7579

@@ -80,14 +84,16 @@ public void handleTextMessage(TextMessage message, WebSocketSession wsSession) t
8084
logger.trace("Ignoring empty message");
8185
return;
8286
}
87+
String[] messages;
8388
try {
84-
String[] messages = this.objectMapper.readValue(payload, String[].class);
85-
this.session.delegateMessages(messages);
89+
messages = this.objectMapper.readValue(payload, String[].class);
8690
}
8791
catch (IOException e) {
8892
logger.error("Broken data received. Terminating WebSocket connection abruptly", e);
8993
wsSession.close();
94+
return;
9095
}
96+
this.session.delegateMessages(messages);
9197
}
9298

9399
@Override

spring-websocket/src/main/java/org/springframework/websocket/endpoint/WebSocketHandlerEndpoint.java

+41-34
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
package org.springframework.websocket.endpoint;
1818

19+
import java.util.concurrent.atomic.AtomicInteger;
20+
1921
import javax.websocket.CloseReason;
2022
import javax.websocket.Endpoint;
2123
import javax.websocket.EndpointConfig;
@@ -51,6 +53,8 @@ public class WebSocketHandlerEndpoint extends Endpoint {
5153

5254
private WebSocketSession webSocketSession;
5355

56+
private final AtomicInteger sessionCount = new AtomicInteger(0);
57+
5458

5559
public WebSocketHandlerEndpoint(HandlerProvider<WebSocketHandler> handlerProvider) {
5660
Assert.notNull(handlerProvider, "handlerProvider is required");
@@ -59,48 +63,54 @@ public WebSocketHandlerEndpoint(HandlerProvider<WebSocketHandler> handlerProvide
5963

6064
@Override
6165
public void onOpen(final javax.websocket.Session session, EndpointConfig config) {
66+
67+
Assert.isTrue(this.sessionCount.compareAndSet(0, 1), "Unexpected connection");
68+
6269
if (logger.isDebugEnabled()) {
63-
logger.debug("Client connected, WebSocket session id=" + session.getId() + ", uri=" + session.getRequestURI());
70+
logger.debug("Client connected, javax.websocket.Session id="
71+
+ session.getId() + ", uri=" + session.getRequestURI());
6472
}
65-
try {
66-
this.handler = handlerProvider.getHandler();
67-
this.webSocketSession = new StandardWebSocketSession(session);
6873

69-
if (this.handler instanceof TextMessageHandler) {
70-
session.addMessageHandler(new MessageHandler.Whole<String>() {
74+
this.webSocketSession = new StandardWebSocketSession(session);
75+
this.handler = handlerProvider.getHandler();
76+
77+
if (this.handler instanceof TextMessageHandler) {
78+
session.addMessageHandler(new MessageHandler.Whole<String>() {
79+
@Override
80+
public void onMessage(String message) {
81+
handleTextMessage(session, message);
82+
}
83+
});
84+
}
85+
else if (this.handler instanceof BinaryMessageHandler) {
86+
if (this.handler instanceof PartialMessageHandler) {
87+
session.addMessageHandler(new MessageHandler.Partial<byte[]>() {
7188
@Override
72-
public void onMessage(String message) {
73-
handleTextMessage(session, message);
89+
public void onMessage(byte[] messagePart, boolean isLast) {
90+
handleBinaryMessage(session, messagePart, isLast);
7491
}
7592
});
7693
}
77-
else if (this.handler instanceof BinaryMessageHandler) {
78-
if (this.handler instanceof PartialMessageHandler) {
79-
session.addMessageHandler(new MessageHandler.Partial<byte[]>() {
80-
@Override
81-
public void onMessage(byte[] messagePart, boolean isLast) {
82-
handleBinaryMessage(session, messagePart, isLast);
83-
}
84-
});
85-
}
86-
else {
87-
session.addMessageHandler(new MessageHandler.Whole<byte[]>() {
88-
@Override
89-
public void onMessage(byte[] message) {
90-
handleBinaryMessage(session, message, true);
91-
}
92-
});
93-
}
94-
}
9594
else {
95+
session.addMessageHandler(new MessageHandler.Whole<byte[]>() {
96+
@Override
97+
public void onMessage(byte[] message) {
98+
handleBinaryMessage(session, message, true);
99+
}
100+
});
101+
}
102+
}
103+
else {
104+
if (logger.isWarnEnabled()) {
96105
logger.warn("WebSocketHandler handles neither text nor binary messages: " + this.handler);
97106
}
107+
}
98108

109+
try {
99110
this.handler.afterConnectionEstablished(this.webSocketSession);
100111
}
101112
catch (Throwable ex) {
102-
// TODO
103-
logger.error("Error while processing new session", ex);
113+
this.handler.handleError(ex, this.webSocketSession);
104114
}
105115
}
106116

@@ -113,8 +123,7 @@ private void handleTextMessage(javax.websocket.Session session, String message)
113123
((TextMessageHandler) handler).handleTextMessage(textMessage, this.webSocketSession);
114124
}
115125
catch (Throwable ex) {
116-
// TODO
117-
logger.error("Error while processing message", ex);
126+
this.handler.handleError(ex, this.webSocketSession);
118127
}
119128
}
120129

@@ -127,8 +136,7 @@ private void handleBinaryMessage(javax.websocket.Session session, byte[] message
127136
((BinaryMessageHandler) handler).handleBinaryMessage(binaryMessage, this.webSocketSession);
128137
}
129138
catch (Throwable ex) {
130-
// TODO
131-
logger.error("Error while processing message", ex);
139+
this.handler.handleError(ex, this.webSocketSession);
132140
}
133141
}
134142

@@ -142,7 +150,6 @@ public void onClose(javax.websocket.Session session, CloseReason reason) {
142150
this.handler.afterConnectionClosed(closeStatus, this.webSocketSession);
143151
}
144152
catch (Throwable ex) {
145-
// TODO
146153
logger.error("Error while processing session closing", ex);
147154
}
148155
finally {
@@ -157,7 +164,7 @@ public void onError(javax.websocket.Session session, Throwable exception) {
157164
this.handler.handleError(exception, this.webSocketSession);
158165
}
159166
catch (Throwable ex) {
160-
// TODO
167+
// TODO: close the session?
161168
logger.error("Failed to handle error", ex);
162169
}
163170
}

0 commit comments

Comments
 (0)