Skip to content

Commit 30d2f78

Browse files
committed
Modify return type of subProtocolWebSocketHandler bean
The @bean method now returns WebSocketHandler allowing it to be decorated via WebSocketHandlerDecorator.
1 parent 542b5b2 commit 30d2f78

File tree

5 files changed

+50
-22
lines changed

5 files changed

+50
-22
lines changed

spring-messaging/src/main/java/org/springframework/messaging/simp/config/AbstractStompEndpointRegistration.java

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,13 @@
2222
import org.springframework.scheduling.TaskScheduler;
2323
import org.springframework.util.Assert;
2424
import org.springframework.util.ObjectUtils;
25+
import org.springframework.web.socket.WebSocketHandler;
2526
import org.springframework.web.socket.server.DefaultHandshakeHandler;
2627
import org.springframework.web.socket.server.HandshakeHandler;
2728
import org.springframework.web.socket.server.config.SockJsServiceRegistration;
2829
import org.springframework.web.socket.sockjs.SockJsService;
2930
import org.springframework.web.socket.sockjs.transport.handler.WebSocketTransportHandler;
31+
import org.springframework.web.socket.support.WebSocketHandlerDecorator;
3032

3133

3234
/**
@@ -39,7 +41,7 @@ public abstract class AbstractStompEndpointRegistration<M> implements StompEndpo
3941

4042
private final String[] paths;
4143

42-
private final SubProtocolWebSocketHandler wsHandler;
44+
private final WebSocketHandler wsHandler;
4345

4446
private HandshakeHandler handshakeHandler;
4547

@@ -48,7 +50,7 @@ public abstract class AbstractStompEndpointRegistration<M> implements StompEndpo
4850
private final TaskScheduler sockJsTaskScheduler;
4951

5052

51-
public AbstractStompEndpointRegistration(String[] paths, SubProtocolWebSocketHandler webSocketHandler,
53+
public AbstractStompEndpointRegistration(String[] paths, WebSocketHandler webSocketHandler,
5254
TaskScheduler sockJsTaskScheduler) {
5355

5456
Assert.notEmpty(paths, "No paths specified");
@@ -115,20 +117,27 @@ private HandshakeHandler getOrCreateHandshakeHandler() {
115117
if (handler instanceof DefaultHandshakeHandler) {
116118
DefaultHandshakeHandler defaultHandshakeHandler = (DefaultHandshakeHandler) handler;
117119
if (ObjectUtils.isEmpty(defaultHandshakeHandler.getSupportedProtocols())) {
118-
Set<String> protocols = this.wsHandler.getSupportedProtocols();
120+
Set<String> protocols = findSubProtocolWebSocketHandler(this.wsHandler).getSupportedProtocols();
119121
defaultHandshakeHandler.setSupportedProtocols(protocols.toArray(new String[protocols.size()]));
120122
}
121123
}
122124

123125
return handler;
124126
}
125127

128+
private static SubProtocolWebSocketHandler findSubProtocolWebSocketHandler(WebSocketHandler webSocketHandler) {
129+
WebSocketHandler actual = (webSocketHandler instanceof WebSocketHandlerDecorator) ?
130+
((WebSocketHandlerDecorator) webSocketHandler).getLastHandler() : webSocketHandler;
131+
Assert.isInstanceOf(SubProtocolWebSocketHandler.class, actual,
132+
"No SubProtocolWebSocketHandler found: " + webSocketHandler);
133+
return (SubProtocolWebSocketHandler) actual;
134+
}
135+
126136
protected abstract void addSockJsServiceMapping(M mappings, SockJsService sockJsService,
127-
SubProtocolWebSocketHandler wsHandler, String pathPattern);
137+
WebSocketHandler wsHandler, String pathPattern);
128138

129139
protected abstract void addWebSocketHandlerMapping(M mappings,
130-
SubProtocolWebSocketHandler wsHandler, HandshakeHandler handshakeHandler, String path);
131-
140+
WebSocketHandler wsHandler, HandshakeHandler handshakeHandler, String path);
132141

133142

134143
private class StompSockJsServiceRegistration extends SockJsServiceRegistration {

spring-messaging/src/main/java/org/springframework/messaging/simp/config/ServletStompEndpointRegistration.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@
1616

1717
package org.springframework.messaging.simp.config;
1818

19-
import org.springframework.messaging.handler.websocket.SubProtocolWebSocketHandler;
2019
import org.springframework.scheduling.TaskScheduler;
2120
import org.springframework.util.LinkedMultiValueMap;
2221
import org.springframework.util.MultiValueMap;
2322
import org.springframework.web.HttpRequestHandler;
23+
import org.springframework.web.socket.WebSocketHandler;
2424
import org.springframework.web.socket.server.HandshakeHandler;
2525
import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler;
2626
import org.springframework.web.socket.sockjs.SockJsHttpRequestHandler;
@@ -38,8 +38,8 @@ public class ServletStompEndpointRegistration
3838
extends AbstractStompEndpointRegistration<MultiValueMap<HttpRequestHandler, String>> {
3939

4040

41-
public ServletStompEndpointRegistration(String[] paths, SubProtocolWebSocketHandler wsHandler,
42-
TaskScheduler sockJsTaskScheduler) {
41+
public ServletStompEndpointRegistration(String[] paths,
42+
WebSocketHandler wsHandler, TaskScheduler sockJsTaskScheduler) {
4343

4444
super(paths, wsHandler, sockJsTaskScheduler);
4545
}
@@ -51,15 +51,15 @@ protected MultiValueMap<HttpRequestHandler, String> createMappings() {
5151

5252
@Override
5353
protected void addSockJsServiceMapping(MultiValueMap<HttpRequestHandler, String> mappings,
54-
SockJsService sockJsService, SubProtocolWebSocketHandler wsHandler, String pathPattern) {
54+
SockJsService sockJsService, WebSocketHandler wsHandler, String pathPattern) {
5555

5656
SockJsHttpRequestHandler httpHandler = new SockJsHttpRequestHandler(sockJsService, wsHandler);
5757
mappings.add(httpHandler, pathPattern);
5858
}
5959

6060
@Override
6161
protected void addWebSocketHandlerMapping(MultiValueMap<HttpRequestHandler, String> mappings,
62-
SubProtocolWebSocketHandler wsHandler, HandshakeHandler handshakeHandler, String path) {
62+
WebSocketHandler wsHandler, HandshakeHandler handshakeHandler, String path) {
6363

6464
WebSocketHttpRequestHandler handler = new WebSocketHttpRequestHandler(wsHandler, handshakeHandler);
6565
mappings.add(handler, path);

spring-messaging/src/main/java/org/springframework/messaging/simp/config/ServletStompEndpointRegistry.java

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
import org.springframework.web.HttpRequestHandler;
3131
import org.springframework.web.servlet.handler.AbstractHandlerMapping;
3232
import org.springframework.web.servlet.handler.SimpleUrlHandlerMapping;
33+
import org.springframework.web.socket.WebSocketHandler;
34+
import org.springframework.web.socket.support.WebSocketHandlerDecorator;
3335

3436

3537
/**
@@ -40,7 +42,9 @@
4042
*/
4143
public class ServletStompEndpointRegistry implements StompEndpointRegistry {
4244

43-
private final SubProtocolWebSocketHandler wsHandler;
45+
private final WebSocketHandler webSocketHandler;
46+
47+
private final SubProtocolWebSocketHandler subProtocolWebSocketHandler;
4448

4549
private final StompProtocolHandler stompHandler;
4650

@@ -49,23 +53,36 @@ public class ServletStompEndpointRegistry implements StompEndpointRegistry {
4953
private final TaskScheduler sockJsScheduler;
5054

5155

52-
public ServletStompEndpointRegistry(SubProtocolWebSocketHandler webSocketHandler,
56+
public ServletStompEndpointRegistry(WebSocketHandler webSocketHandler,
5357
MutableUserQueueSuffixResolver userQueueSuffixResolver, TaskScheduler defaultSockJsTaskScheduler) {
5458

5559
Assert.notNull(webSocketHandler);
5660
Assert.notNull(userQueueSuffixResolver);
5761

58-
this.wsHandler = webSocketHandler;
62+
this.webSocketHandler = webSocketHandler;
63+
this.subProtocolWebSocketHandler = findSubProtocolWebSocketHandler(webSocketHandler);
5964
this.stompHandler = new StompProtocolHandler();
6065
this.stompHandler.setUserQueueSuffixResolver(userQueueSuffixResolver);
6166
this.sockJsScheduler = defaultSockJsTaskScheduler;
6267
}
6368

69+
private static SubProtocolWebSocketHandler findSubProtocolWebSocketHandler(WebSocketHandler webSocketHandler) {
70+
71+
WebSocketHandler actual = (webSocketHandler instanceof WebSocketHandlerDecorator) ?
72+
((WebSocketHandlerDecorator) webSocketHandler).getLastHandler() : webSocketHandler;
73+
74+
Assert.isInstanceOf(SubProtocolWebSocketHandler.class, actual,
75+
"No SubProtocolWebSocketHandler found: " + webSocketHandler);
76+
77+
return (SubProtocolWebSocketHandler) actual;
78+
}
79+
6480

6581
@Override
6682
public StompEndpointRegistration addEndpoint(String... paths) {
67-
this.wsHandler.addProtocolHandler(this.stompHandler);
68-
ServletStompEndpointRegistration r = new ServletStompEndpointRegistration(paths, this.wsHandler, this.sockJsScheduler);
83+
this.subProtocolWebSocketHandler.addProtocolHandler(this.stompHandler);
84+
ServletStompEndpointRegistration r = new ServletStompEndpointRegistration(
85+
paths, this.webSocketHandler, this.sockJsScheduler);
6986
this.registrations.add(r);
7087
return r;
7188
}

spring-messaging/src/main/java/org/springframework/messaging/simp/config/WebSocketMessageBrokerConfigurationSupport.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler;
3535
import org.springframework.web.servlet.HandlerMapping;
3636
import org.springframework.web.servlet.handler.AbstractHandlerMapping;
37+
import org.springframework.web.socket.WebSocketHandler;
3738
import org.springframework.web.socket.server.config.SockJsServiceRegistration;
3839

3940

@@ -65,7 +66,7 @@ public HandlerMapping brokerWebSocketHandlerMapping() {
6566
}
6667

6768
@Bean
68-
public SubProtocolWebSocketHandler subProtocolWebSocketHandler() {
69+
public WebSocketHandler subProtocolWebSocketHandler() {
6970
SubProtocolWebSocketHandler wsHandler = new SubProtocolWebSocketHandler(webSocketRequestChannel());
7071
webSocketResponseChannel().subscribe(wsHandler);
7172
return wsHandler;

spring-messaging/src/test/java/org/springframework/messaging/simp/config/AbstractStompEndpointRegistrationTests.java

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.springframework.messaging.handler.websocket.SubProtocolWebSocketHandler;
2626
import org.springframework.messaging.support.channel.ExecutorSubscribableChannel;
2727
import org.springframework.scheduling.TaskScheduler;
28+
import org.springframework.web.socket.WebSocketHandler;
2829
import org.springframework.web.socket.server.DefaultHandshakeHandler;
2930
import org.springframework.web.socket.server.HandshakeHandler;
3031
import org.springframework.web.socket.sockjs.SockJsService;
@@ -122,13 +123,13 @@ protected List<Mapping> createMappings() {
122123

123124
@Override
124125
protected void addSockJsServiceMapping(List<Mapping> mappings, SockJsService sockJsService,
125-
SubProtocolWebSocketHandler wsHandler, String pathPattern) {
126+
WebSocketHandler wsHandler, String pathPattern) {
126127

127128
mappings.add(new Mapping(wsHandler, pathPattern, sockJsService));
128129
}
129130

130131
@Override
131-
protected void addWebSocketHandlerMapping(List<Mapping> mappings, SubProtocolWebSocketHandler wsHandler,
132+
protected void addWebSocketHandlerMapping(List<Mapping> mappings, WebSocketHandler wsHandler,
132133
HandshakeHandler handshakeHandler, String path) {
133134

134135
mappings.add(new Mapping(wsHandler, path, handshakeHandler));
@@ -137,22 +138,22 @@ protected void addWebSocketHandlerMapping(List<Mapping> mappings, SubProtocolWeb
137138

138139
private static class Mapping {
139140

140-
private final SubProtocolWebSocketHandler webSocketHandler;
141+
private final WebSocketHandler webSocketHandler;
141142

142143
private final String path;
143144

144145
private final HandshakeHandler handshakeHandler;
145146

146147
private final DefaultSockJsService sockJsService;
147148

148-
public Mapping(SubProtocolWebSocketHandler handler, String path, SockJsService sockJsService) {
149+
public Mapping(WebSocketHandler handler, String path, SockJsService sockJsService) {
149150
this.webSocketHandler = handler;
150151
this.path = path;
151152
this.handshakeHandler = null;
152153
this.sockJsService = (DefaultSockJsService) sockJsService;
153154
}
154155

155-
public Mapping(SubProtocolWebSocketHandler h, String path, HandshakeHandler hh) {
156+
public Mapping(WebSocketHandler h, String path, HandshakeHandler hh) {
156157
this.webSocketHandler = h;
157158
this.path = path;
158159
this.handshakeHandler = hh;

0 commit comments

Comments
 (0)