org.java-websocket
Java-WebSocket
1.5.3
package com.test.demo.sse;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.scheduling.annotation.EnableAsync;
import org.springframework.scheduling.concurrent.CustomizableThreadFactory;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
/**
*
* AsyncConfig
*
* Description: 异步配置
*/
@EnableAsync
@Configuration
public class AsyncConfig {
/**
* 核心线程数(默认线程数)
*/
@Value("${sync.corePoolSize:50}")
private int corePoolSize;
/**
* 最大线程数
*/
@Value("${sync.maxPoolSize:200}")
private int maxPoolSize;
/**
* 缓冲队列数数量
*/
@Value("${sync.queueCapacity:10000000}")
private int queueCapacity;
@Bean
public Executor executor() {
ThreadPoolTaskExecutor taskExecutor = new ThreadPoolTaskExecutor();
// 核心线程数(默认线程数)
taskExecutor.setCorePoolSize(corePoolSize);
// 最大线程数
taskExecutor.setMaxPoolSize(maxPoolSize);
// 缓冲队列数,默认Integer.MAX_VALUE.
taskExecutor.setQueueCapacity(queueCapacity);
// 线程池名前缀
taskExecutor.setThreadNamePrefix("async-executor-");
// 允许线程空闲时间(单位:秒),默认:60
// taskExecutor.setKeepAliveSeconds(60);
// 线程池对拒绝任务的处理策略,默认值AbortPolicy
// taskExecutor.setRejectedExecutionHandler(new ThreadPoolExecutor.AbortPolicy());
// 初始化
taskExecutor.initialize();
return taskExecutor;
}
@Bean
public ScheduledExecutorService scheduledExecutorService() {
return Executors.newScheduledThreadPool(corePoolSize,
new CustomizableThreadFactory("schedule-executor-"));
}
}
package com.test.demo.sse;
import org.springframework.beans.BeansException;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.context.ApplicationEvent;
import org.springframework.stereotype.Component;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
import javax.servlet.http.HttpServletRequest;
/**
* Spring ApplicationContext 工具类
*/
@Component
public class SpringContextUtils implements ApplicationContextAware {
/**
* 上下文对象实例
*/
private static ApplicationContext applicationContext;
@Override
public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
SpringContextUtils.applicationContext = applicationContext;
}
/**
* 获取applicationContext
*
* @return
*/
public static ApplicationContext getApplicationContext() {
return applicationContext;
}
/**
* 获取HttpServletRequest
*/
public static HttpServletRequest getHttpServletRequest() {
return ((ServletRequestAttributes) RequestContextHolder.getRequestAttributes()).getRequest();
}
public static String getDomain() {
HttpServletRequest request = getHttpServletRequest();
StringBuffer url = request.getRequestURL();
return url.delete(url.length() - request.getRequestURI().length(), url.length()).toString();
}
public static String getOrigin() {
HttpServletRequest request = getHttpServletRequest();
return request.getHeader("Origin");
}
/**
* 通过name获取 Bean.
*
* @param name
* @return
*/
public static Object getBean(String name) {
return getApplicationContext().getBean(name);
}
/**
* 通过class获取Bean.
*
* @param clazz
* @param
* @return
*/
public static T getBean(Class clazz) {
return getApplicationContext().getBean(clazz);
}
/**
* 通过name,以及Clazz返回指定的Bean
*
* @param name
* @param clazz
* @param
* @return
*/
public static T getBean(String name, Class clazz) {
return getApplicationContext().getBean(name, clazz);
}
/**
* 发布事件
*
* @param event
*/
public static void publishEvent(ApplicationEvent event) {
if (applicationContext == null) {
return;
}
applicationContext.publishEvent(event);
}
}
package com.test.demo.sse;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import java.nio.charset.StandardCharsets;
import java.util.UUID;
import java.util.concurrent.*;
/**
*
* MySseEmitter
*
* Description: 解决SseEmitter浏览器下中文乱码问题
*/
@EqualsAndHashCode(callSuper = true)
@Data
@Slf4j
public class MySseEmitter extends SseEmitter {
/**
* websocket返回的所有信息,只用于将消息发送到前端
*/
private StringBuilder totalAnswer = new StringBuilder();
/**
* websocket返回的所有信息,用于最终存储的消息内容
*/
private StringBuilder totalAnswerStorage = new StringBuilder();
/**
* 链接是否已主动断开,,true:已主动断开,false:未断开
*/
private boolean disconnected = false;
/**
* 本条消息的唯一id
*/
private String messageUuid = UUID.randomUUID().toString();
/**
* 本次会话的唯一id
*/
private String conversationUuid = UUID.randomUUID().toString();
/**
* 是否匀速返回,true:需要匀速,false:不需要匀速
*/
private boolean speedControl;
/**
* 是否已经开始匀速返回信息,true:已经开始,false:还没有开始
*/
private boolean startSendMsgWithSpeedControl = false;
/**
* 已经发送的消息的长度
*/
private int sendLength = 0;
/**
* 所有消息是否已经全部匀速返回,true:已经全部返回,false:还没有全部返回
*/
private boolean endSendMsgWithSpeedControl = false;
/**
* 匀速吐字间隔时间,单位:毫秒
*/
private long sleepTime = 20L;
/**
* 匀速发送消息时每次返回多少个字符
*/
private int sendMsgSpeed = 1;
/**
* 超时时间,单位:毫秒
*/
private long timeout;
/**
* 当前登录人
*/
private String userUid;
/**
* 当前登录人的问题
*/
private String userQuestion;
/**
* 解决中文乱码
*
* @param outputMessage
*/
@Override
protected void extendResponse(ServerHttpResponse outputMessage) {
super.extendResponse(outputMessage);
HttpHeaders headers = outputMessage.getHeaders();
headers.setContentType(new MediaType(MediaType.TEXT_EVENT_STREAM, StandardCharsets.UTF_8));
}
/**
* 创建SSE对象
*
* @param speedControl 是否开启匀速,true:开启,false:关闭
* @param timeout 超时时间,单位:毫秒
* @param userUid 当前登录人
* @param userQuestion 当前登录人的问题
*/
public MySseEmitter(boolean speedControl, long timeout, String userUid, String userQuestion) {
// 设置超时时间,单位:毫秒
super(timeout);
this.speedControl = speedControl;
this.timeout = timeout;
this.userUid = userUid;
this.userQuestion = userQuestion;
}
/**
* 自定义发送消息方法
*
* @param message 具体消息
* @param msgStorage 本次发送的消息内容是否需要进行存储,true:需要,false:不需要
* @return 是否需要关闭链接,true:是,false:否
*/
public boolean mySend(String message, boolean msgStorage) {
try {
if (StringUtils.isNotEmpty(message)) {
// 处理换行,PC换行\r、\n、、\r\n都行,移动只能\r\n
message = message.replaceAll("\r", "\n")
.replaceAll("\n\n", "\n")
.replaceAll("\n\n", "\n")
.replaceAll("\n\n", "\n")
.replaceAll("\n\n", "\n")
.replaceAll("\n\n", "\n")
.replaceAll("\n\n", "\n")
.replaceAll("\n\n", "\n")
.replaceAll("\n\n", "\n")
.replaceAll("\n\n", "\n")
.replaceAll("\n\n", "\n")
.replaceAll("\n\n", "\n")
.replaceAll("\n\n", "\n")
.replaceAll("\n", "\r\n");
this.totalAnswer.append(message);
if (msgStorage) {
this.totalAnswerStorage.append(message);
}
if (!this.speedControl) {
super.send(message);
} else if (!this.startSendMsgWithSpeedControl && !this.disconnected) {
// 异步发送
SpringContextUtils.getBean(AsyncService.class).sendMsgWithSpeedControl(this);
}
}
} catch (Exception e) {
log.error("==>MySseEmitter send error,conversationUuid:{}", this.conversationUuid, e);
this.disconnected = true;
}
return this.disconnected;
}
/**
* 断开连接
*
* @param msgStorageType 消息处理类型,0:不存储,1:本地数据库存储
*/
public void myComplete(String msgStorageType) {
try {
if (!this.speedControl || this.disconnected || this.endSendMsgWithSpeedControl) {
super.complete();
} else {
Future> future = SpringContextUtils.getBean(ScheduledExecutorService.class).scheduleAtFixedRate(() -> {
if (this.endSendMsgWithSpeedControl) {
log.info("==>当前消息已全部返回完成,主动断开与端上链接,conversationUuid:{}", conversationUuid);
throw new RuntimeException("==>当前消息已全部返回完成,主动断开与端上链接,conversationUuid:" + conversationUuid);
}
}, this.sleepTime, this.sleepTime, TimeUnit.MILLISECONDS);
try {
// 超时时间,单位:毫秒
future.get(timeout, TimeUnit.MILLISECONDS);
} catch (TimeoutException | ExecutionException e) {
log.info("==>等待断开链接任务执行结束,conversationUuid:{}", conversationUuid);
// 取消任务
future.cancel(true);
} catch (Exception e) {
log.error("==>等待断开链接任务执行异常,conversationUuid:{}", conversationUuid, e);
// 取消任务
future.cancel(true);
}
super.complete();
}
} catch (Exception ignore) {
}
if ("1".equals(msgStorageType)) {
// 本地数据库存储消息,异步保存数据
SpringContextUtils.getBean(AsyncService.class).saveMsg(this.messageUuid, this.conversationUuid,
this.userUid, this.userQuestion, this.totalAnswerStorage.toString());
}
}
}
package com.test.demo.sse;
import lombok.extern.slf4j.Slf4j;
import org.java_websocket.client.WebSocketClient;
import org.java_websocket.drafts.Draft_6455;
import javax.net.ssl.*;
import java.net.Socket;
import java.net.URI;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
/**
*
* MyWebSocketClient
*
* Description: 自定义WebSocketClient,忽略wss证书
*/
@Slf4j
public abstract class MyWebSocketClient extends WebSocketClient {
/**
* 创建WebSocketClient
*
* @param serverUri websocket 地址
* @param connectTimeout 连接超时时间,单位:毫秒
*/
public MyWebSocketClient(URI serverUri, int connectTimeout) {
// 设置连接超时时间
super(serverUri, new Draft_6455(), null, connectTimeout);
// 设置不验证SSL证书的SSLContext
TrustManager[] trustAllCerts = new TrustManager[]{new X509ExtendedTrustManager() {
@Override
public void checkClientTrusted(X509Certificate[] x509Certificates, String s, Socket socket) throws CertificateException {
}
@Override
public void checkServerTrusted(X509Certificate[] x509Certificates, String s, Socket socket) throws CertificateException {
}
@Override
public void checkClientTrusted(X509Certificate[] x509Certificates, String s, SSLEngine sslEngine) throws CertificateException {
}
@Override
public void checkServerTrusted(X509Certificate[] x509Certificates, String s, SSLEngine sslEngine) throws CertificateException {
}
@Override
public X509Certificate[] getAcceptedIssuers() {
return null;
}
@Override
public void checkClientTrusted(X509Certificate[] arg0, String arg1) throws CertificateException {
}
@Override
public void checkServerTrusted(X509Certificate[] arg0, String arg1) throws CertificateException {
}
}};
try {
SSLContext ssl = SSLContext.getInstance("SSL");
ssl.init(null, trustAllCerts, new java.security.SecureRandom());
SSLSocketFactory socketFactory = ssl.getSocketFactory();
this.setSocketFactory(socketFactory);
} catch (Exception e) {
log.error("==>初始化SSLContext失败", e);
}
}
}
package com.test.demo.sse;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.java_websocket.client.WebSocketClient;
import org.java_websocket.handshake.ServerHandshake;
import java.net.URI;
/**
*
* MyWebSocketClientHelper
*
* Description:
*/
@Slf4j
public class MyWebSocketClientHelper {
/**
* WebSocketClient连接并发送消息
*
* @param sseEmitter sse链接
* @param msgStorageType 消息处理类型,0:不存储,1:本地数据库存储
*/
public static void connectAndSend(MySseEmitter sseEmitter, String msgStorageType) {
String commonErrorMsg = "通用错误信息,报错啦";
String messageUuid = sseEmitter.getMessageUuid();
WebSocketClient client = null;
try {
client = new MyWebSocketClient(new URI("wss://xxxxx"), Integer.parseInt(Long.toString(sseEmitter.getTimeout()))) {
@Override
public void onOpen(ServerHandshake serverHandshake) {
log.info("==>connect success,messageUuid:{}", messageUuid);
try {
String requestParam = "xxxxxxx";
this.send(requestParam);
} catch (Exception e) {
log.error("==>sendRequest error,messageUuid:{}", messageUuid, e);
throw e;
}
}
@Override
public void onMessage(String result) {
log.info("==>messageUuid:{},onMessage:{}", messageUuid, result);
try {
sseEmitter.mySend(result, true);
} catch (Exception e) {
log.error("==>onMessage error,messageUuid:{}", messageUuid, e);
}
}
@Override
public void onClose(int code, String reason, boolean remote) {
// 1. code(int类型):表示关闭连接的原因,通常是一个整数。例如,如果连接正常关闭,code的值可能是1000(表示正常关闭);如果连接因为服务器主动关闭而关闭,code的值可能是1006(表示服务器端强制关闭)。
// 2. reason(String类型):表示关闭连接的原因,通常是一段文本描述。这个参数是可选的,如果没有提供原因,可以传递一个空字符串或者null。
// 3. remote(boolean类型):表示连接是否被清理。如果为true,表示连接正常关闭;如果为false,表示连接异常关闭。
log.info("==>onClose,messageUuid:{},code:{},reason:{},remote:{}", messageUuid, code, reason, remote);
try {
if (!sseEmitter.isDisconnected() && StringUtils.isBlank(sseEmitter.getTotalAnswer().toString())) {
sseEmitter.mySend(commonErrorMsg, true);
}
sseEmitter.myComplete(msgStorageType);
} catch (Exception ignored) {
}
}
@Override
public void onError(Exception e) {
log.error("==>onError,messageUuid:{}", messageUuid, e);
try {
this.close();
} catch (Exception ignored) {
}
}
};
client.connect();
} catch (Exception e) {
log.error("==>WebSocketClientConnectAndSend error,messageUuid:{}", messageUuid, e);
try {
if (client != null) {
client.close();
}
} catch (Exception ignored) {
}
}
}
}
package com.test.demo.sse;
/**
*
* AsyncService
*
* Description:
*/
public interface AsyncService {
/**
* 异步匀速返回消息
*
* @param sseEmitter
*/
void sendMsgWithSpeedControl(MySseEmitter sseEmitter);
/**
* 异步保存消息
*
* @param messageUuid 消息uid
* @param conversationUuid 会话uid
* @param userUid 当前登录人uid
* @param userQuestion 用户输入的问题
* @param totalAnswerStorage websocket返回的具体消息内容
*/
void saveMsg(String messageUuid, String conversationUuid, String userUid, String userQuestion, String totalAnswerStorage);
}
package com.test.demo.sse;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Lazy;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import java.util.concurrent.*;
/**
*
* AsyncServiceImpl
*
* Description:
*/
@Slf4j
@Async
@Service
public class AsyncServiceImpl implements AsyncService {
@Lazy
@Autowired
ScheduledExecutorService scheduledExecutorService;
@Override
public void sendMsgWithSpeedControl(MySseEmitter sseEmitter) {
sseEmitter.setStartSendMsgWithSpeedControl(true);
sseEmitter.setEndSendMsgWithSpeedControl(false);
final String messageUuid = sseEmitter.getMessageUuid();
// 使用scheduleAtFixedRate方法安排任务
Future> future = scheduledExecutorService.scheduleAtFixedRate(() -> {
int sendLength = sseEmitter.getSendLength();
int sendMsgSpeed = sseEmitter.getSendMsgSpeed();
// 当前时刻,所有的消息
String nowAllAnswer = sseEmitter.getTotalAnswer().toString();
int totalAnswerLength = nowAllAnswer.length();
if (sendLength >= totalAnswerLength) {
log.info("==>当前时刻所有消息已全部发送完成,任务执行结束,messageUuid:{}", messageUuid);
throw new RuntimeException("==>当前时刻所有消息已全部发送完成,任务执行结束,messageUuid:" + messageUuid);
}
String message;
if ((sendLength + sendMsgSpeed) > totalAnswerLength) {
message = nowAllAnswer.substring(sendLength);
} else {
message = nowAllAnswer.substring(sendLength, sendLength + sendMsgSpeed);
}
if (message.endsWith("\r") && (sendLength + sendMsgSpeed + 1) <= totalAnswerLength) {
message = nowAllAnswer.substring(sendLength, sendLength + sendMsgSpeed + 1);
}
try {
sseEmitter.send(message);
} catch (Exception e) {
sseEmitter.setDisconnected(true);
log.info("==>发送消息失败,视为端上主动断开链接,任务执行结束,messageUuid:{}", messageUuid);
throw new RuntimeException("==>发送消息失败,视为端上主动断开链接,任务执行结束,messageUuid:" + messageUuid);
}
sendLength += message.length();
sseEmitter.setSendLength(sendLength);
}, sseEmitter.getSleepTime(), sseEmitter.getSleepTime(), TimeUnit.MILLISECONDS);
// 尝试获取任务结果,如果超过超时时间则抛出TimeoutException异常
try {
// 超时时间,单位:毫秒
future.get(sseEmitter.getTimeout(), TimeUnit.MILLISECONDS);
} catch (TimeoutException | ExecutionException e) {
log.info("==>任务执行结束,messageUuid:{}", messageUuid);
// 取消任务
future.cancel(true);
} catch (Exception e) {
log.error("==>任务执行异常,messageUuid:{}", messageUuid, e);
// 取消任务
future.cancel(true);
}
sseEmitter.setStartSendMsgWithSpeedControl(false);
sseEmitter.setEndSendMsgWithSpeedControl(true);
}
@Override
@Transactional(rollbackFor = Exception.class)
public void saveMsg(String messageUuid, String conversationUuid, String userUid, String userQuestion, String totalAnswerStorage) {
log.info("==>保存会话和信息,messageUuid:{},conversationUuid:{},userUid:{},userQuestion:{},totalAnswerStorage:{}",
messageUuid, conversationUuid, userUid, userQuestion, totalAnswerStorage);
// 会话不存在的,新建会话并保存
// 保存消息
}
}
package com.test.demo.sse;
import lombok.extern.slf4j.Slf4j;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
/**
*
* SseTestController
*
* Description:
*/
@Slf4j
@RestController
@RequestMapping("/sse")
public class SseTestController {
/**
* 用户提问提问
*
* @param input 输入信息
* @return
*/
@PostMapping(value = "/ask")
public SseEmitter ask(@RequestBody @Valid Input input) {
MySseEmitter sseEmitter = new MySseEmitter(true, 60000L, input.getUserUid(), input.getUserQuestion);
try {
MyWebSocketClientHelper.connectAndSend(sseEmitter, "1");
} catch (Exception e) {
log.error("ask error", e);
sseEmitter.myComplete("1");
}
return sseEmitter;
}
}