Skip to content

Commit 9ca4672

Browse files
committed
Fix handshake handling issue
1 parent 46bcffc commit 9ca4672

File tree

4 files changed

+31
-4
lines changed

4 files changed

+31
-4
lines changed

spring-websocket/src/main/java/org/springframework/websocket/CloseStatus.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,10 @@ public boolean equals(Object other) {
198198
return (this.code == otherStatus.code && ObjectUtils.nullSafeEquals(this.reason, otherStatus.reason));
199199
}
200200

201+
public boolean equalsCode(CloseStatus other) {
202+
return this.code == other.code;
203+
}
204+
201205
@Override
202206
public String toString() {
203207
return "CloseStatus [code=" + this.code + ", reason=" + this.reason + "]";

spring-websocket/src/main/java/org/springframework/websocket/client/endpoint/StandardWebSocketClient.java

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,12 @@
2727
import javax.websocket.ClientEndpointConfig.Configurator;
2828
import javax.websocket.ContainerProvider;
2929
import javax.websocket.Endpoint;
30+
import javax.websocket.HandshakeResponse;
3031
import javax.websocket.Session;
3132
import javax.websocket.WebSocketContainer;
3233

34+
import org.apache.commons.logging.Log;
35+
import org.apache.commons.logging.LogFactory;
3336
import org.springframework.http.HttpHeaders;
3437
import org.springframework.web.util.UriComponentsBuilder;
3538
import org.springframework.websocket.WebSocketHandler;
@@ -47,6 +50,8 @@
4750
*/
4851
public class StandardWebSocketClient implements WebSocketClient {
4952

53+
private static final Log logger = LogFactory.getLog(StandardWebSocketClient.class);
54+
5055
private static final Set<String> EXCLUDED_HEADERS = new HashSet<String>(
5156
Arrays.asList("Sec-WebSocket-Accept", "Sec-WebSocket-Extensions", "Sec-WebSocket-Key",
5257
"Sec-WebSocket-Protocol", "Sec-WebSocket-Version"));
@@ -83,9 +88,22 @@ public WebSocketSession doHandshake(WebSocketHandler webSocketHandler,
8388
public void beforeRequest(Map<String, List<String>> headers) {
8489
for (String headerName : httpHeaders.keySet()) {
8590
if (!EXCLUDED_HEADERS.contains(headerName)) {
86-
headers.put(headerName, httpHeaders.get(headerName));
91+
List<String> value = httpHeaders.get(headerName);
92+
if (logger.isTraceEnabled()) {
93+
logger.trace("Adding header [" + headerName + "=" + value + "]");
94+
}
95+
headers.put(headerName, value);
8796
}
8897
}
98+
if (logger.isTraceEnabled()) {
99+
logger.trace("Handshake request headers: " + headers);
100+
}
101+
}
102+
@Override
103+
public void afterResponse(HandshakeResponse handshakeResponse) {
104+
if (logger.isTraceEnabled()) {
105+
logger.trace("Handshake response headers: " + handshakeResponse.getHeaders());
106+
}
89107
}
90108
});
91109
}

spring-websocket/src/main/java/org/springframework/websocket/client/jetty/JettyWebSocketClient.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,8 @@ public WebSocketSession doHandshake(WebSocketHandler webSocketHandler, String ur
134134
public WebSocketSession doHandshake(WebSocketHandler webSocketHandler, HttpHeaders headers, URI uri)
135135
throws WebSocketConnectFailureException {
136136

137+
// TODO: populate headers
138+
137139
JettyWebSocketListenerAdapter listener = new JettyWebSocketListenerAdapter(webSocketHandler);
138140

139141
try {

spring-websocket/src/main/java/org/springframework/websocket/server/DefaultHandshakeHandler.java

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import java.nio.charset.Charset;
2121
import java.security.MessageDigest;
2222
import java.security.NoSuchAlgorithmException;
23+
import java.util.ArrayList;
2324
import java.util.Arrays;
2425
import java.util.Collections;
2526
import java.util.List;
@@ -34,6 +35,7 @@
3435
import org.springframework.http.server.ServerHttpRequest;
3536
import org.springframework.http.server.ServerHttpResponse;
3637
import org.springframework.util.ClassUtils;
38+
import org.springframework.util.CollectionUtils;
3739
import org.springframework.util.StringUtils;
3840
import org.springframework.websocket.WebSocketHandler;
3941

@@ -53,7 +55,7 @@ public class DefaultHandshakeHandler implements HandshakeHandler {
5355

5456
protected Log logger = LogFactory.getLog(getClass());
5557

56-
private List<String> supportedProtocols;
58+
private List<String> supportedProtocols = new ArrayList<String>();
5759

5860
private RequestUpgradeStrategy requestUpgradeStrategy;
5961

@@ -101,7 +103,8 @@ public final boolean doHandshake(ServerHttpRequest request, ServerHttpResponse r
101103
handleInvalidUpgradeHeader(request, response);
102104
return false;
103105
}
104-
if (!request.getHeaders().getConnection().contains("Upgrade")) {
106+
if (!request.getHeaders().getConnection().contains("Upgrade") &&
107+
!request.getHeaders().getConnection().contains("upgrade")) {
105108
handleInvalidConnectHeader(request, response);
106109
return false;
107110
}
@@ -188,7 +191,7 @@ protected boolean isValidOrigin(ServerHttpRequest request) {
188191
}
189192

190193
protected String selectProtocol(List<String> requestedProtocols) {
191-
if (requestedProtocols != null) {
194+
if (CollectionUtils.isEmpty(requestedProtocols)) {
192195
for (String protocol : requestedProtocols) {
193196
if (this.supportedProtocols.contains(protocol)) {
194197
return protocol;

0 commit comments

Comments
 (0)