- import org.springframework.context.annotation.Bean;
- import org.springframework.context.annotation.Configuration;
- import org.springframework.web.socket.server.standard.ServerEndpointExporter;
-
- /**
- * websocket 配置
- *
- * @author ruoyi
- */
- @Configuration
- public class WebSocketConfig
- {
- @Bean
- public ServerEndpointExporter serverEndpointExporter()
- {
- return new ServerEndpointExporter();
- }
- }
工具类 SemaphoreUtils
- import java.util.concurrent.Semaphore;
- import org.slf4j.Logger;
- import org.slf4j.LoggerFactory;
-
- /**
- * 信号量相关处理
- *
- * @author ruoyi
- */
- public class SemaphoreUtils{
- /**
- * SemaphoreUtils 日志控制器
- */
- private static final Logger LOGGER = LoggerFactory.getLogger(SemaphoreUtils.class);
-
- /**
- * 获取信号量
- *
- * @param semaphore
- * @return
- */
- public static boolean tryAcquire(Semaphore semaphore)
- {
- boolean flag = false;
- try
- {
- flag = semaphore.tryAcquire();
- }
- catch (Exception e)
- {
- LOGGER.error("获取信号量异常", e);
- }
- return flag;
- }
-
- /**
- * 释放信号量
- *
- * @param semaphore
- */
- public static void release(Semaphore semaphore)
- {
- try
- {
- semaphore.release();
- }
- catch (Exception e)
- {
- LOGGER.error("释放信号量异常", e);
- }
- }
- }
服务端类WebSocketServer
- import java.util.concurrent.Semaphore;
- import javax.websocket.OnClose;
- import javax.websocket.OnError;
- import javax.websocket.OnMessage;
- import javax.websocket.OnOpen;
- import javax.websocket.Session;
- import javax.websocket.server.ServerEndpoint;
- import com.lxh.demo.util.SemaphoreUtils;
- import org.slf4j.Logger;
- import org.slf4j.LoggerFactory;
- import org.springframework.stereotype.Component;
-
- /**
- * websocket 消息处理
- *
- * @author ruoyi
- */
- @Component
- @ServerEndpoint("/websocket/message")
- public class WebSocketServer
- {
- /**
- * WebSocketServer 日志控制器
- */
- private static final Logger LOGGER = LoggerFactory.getLogger(WebSocketServer.class);
-
- /**
- * 默认最多允许同时在线人数100
- */
- public static int socketMaxOnlineCount = 100;
-
- private static Semaphore socketSemaphore = new Semaphore(socketMaxOnlineCount);
-
- /**
- * 连接建立成功调用的方法
- */
- @OnOpen
- public void onOpen(Session session) throws Exception{
- boolean semaphoreFlag = false;
- // 尝试获取信号量
- semaphoreFlag = SemaphoreUtils.tryAcquire(socketSemaphore);
- if (!semaphoreFlag)
- {
- // 未获取到信号量
- LOGGER.error("\n 当前在线人数超过限制数- {}", socketMaxOnlineCount);
- WebSocketUsers.sendMessageToUserByText(session, "当前在线人数超过限制数:" + socketMaxOnlineCount);
- session.close();
- }
- else
- {
- // 添加用户
- WebSocketUsers.put(session.getId(), session);
- LOGGER.info("\n 建立连接 - {}", session);
- LOGGER.info("\n 当前人数 - {}", WebSocketUsers.getUsers().size());
- WebSocketUsers.sendMessageToUserByText(session, "连接成功");
- }
- }
-
- /**
- * 连接关闭时处理
- */
- @OnClose
- public void onClose(Session session)
- {
- LOGGER.info("\n 关闭连接 - {}", session);
- // 移除用户
- WebSocketUsers.remove(session.getId());
- // 获取到信号量则需释放
- SemaphoreUtils.release(socketSemaphore);
- }
-
- /**
- * 抛出异常时处理
- */
- @OnError
- public void onError(Session session, Throwable exception) throws Exception
- {
- if (session.isOpen())
- {
- // 关闭连接
- session.close();
- }
- String sessionId = session.getId();
- LOGGER.info("\n 连接异常 - {}", sessionId);
- LOGGER.info("\n 异常信息 - {}", exception);
- // 移出用户
- WebSocketUsers.remove(sessionId);
- // 获取到信号量则需释放
- SemaphoreUtils.release(socketSemaphore);
- }
-
- /**
- * 服务器接收到客户端消息时调用的方法
- */
- @OnMessage
- public void onMessage(String message, Session session)
- {
- String msg = message.replace("你", "我").replace("吗", "");
- WebSocketUsers.sendMessageToUserByText(session, msg);
- }
- }
WebSocketUsers工具类
- import java.io.IOException;
- import java.util.Collection;
- import java.util.Map;
- import java.util.Set;
- import java.util.concurrent.ConcurrentHashMap;
- import javax.websocket.Session;
- import org.slf4j.Logger;
- import org.slf4j.LoggerFactory;
-
- /**
- * websocket 客户端用户集
- *
- * @author ruoyi
- */
- public class WebSocketUsers
- {
- /**
- * WebSocketUsers 日志控制器
- */
- private static final Logger LOGGER = LoggerFactory.getLogger(WebSocketUsers.class);
-
- /**
- * 用户集
- */
- private static Map
USERS = new ConcurrentHashMap(); -
- /**
- * 存储用户
- *
- * @param key 唯一键
- * @param session 用户信息
- */
- public static void put(String key, Session session)
- {
- USERS.put(key, session);
- }
-
- /**
- * 移除用户
- *
- * @param session 用户信息
- *
- * @return 移除结果
- */
- public static boolean remove(Session session)
- {
- String key = null;
- boolean flag = USERS.containsValue(session);
- if (flag)
- {
- Set
> entries = USERS.entrySet(); - for (Map.Entry
entry : entries) - {
- Session value = entry.getValue();
- if (value.equals(session))
- {
- key = entry.getKey();
- break;
- }
- }
- }
- else
- {
- return true;
- }
- return remove(key);
- }
-
- /**
- * 移出用户
- *
- * @param key 键
- */
- public static boolean remove(String key)
- {
- LOGGER.info("\n 正在移出用户 - {}", key);
- Session remove = USERS.remove(key);
- if (remove != null)
- {
- boolean containsValue = USERS.containsValue(remove);
- LOGGER.info("\n 移出结果 - {}", containsValue ? "失败" : "成功");
- return containsValue;
- }
- else
- {
- return true;
- }
- }
-
- /**
- * 获取在线用户列表
- *
- * @return 返回用户集合
- */
- public static Map
getUsers() - {
- return USERS;
- }
-
- /**
- * 群发消息文本消息
- *
- * @param message 消息内容
- */
- public static void sendMessageToUsersByText(String message)
- {
- Collection
values = USERS.values(); - for (Session value : values)
- {
- sendMessageToUserByText(value, message);
- }
- }
-
- /**
- * 发送文本消息
- *
- * @param session 缓存
- * @param message 消息内容
- */
- public static void sendMessageToUserByText(Session session, String message)
- {
- if (session != null)
- {
- try
- {
- session.getBasicRemote().sendText(message);
- }
- catch (IOException e)
- {
- LOGGER.error("\n[发送消息异常]", e);
- }
- }
- else
- {
- LOGGER.info("\n[你已离线]");
- }
- }
- }
Html 页面代码
- html>
- <html lang="zh" xmlns:th="http://www.thymeleaf.org">
- <head>
- <meta charset="utf-8">
- <meta http-equiv="X-UA-Compatible" content="IE=edge">
- <title>测试界面title>
- head>
-
- <body>
-
- <div>
- <input type="text" style="width: 20%" value="ws://127.0.0.1/websocket/message" id="url">
- <button id="btn_join">连接button>
- <button id="btn_exit">断开button>
- div>
- <br/>
- <textarea id="message" cols="100" rows="9">textarea> <button id="btn_send">发送消息button>
- <br/>
- <br/>
- <textarea id="text_content" readonly="readonly" cols="100" rows="9">textarea>返回内容
- <br/>
- <br/>
- <script th:src="@{/js/jquery.min.js}" >script>
- <script type="text/javascript">
- $(document).ready(function(){
- var ws = null;
- // 连接
- $('#btn_join').click(function() {
- var url = $("#url").val();
- ws = new WebSocket(url);
- ws.onopen = function(event) {
- $('#text_content').append('已经打开连接!' + '\n');
- }
- ws.onmessage = function(event) {
- $('#text_content').append(event.data + '\n');
- }
- ws.onclose = function(event) {
- $('#text_content').append('已经关闭连接!' + '\n');
- }
- });
- // 发送消息
- $('#btn_send').click(function() {
- var message = $('#message').val();
- if (ws) {
- ws.send(message);
- } else {
- alert("未连接到服务器");
- }
- });
- //断开
- $('#btn_exit').click(function() {
- if (ws) {
- ws.close();
- ws = null;
- }
- });
- })
- script>
- body>
- html>
成功运行后,页面如下

