Skip to content

Commit a929e9c

Browse files
committed
Implement allocation-friendly method to get user count in SimpUserRegistry
SPR-14930
1 parent da63898 commit a929e9c

File tree

6 files changed

+37
-10
lines changed

6 files changed

+37
-10
lines changed

spring-messaging/src/main/java/org/springframework/messaging/simp/user/MultiServerUserRegistry.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,16 @@ public Set<SimpUser> getUsers() {
135135
return result;
136136
}
137137

138+
@Override
139+
public int getUserCount() {
140+
int userCount = 0;
141+
for (UserRegistrySnapshot registry : this.remoteRegistries.values()) {
142+
userCount += registry.getUserMap().size();
143+
}
144+
userCount += this.localRegistry.getUserCount();
145+
return userCount;
146+
}
147+
138148
@Override
139149
public Set<SimpSubscription> findSubscriptions(SimpSubscriptionMatcher matcher) {
140150
Set<SimpSubscription> result = new HashSet<>();

spring-messaging/src/main/java/org/springframework/messaging/simp/user/SimpUserRegistry.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,13 @@ public interface SimpUserRegistry {
4040
*/
4141
Set<SimpUser> getUsers();
4242

43+
/**
44+
* Return the count of all connected users.
45+
* @return the connected user count.
46+
* @since 4.3.5
47+
*/
48+
int getUserCount();
49+
4350
/**
4451
* Find subscriptions with the given matcher.
4552
* @param matcher the matcher to use

spring-messaging/src/test/java/org/springframework/messaging/simp/user/MultiServerUserRegistryTests.java

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,10 @@ public void getUserFromLocalRegistry() throws Exception {
6161
SimpUser user = Mockito.mock(SimpUser.class);
6262
Set<SimpUser> users = Collections.singleton(user);
6363
when(this.localRegistry.getUsers()).thenReturn(users);
64+
when(this.localRegistry.getUserCount()).thenReturn(1);
6465
when(this.localRegistry.getUser("joe")).thenReturn(user);
6566

66-
assertEquals(1, this.registry.getUsers().size());
67+
assertEquals(1, this.registry.getUserCount());
6768
assertSame(user, this.registry.getUser("joe"));
6869
}
6970

@@ -84,7 +85,7 @@ public void getUserFromRemoteRegistry() throws Exception {
8485
this.registry.addRemoteRegistryDto(message, this.converter, 20000);
8586

8687

87-
assertEquals(1, this.registry.getUsers().size());
88+
assertEquals(1, this.registry.getUserCount());
8889
SimpUser user = this.registry.getUser("joe");
8990
assertNotNull(user);
9091
assertTrue(user.hasSessions());
@@ -125,7 +126,7 @@ public void findSubscriptionsFromRemoteRegistry() throws Exception {
125126
this.registry.addRemoteRegistryDto(message, this.converter, 20000);
126127

127128

128-
assertEquals(3, this.registry.getUsers().size());
129+
assertEquals(3, this.registry.getUserCount());
129130
Set<SimpSubscription> matches = this.registry.findSubscriptions(s -> s.getDestination().equals("/match"));
130131
assertEquals(2, matches.size());
131132
Iterator<SimpSubscription> iterator = matches.iterator();
@@ -157,7 +158,7 @@ public void getSessionsWhenUserIsConnectedToMultipleServers() throws Exception {
157158
this.registry.addRemoteRegistryDto(message, this.converter, 20000);
158159

159160

160-
assertEquals(1, this.registry.getUsers().size());
161+
assertEquals(1, this.registry.getUserCount());
161162
SimpUser user = this.registry.getUsers().iterator().next();
162163
assertTrue(user.hasSessions());
163164
assertEquals(2, user.getSessions().size());
@@ -187,9 +188,9 @@ public void purgeExpiredRegistries() throws Exception {
187188
this.registry.addRemoteRegistryDto(message, this.converter, -1);
188189

189190

190-
assertEquals(1, this.registry.getUsers().size());
191+
assertEquals(1, this.registry.getUserCount());
191192
this.registry.purgeExpiredRegistries();
192-
assertEquals(0, this.registry.getUsers().size());
193+
assertEquals(0, this.registry.getUserCount());
193194
}
194195

195196
}

spring-messaging/src/test/java/org/springframework/messaging/simp/user/UserRegistryMessageHandlerTests.java

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ public void broadcastRegistry() throws Exception {
126126

127127
MultiServerUserRegistry remoteRegistry = new MultiServerUserRegistry(mock(SimpUserRegistry.class));
128128
remoteRegistry.addRemoteRegistryDto(message, this.converter, 20000);
129-
assertEquals(2, remoteRegistry.getUsers().size());
129+
assertEquals(2, remoteRegistry.getUserCount());
130130
assertNotNull(remoteRegistry.getUser("joe"));
131131
assertNotNull(remoteRegistry.getUser("jane"));
132132
}
@@ -142,14 +142,15 @@ public void handleMessage() throws Exception {
142142

143143
HashSet<SimpUser> simpUsers = new HashSet<>(Arrays.asList(simpUser1, simpUser2));
144144
SimpUserRegistry remoteUserRegistry = mock(SimpUserRegistry.class);
145+
when(remoteUserRegistry.getUserCount()).thenReturn(2);
145146
when(remoteUserRegistry.getUsers()).thenReturn(simpUsers);
146147

147148
MultiServerUserRegistry remoteRegistry = new MultiServerUserRegistry(remoteUserRegistry);
148149
Message<?> message = this.converter.toMessage(remoteRegistry.getLocalRegistryDto(), null);
149150

150151
this.handler.handleMessage(message);
151152

152-
assertEquals(2, remoteRegistry.getUsers().size());
153+
assertEquals(2, remoteRegistry.getUserCount());
153154
assertNotNull(this.multiServerRegistry.getUser("joe"));
154155
assertNotNull(this.multiServerRegistry.getUser("jane"));
155156
}
@@ -159,13 +160,14 @@ public void handleMessageFromOwnBroadcast() throws Exception {
159160

160161
TestSimpUser simpUser = new TestSimpUser("joe");
161162
simpUser.addSessions(new TestSimpSession("123"));
163+
when(this.localRegistry.getUserCount()).thenReturn(1);
162164
when(this.localRegistry.getUsers()).thenReturn(Collections.singleton(simpUser));
163165

164-
assertEquals(1, this.multiServerRegistry.getUsers().size());
166+
assertEquals(1, this.multiServerRegistry.getUserCount());
165167

166168
Message<?> message = this.converter.toMessage(this.multiServerRegistry.getLocalRegistryDto(), null);
167169
this.multiServerRegistry.addRemoteRegistryDto(message, this.converter, 20000);
168-
assertEquals(1, this.multiServerRegistry.getUsers().size());
170+
assertEquals(1, this.multiServerRegistry.getUserCount());
169171
}
170172

171173

spring-websocket/src/main/java/org/springframework/web/socket/messaging/DefaultSimpUserRegistry.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,11 @@ public Set<SimpUser> getUsers() {
141141
return new HashSet<>(this.users.values());
142142
}
143143

144+
@Override
145+
public int getUserCount() {
146+
return this.users.size();
147+
}
148+
144149
public Set<SimpSubscription> findSubscriptions(SimpSubscriptionMatcher matcher) {
145150
Set<SimpSubscription> result = new HashSet<>();
146151
for (LocalSimpSession session : this.sessions.values()) {

spring-websocket/src/test/java/org/springframework/web/socket/messaging/DefaultSimpUserRegistryTests.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ public void addOneSessionId() {
5757
SimpUser simpUser = registry.getUser("joe");
5858
assertNotNull(simpUser);
5959

60+
assertEquals(1, registry.getUserCount());
6061
assertEquals(1, simpUser.getSessions().size());
6162
assertNotNull(simpUser.getSession("123"));
6263
}
@@ -82,6 +83,7 @@ public void addMultipleSessionIds() {
8283
SimpUser simpUser = registry.getUser("joe");
8384
assertNotNull(simpUser);
8485

86+
assertEquals(1, registry.getUserCount());
8587
assertEquals(3, simpUser.getSessions().size());
8688
assertNotNull(simpUser.getSession("123"));
8789
assertNotNull(simpUser.getSession("456"));

0 commit comments

Comments
 (0)