注意此时没有走用户认证,那么就要对路径放行,因为若依框架用的是SpringSecurity,所以找到文件SecurityConfig.java ,进行路径放行

虽然按着上述步骤我们完成了浏览器(客户端)和Java(服务端)的WebSocket通信,但是我们不能限定哪些用户可以连接我们的服务端获取数据,服务端也不知道应该具体给哪些用户发送消息,在我们框架之前交互我们是通过浏览器传递toke 值来实现用户身份确认的,那么我们的WebSocket可不可以也这样呢?
很不幸的是 ws连接是无法像http一样完全自主定义请求头的,给token认证带来了不便,我们大致可以通过以下集中方式完成用户认证
1、将 token 明文携带在 url 中,例如ws://localhost:8080/weggo/websocket/message?Authorization=Bearer+token
2、通过websocket下的子协议来实现,Stomp这个协议来实现,前端采用SocketJs框架来实现对应定制请求头。实现携带authorization=Bearer +token 的需求,这样就可以正常建立连接
3、利用子协议数组,将 token 携带在 protocols 里,var ws = new WebSocket(url, ["token"]);
这样后端在 onOpen 事件中,就可以从 server 中读取 Sec-WebSocket-Protocol 属性来进行 token 的获取,具体可以参考WebScoket构造函数官方文档
- var aWebSocket = new WebSocket(url [, protocols]);
- url
- 要连接的URL;这应该是WebSocket服务器将响应的URL。
- protocols 可选
- 一个协议字符串或者一个包含协议字符串的数组。这些字符串用于指定子协议,这样单个服务器可以实现多个WebSocket子协议
- (例如,您可能希望一台服务器能够根据指定的协议(protocol)处理不同类型的交互)。如果不指定协议字符串,则假定为空字符串。
protocols对应的就是发起ws连接时, 携带在请求头中的Sec-WebSocket-Protocol属性, 服务端可以获取到此属性的值用于通信逻辑(即通信子协议,当然用来进行token认证也是完全没问题的),前端人员在请求头上携带sec-websocket-protocol=Bearer +token后台在请求到达oauth2之前进行拦截,然后将在请求头上添加Authorization=Bearer +token(key首字母大写),然后在响应头(respone)上添加sec-websocket-protocol=Bearer +token(不添加会报错)
方法3部分代码示例
- //前端
- var aWebSocket = new WebSocket(url ['用户token']);
-
- //后端
- @Override
- public void afterConnectionEstablished(WebSocketSession session) throws Exception {
- //这里就是我们所提交的token
- String submitedToken=session.getHandshakeHeaders().get("sec-websocket-protocol").get(0);
-
- //根据token取得登录用户信息(业务逻辑根据你自己的来处理)
- }
另外,如果需要在第一次握手前的时候就取得token,只需要在header里面取得就可以啦
- @Override
- public boolean beforeHandshake(ServerHttpRequest serverHttpRequest, ServerHttpResponse serverHttpResponse, WebSocketHandler webSocketHandler, Map
map) throws Exception { - System.out.println("准备握手");
- String submitedToken = serverHttpRequest.getHeaders().get("sec-websocket-protocol")
- return true;
- }
因为我的项目是APP 移动端与服务端进行交互,所以后来选择了最简单实现的方案一
首先要解决的就是在拦截器获取url 的token 信息,原框架只从head里面获取,所以需要稍加改动
找到TokenService.java文件里的getToken方法,改成如下,这样就可以获取url 中的token 了又不影响原来的Http 请求
- private String getToken(HttpServletRequest request)
- {
- String token = Optional.ofNullable(request.getHeader(header)).orElse(request.getParameter(header));
- if (StringUtils.isNotEmpty(token) && token.startsWith(Constants.TOKEN_PREFIX))
- {
- token = token.replace(Constants.TOKEN_PREFIX, "");
- }
- return token;
- }
接下来就是需要对我们的WebSocket类进行改造了,为了方便阅读,去除了WebSocketUsers类,添加了类变量webSocketSet来存储客户端对象
- import com.alibaba.fastjson2.JSON;
- import com.tongchuang.common.utils.SecurityUtils;
- import com.tongchuang.web.mqtt.domain.DeviceInfo;
- import io.netty.util.HashedWheelTimer;
- import io.netty.util.Timeout;
- import org.slf4j.Logger;
- import org.slf4j.LoggerFactory;
- import org.springframework.security.core.Authentication;
- import org.springframework.stereotype.Component;
-
- import javax.websocket.*;
- import javax.websocket.server.PathParam;
- import javax.websocket.server.ServerEndpoint;
- import java.io.IOException;
- import java.util.HashMap;
- import java.util.Map;
- import java.util.concurrent.CopyOnWriteArraySet;
- import java.util.concurrent.Semaphore;
- import java.util.concurrent.TimeUnit;
- import java.util.concurrent.atomic.AtomicInteger;
- import java.util.function.Function;
-
- /**
- * websocket 消息处理
- *
- * @author stronger
- */
- @Component
- @ServerEndpoint("/websocket/message")
- public class WebSocketServer {
- /*========================声明类变量,意在所有实例共享=================================================*/
- /**
- * WebSocketServer 日志控制器
- */
- private static final Logger LOGGER = LoggerFactory.getLogger(WebSocketServer.class);
-
- /**
- * 默认最多允许同时在线人数100
- */
- public static int socketMaxOnlineCount = 100;
-
- private static Semaphore socketSemaphore = new Semaphore(socketMaxOnlineCount);
-
- HashedWheelTimer timer = new HashedWheelTimer(1, TimeUnit.SECONDS, 8);
- /**
- * concurrent包的线程安全Set,用来存放每个客户端对应的MyWebSocket对象。
- */
- private static final CopyOnWriteArraySet
webSocketSet = new CopyOnWriteArraySet<>(); - /**
- * 连接数
- */
- private static final AtomicInteger count = new AtomicInteger();
-
- /*========================声明实例变量,意在每个实例独享=======================================================*/
- /**
- * 与某个客户端的连接会话,需要通过它来给客户端发送数据
- */
- private Session session;
- /**
- * 用户id
- */
- private String sid = "";
-
- /**
- * 连接建立成功调用的方法
- */
- @OnOpen
- public void onOpen(Session session) throws Exception {
- // 尝试获取信号量
- boolean semaphoreFlag = SemaphoreUtils.tryAcquire(socketSemaphore);
- if (!semaphoreFlag) {
- // 未获取到信号量
- LOGGER.error("\n 当前在线人数超过限制数- {}", socketMaxOnlineCount);
- // 给当前Session 登录用户发送消息
- sendMessageToUserByText(session, "当前在线人数超过限制数:" + socketMaxOnlineCount);
- session.close();
- } else {
- // 返回此会话的经过身份验证的用户,如果此会话没有经过身份验证的用户,则返回null
- Authentication authentication = (Authentication) session.getUserPrincipal();
- SecurityUtils.setAuthentication(authentication);
- String username = SecurityUtils.getUsername();
- this.session = session;
- //如果存在就先删除一个,防止重复推送消息
- for (WebSocketServer webSocket : webSocketSet) {
- if (webSocket.sid.equals(username)) {
- webSocketSet.remove(webSocket);
- count.getAndDecrement();
- }
- }
- count.getAndIncrement();
- webSocketSet.add(this);
- this.sid = username;
- LOGGER.info("\n 当前人数 - {}", count);
- sendMessageToUserByText(session, "连接成功");
- }
- }
-
- /**
- * 连接关闭时处理
- */
- @OnClose
- public void onClose(Session session) {
- LOGGER.info("\n 关闭连接 - {}", session);
- // 移除用户
- webSocketSet.remove(session);
- // 获取到信号量则需释放
- SemaphoreUtils.release(socketSemaphore);
- }
-
- /**
- * 抛出异常时处理
- */
- @OnError
- public void onError(Session session, Throwable exception) throws Exception {
- if (session.isOpen()) {
- // 关闭连接
- session.close();
- }
- String sessionId = session.getId();
- LOGGER.info("\n 连接异常 - {}", sessionId);
- LOGGER.info("\n 异常信息 - {}", exception);
- // 移出用户
- webSocketSet.remove(session);
- // 获取到信号量则需释放
- SemaphoreUtils.release(socketSemaphore);
- }
-
- /**
- * 服务器接收到客户端消息时调用的方法
- */
- @OnMessage
- public void onMessage(String message, Session session) {
- Authentication authentication = (Authentication) session.getUserPrincipal();
- LOGGER.info("收到来自" + sid + "的信息:" + message);
- // 实时更新
- this.refresh(sid, authentication);
- sendMessageToUserByText(session, "我收到了你的新消息哦");
- }
-
- /**
- * 刷新定时任务,发送信息
- */
- private void refresh(String userId, Authentication authentication) {
- this.start(5000L, task -> {
- // 判断用户是否在线,不在线则不用处理,因为在内部无法关闭该定时任务,所以通过返回值在外部进行判断。
- if (WebSocketServer.isConn(userId)) {
- // 因为这里是长链接,不会和普通网页一样,每次发送http 请求可以走拦截器【doFilterInternal】续约,所以需要手动续约
- SecurityUtils.setAuthentication(authentication);
- // 从数据库或者缓存中获取信息,构建自定义的Bean
- DeviceInfo deviceInfo = DeviceInfo.builder().Macaddress("de5a735951ee").Imei("351517175516665")
- .Battery("99").Charge("0").Latitude("116.402649").Latitude("39.914859").Altitude("80")
- .Method(SecurityUtils.getUsername()).build();
- // TODO判断数据是否有更新
- // 发送最新数据给前端
- WebSocketServer.sendInfo("JSON", deviceInfo, userId);
- // 设置返回值,判断是否需要继续执行
- return true;
- }
- return false;
- });
- }
-
- private void start(long delay, Function
function) { - timer.newTimeout(t -> {
- // 获取返回值,判断是否执行
- Boolean result = function.apply(t);
- if (result) {
- timer.newTimeout(t.task(), delay, TimeUnit.MILLISECONDS);
- }
- }, delay, TimeUnit.MILLISECONDS);
- }
-
- /**
- * 判断是否有链接
- *
- * @return
- */
- public static boolean isConn(String sid) {
- for (WebSocketServer item : webSocketSet) {
- if (item.sid.equals(sid)) {
- return true;
- }
- }
- return false;
- }
-
- /**
- * 群发自定义消息
- * 或者指定用户发送消息
- */
- public static void sendInfo(String type, Object data, @PathParam("sid") String sid) {
- // 遍历WebSocketServer对象集合,如果符合条件就推送
- for (WebSocketServer item : webSocketSet) {
- try {
- //这里可以设定只推送给这个sid的,为null则全部推送
- if (sid == null) {
- item.sendMessage(type, data);
- } else if (item.sid.equals(sid)) {
- item.sendMessage(type, data);
- }
- } catch (IOException ignored) {
- }
- }
- }
-
- /**
- * 实现服务器主动推送
- */
- private void sendMessage(String type, Object data) throws IOException {
- Map
result = new HashMap<>(); - result.put("type", type);
- result.put("data", data);
- this.session.getAsyncRemote().sendText(JSON.toJSONString(result));
- }
-
- /**
- * 实现服务器主动推送-根据session
- */
- public static void sendMessageToUserByText(Session session, String message) {
- if (session != null) {
- try {
- session.getBasicRemote().sendText(message);
- } catch (IOException e) {
- LOGGER.error("\n[发送消息异常]", e);
- }
- } else {
- LOGGER.info("\n[你已离线]");
- }
- }
- }

- public class SecurityUtils
- {
-
- public static void setAuthentication(Authentication authentication) {
- SecurityContextHolder.getContext().setAuthentication(authentication);
- }
- }