• 【java】实现sse调用websocket接口,忽略wss证书并控制sse吐字速度


    maven

            
                org.java-websocket
                Java-WebSocket
                1.5.3
            
    

    AsyncConfig

    package com.test.demo.sse;
    
    import org.springframework.beans.factory.annotation.Value;
    import org.springframework.context.annotation.Bean;
    import org.springframework.context.annotation.Configuration;
    import org.springframework.scheduling.annotation.EnableAsync;
    import org.springframework.scheduling.concurrent.CustomizableThreadFactory;
    import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
    
    import java.util.concurrent.Executor;
    import java.util.concurrent.Executors;
    import java.util.concurrent.ScheduledExecutorService;
    
    /**
     * 

    * AsyncConfig *

    * Description: 异步配置 */ @EnableAsync @Configuration public class AsyncConfig { /** * 核心线程数(默认线程数) */ @Value("${sync.corePoolSize:50}") private int corePoolSize; /** * 最大线程数 */ @Value("${sync.maxPoolSize:200}") private int maxPoolSize; /** * 缓冲队列数数量 */ @Value("${sync.queueCapacity:10000000}") private int queueCapacity; @Bean public Executor executor() { ThreadPoolTaskExecutor taskExecutor = new ThreadPoolTaskExecutor(); // 核心线程数(默认线程数) taskExecutor.setCorePoolSize(corePoolSize); // 最大线程数 taskExecutor.setMaxPoolSize(maxPoolSize); // 缓冲队列数,默认Integer.MAX_VALUE. taskExecutor.setQueueCapacity(queueCapacity); // 线程池名前缀 taskExecutor.setThreadNamePrefix("async-executor-"); // 允许线程空闲时间(单位:秒),默认:60 // taskExecutor.setKeepAliveSeconds(60); // 线程池对拒绝任务的处理策略,默认值AbortPolicy // taskExecutor.setRejectedExecutionHandler(new ThreadPoolExecutor.AbortPolicy()); // 初始化 taskExecutor.initialize(); return taskExecutor; } @Bean public ScheduledExecutorService scheduledExecutorService() { return Executors.newScheduledThreadPool(corePoolSize, new CustomizableThreadFactory("schedule-executor-")); } }

    SpringContextUtils

    package com.test.demo.sse;
    
    import org.springframework.beans.BeansException;
    import org.springframework.context.ApplicationContext;
    import org.springframework.context.ApplicationContextAware;
    import org.springframework.context.ApplicationEvent;
    import org.springframework.stereotype.Component;
    import org.springframework.web.context.request.RequestContextHolder;
    import org.springframework.web.context.request.ServletRequestAttributes;
    
    import javax.servlet.http.HttpServletRequest;
    
    /**
     * Spring ApplicationContext 工具类
     */
    @Component
    public class SpringContextUtils implements ApplicationContextAware {
    
        /**
         * 上下文对象实例
         */
        private static ApplicationContext applicationContext;
    
        @Override
        public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
            SpringContextUtils.applicationContext = applicationContext;
        }
    
        /**
         * 获取applicationContext
         *
         * @return
         */
        public static ApplicationContext getApplicationContext() {
            return applicationContext;
        }
    
        /**
         * 获取HttpServletRequest
         */
        public static HttpServletRequest getHttpServletRequest() {
            return ((ServletRequestAttributes) RequestContextHolder.getRequestAttributes()).getRequest();
        }
    
        public static String getDomain() {
            HttpServletRequest request = getHttpServletRequest();
            StringBuffer url = request.getRequestURL();
            return url.delete(url.length() - request.getRequestURI().length(), url.length()).toString();
        }
    
        public static String getOrigin() {
            HttpServletRequest request = getHttpServletRequest();
            return request.getHeader("Origin");
        }
    
        /**
         * 通过name获取 Bean.
         *
         * @param name
         * @return
         */
        public static Object getBean(String name) {
            return getApplicationContext().getBean(name);
        }
    
        /**
         * 通过class获取Bean.
         *
         * @param clazz
         * @param 
         * @return
         */
        public static  T getBean(Class clazz) {
            return getApplicationContext().getBean(clazz);
        }
    
        /**
         * 通过name,以及Clazz返回指定的Bean
         *
         * @param name
         * @param clazz
         * @param 
         * @return
         */
        public static  T getBean(String name, Class clazz) {
            return getApplicationContext().getBean(name, clazz);
        }
    
        /**
         * 发布事件
         *
         * @param event
         */
        public static void publishEvent(ApplicationEvent event) {
            if (applicationContext == null) {
                return;
            }
            applicationContext.publishEvent(event);
        }
    }
    
    

    MySseEmitter

    package com.test.demo.sse;
    
    import lombok.Data;
    import lombok.EqualsAndHashCode;
    import lombok.extern.slf4j.Slf4j;
    import org.apache.commons.lang3.StringUtils;
    import org.springframework.http.HttpHeaders;
    import org.springframework.http.MediaType;
    import org.springframework.http.server.ServerHttpResponse;
    import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
    
    import java.nio.charset.StandardCharsets;
    import java.util.UUID;
    import java.util.concurrent.*;
    
    /**
     * 

    * MySseEmitter *

    * Description: 解决SseEmitter浏览器下中文乱码问题 */ @EqualsAndHashCode(callSuper = true) @Data @Slf4j public class MySseEmitter extends SseEmitter { /** * websocket返回的所有信息,只用于将消息发送到前端 */ private StringBuilder totalAnswer = new StringBuilder(); /** * websocket返回的所有信息,用于最终存储的消息内容 */ private StringBuilder totalAnswerStorage = new StringBuilder(); /** * 链接是否已主动断开,,true:已主动断开,false:未断开 */ private boolean disconnected = false; /** * 本条消息的唯一id */ private String messageUuid = UUID.randomUUID().toString(); /** * 本次会话的唯一id */ private String conversationUuid = UUID.randomUUID().toString(); /** * 是否匀速返回,true:需要匀速,false:不需要匀速 */ private boolean speedControl; /** * 是否已经开始匀速返回信息,true:已经开始,false:还没有开始 */ private boolean startSendMsgWithSpeedControl = false; /** * 已经发送的消息的长度 */ private int sendLength = 0; /** * 所有消息是否已经全部匀速返回,true:已经全部返回,false:还没有全部返回 */ private boolean endSendMsgWithSpeedControl = false; /** * 匀速吐字间隔时间,单位:毫秒 */ private long sleepTime = 20L; /** * 匀速发送消息时每次返回多少个字符 */ private int sendMsgSpeed = 1; /** * 超时时间,单位:毫秒 */ private long timeout; /** * 当前登录人 */ private String userUid; /** * 当前登录人的问题 */ private String userQuestion; /** * 解决中文乱码 * * @param outputMessage */ @Override protected void extendResponse(ServerHttpResponse outputMessage) { super.extendResponse(outputMessage); HttpHeaders headers = outputMessage.getHeaders(); headers.setContentType(new MediaType(MediaType.TEXT_EVENT_STREAM, StandardCharsets.UTF_8)); } /** * 创建SSE对象 * * @param speedControl 是否开启匀速,true:开启,false:关闭 * @param timeout 超时时间,单位:毫秒 * @param userUid 当前登录人 * @param userQuestion 当前登录人的问题 */ public MySseEmitter(boolean speedControl, long timeout, String userUid, String userQuestion) { // 设置超时时间,单位:毫秒 super(timeout); this.speedControl = speedControl; this.timeout = timeout; this.userUid = userUid; this.userQuestion = userQuestion; } /** * 自定义发送消息方法 * * @param message 具体消息 * @param msgStorage 本次发送的消息内容是否需要进行存储,true:需要,false:不需要 * @return 是否需要关闭链接,true:是,false:否 */ public boolean mySend(String message, boolean msgStorage) { try { if (StringUtils.isNotEmpty(message)) { // 处理换行,PC换行\r、\n、、\r\n都行,移动只能\r\n message = message.replaceAll("\r", "\n") .replaceAll("\n\n", "\n") .replaceAll("\n\n", "\n") .replaceAll("\n\n", "\n") .replaceAll("\n\n", "\n") .replaceAll("\n\n", "\n") .replaceAll("\n\n", "\n") .replaceAll("\n\n", "\n") .replaceAll("\n\n", "\n") .replaceAll("\n\n", "\n") .replaceAll("\n\n", "\n") .replaceAll("\n\n", "\n") .replaceAll("\n\n", "\n") .replaceAll("\n", "\r\n"); this.totalAnswer.append(message); if (msgStorage) { this.totalAnswerStorage.append(message); } if (!this.speedControl) { super.send(message); } else if (!this.startSendMsgWithSpeedControl && !this.disconnected) { // 异步发送 SpringContextUtils.getBean(AsyncService.class).sendMsgWithSpeedControl(this); } } } catch (Exception e) { log.error("==>MySseEmitter send error,conversationUuid:{}", this.conversationUuid, e); this.disconnected = true; } return this.disconnected; } /** * 断开连接 * * @param msgStorageType 消息处理类型,0:不存储,1:本地数据库存储 */ public void myComplete(String msgStorageType) { try { if (!this.speedControl || this.disconnected || this.endSendMsgWithSpeedControl) { super.complete(); } else { Future future = SpringContextUtils.getBean(ScheduledExecutorService.class).scheduleAtFixedRate(() -> { if (this.endSendMsgWithSpeedControl) { log.info("==>当前消息已全部返回完成,主动断开与端上链接,conversationUuid:{}", conversationUuid); throw new RuntimeException("==>当前消息已全部返回完成,主动断开与端上链接,conversationUuid:" + conversationUuid); } }, this.sleepTime, this.sleepTime, TimeUnit.MILLISECONDS); try { // 超时时间,单位:毫秒 future.get(timeout, TimeUnit.MILLISECONDS); } catch (TimeoutException | ExecutionException e) { log.info("==>等待断开链接任务执行结束,conversationUuid:{}", conversationUuid); // 取消任务 future.cancel(true); } catch (Exception e) { log.error("==>等待断开链接任务执行异常,conversationUuid:{}", conversationUuid, e); // 取消任务 future.cancel(true); } super.complete(); } } catch (Exception ignore) { } if ("1".equals(msgStorageType)) { // 本地数据库存储消息,异步保存数据 SpringContextUtils.getBean(AsyncService.class).saveMsg(this.messageUuid, this.conversationUuid, this.userUid, this.userQuestion, this.totalAnswerStorage.toString()); } } }

    MyWebSocketClient

    package com.test.demo.sse;
    
    import lombok.extern.slf4j.Slf4j;
    import org.java_websocket.client.WebSocketClient;
    import org.java_websocket.drafts.Draft_6455;
    
    import javax.net.ssl.*;
    import java.net.Socket;
    import java.net.URI;
    import java.security.cert.CertificateException;
    import java.security.cert.X509Certificate;
    
    /**
     * 

    * MyWebSocketClient *

    * Description: 自定义WebSocketClient,忽略wss证书 */ @Slf4j public abstract class MyWebSocketClient extends WebSocketClient { /** * 创建WebSocketClient * * @param serverUri websocket 地址 * @param connectTimeout 连接超时时间,单位:毫秒 */ public MyWebSocketClient(URI serverUri, int connectTimeout) { // 设置连接超时时间 super(serverUri, new Draft_6455(), null, connectTimeout); // 设置不验证SSL证书的SSLContext TrustManager[] trustAllCerts = new TrustManager[]{new X509ExtendedTrustManager() { @Override public void checkClientTrusted(X509Certificate[] x509Certificates, String s, Socket socket) throws CertificateException { } @Override public void checkServerTrusted(X509Certificate[] x509Certificates, String s, Socket socket) throws CertificateException { } @Override public void checkClientTrusted(X509Certificate[] x509Certificates, String s, SSLEngine sslEngine) throws CertificateException { } @Override public void checkServerTrusted(X509Certificate[] x509Certificates, String s, SSLEngine sslEngine) throws CertificateException { } @Override public X509Certificate[] getAcceptedIssuers() { return null; } @Override public void checkClientTrusted(X509Certificate[] arg0, String arg1) throws CertificateException { } @Override public void checkServerTrusted(X509Certificate[] arg0, String arg1) throws CertificateException { } }}; try { SSLContext ssl = SSLContext.getInstance("SSL"); ssl.init(null, trustAllCerts, new java.security.SecureRandom()); SSLSocketFactory socketFactory = ssl.getSocketFactory(); this.setSocketFactory(socketFactory); } catch (Exception e) { log.error("==>初始化SSLContext失败", e); } } }

    MyWebSocketClientHelper

    package com.test.demo.sse;
    
    import lombok.extern.slf4j.Slf4j;
    import org.apache.commons.lang3.StringUtils;
    import org.java_websocket.client.WebSocketClient;
    import org.java_websocket.handshake.ServerHandshake;
    
    import java.net.URI;
    
    /**
     * 

    * MyWebSocketClientHelper *

    * Description: */ @Slf4j public class MyWebSocketClientHelper { /** * WebSocketClient连接并发送消息 * * @param sseEmitter sse链接 * @param msgStorageType 消息处理类型,0:不存储,1:本地数据库存储 */ public static void connectAndSend(MySseEmitter sseEmitter, String msgStorageType) { String commonErrorMsg = "通用错误信息,报错啦"; String messageUuid = sseEmitter.getMessageUuid(); WebSocketClient client = null; try { client = new MyWebSocketClient(new URI("wss://xxxxx"), Integer.parseInt(Long.toString(sseEmitter.getTimeout()))) { @Override public void onOpen(ServerHandshake serverHandshake) { log.info("==>connect success,messageUuid:{}", messageUuid); try { String requestParam = "xxxxxxx"; this.send(requestParam); } catch (Exception e) { log.error("==>sendRequest error,messageUuid:{}", messageUuid, e); throw e; } } @Override public void onMessage(String result) { log.info("==>messageUuid:{},onMessage:{}", messageUuid, result); try { sseEmitter.mySend(result, true); } catch (Exception e) { log.error("==>onMessage error,messageUuid:{}", messageUuid, e); } } @Override public void onClose(int code, String reason, boolean remote) { // 1. code(int类型):表示关闭连接的原因,通常是一个整数。例如,如果连接正常关闭,code的值可能是1000(表示正常关闭);如果连接因为服务器主动关闭而关闭,code的值可能是1006(表示服务器端强制关闭)。 // 2. reason(String类型):表示关闭连接的原因,通常是一段文本描述。这个参数是可选的,如果没有提供原因,可以传递一个空字符串或者null。 // 3. remote(boolean类型):表示连接是否被清理。如果为true,表示连接正常关闭;如果为false,表示连接异常关闭。 log.info("==>onClose,messageUuid:{},code:{},reason:{},remote:{}", messageUuid, code, reason, remote); try { if (!sseEmitter.isDisconnected() && StringUtils.isBlank(sseEmitter.getTotalAnswer().toString())) { sseEmitter.mySend(commonErrorMsg, true); } sseEmitter.myComplete(msgStorageType); } catch (Exception ignored) { } } @Override public void onError(Exception e) { log.error("==>onError,messageUuid:{}", messageUuid, e); try { this.close(); } catch (Exception ignored) { } } }; client.connect(); } catch (Exception e) { log.error("==>WebSocketClientConnectAndSend error,messageUuid:{}", messageUuid, e); try { if (client != null) { client.close(); } } catch (Exception ignored) { } } } }

    AsyncService

    package com.test.demo.sse;
    
    /**
     * 

    * AsyncService *

    * Description: */ public interface AsyncService { /** * 异步匀速返回消息 * * @param sseEmitter */ void sendMsgWithSpeedControl(MySseEmitter sseEmitter); /** * 异步保存消息 * * @param messageUuid 消息uid * @param conversationUuid 会话uid * @param userUid 当前登录人uid * @param userQuestion 用户输入的问题 * @param totalAnswerStorage websocket返回的具体消息内容 */ void saveMsg(String messageUuid, String conversationUuid, String userUid, String userQuestion, String totalAnswerStorage); }

    AsyncServiceImpl

    package com.test.demo.sse;
    
    import lombok.extern.slf4j.Slf4j;
    import org.springframework.beans.factory.annotation.Autowired;
    import org.springframework.context.annotation.Lazy;
    import org.springframework.scheduling.annotation.Async;
    import org.springframework.stereotype.Service;
    import org.springframework.transaction.annotation.Transactional;
    
    import java.util.concurrent.*;
    
    /**
     * 

    * AsyncServiceImpl *

    * Description: */ @Slf4j @Async @Service public class AsyncServiceImpl implements AsyncService { @Lazy @Autowired ScheduledExecutorService scheduledExecutorService; @Override public void sendMsgWithSpeedControl(MySseEmitter sseEmitter) { sseEmitter.setStartSendMsgWithSpeedControl(true); sseEmitter.setEndSendMsgWithSpeedControl(false); final String messageUuid = sseEmitter.getMessageUuid(); // 使用scheduleAtFixedRate方法安排任务 Future future = scheduledExecutorService.scheduleAtFixedRate(() -> { int sendLength = sseEmitter.getSendLength(); int sendMsgSpeed = sseEmitter.getSendMsgSpeed(); // 当前时刻,所有的消息 String nowAllAnswer = sseEmitter.getTotalAnswer().toString(); int totalAnswerLength = nowAllAnswer.length(); if (sendLength >= totalAnswerLength) { log.info("==>当前时刻所有消息已全部发送完成,任务执行结束,messageUuid:{}", messageUuid); throw new RuntimeException("==>当前时刻所有消息已全部发送完成,任务执行结束,messageUuid:" + messageUuid); } String message; if ((sendLength + sendMsgSpeed) > totalAnswerLength) { message = nowAllAnswer.substring(sendLength); } else { message = nowAllAnswer.substring(sendLength, sendLength + sendMsgSpeed); } if (message.endsWith("\r") && (sendLength + sendMsgSpeed + 1) <= totalAnswerLength) { message = nowAllAnswer.substring(sendLength, sendLength + sendMsgSpeed + 1); } try { sseEmitter.send(message); } catch (Exception e) { sseEmitter.setDisconnected(true); log.info("==>发送消息失败,视为端上主动断开链接,任务执行结束,messageUuid:{}", messageUuid); throw new RuntimeException("==>发送消息失败,视为端上主动断开链接,任务执行结束,messageUuid:" + messageUuid); } sendLength += message.length(); sseEmitter.setSendLength(sendLength); }, sseEmitter.getSleepTime(), sseEmitter.getSleepTime(), TimeUnit.MILLISECONDS); // 尝试获取任务结果,如果超过超时时间则抛出TimeoutException异常 try { // 超时时间,单位:毫秒 future.get(sseEmitter.getTimeout(), TimeUnit.MILLISECONDS); } catch (TimeoutException | ExecutionException e) { log.info("==>任务执行结束,messageUuid:{}", messageUuid); // 取消任务 future.cancel(true); } catch (Exception e) { log.error("==>任务执行异常,messageUuid:{}", messageUuid, e); // 取消任务 future.cancel(true); } sseEmitter.setStartSendMsgWithSpeedControl(false); sseEmitter.setEndSendMsgWithSpeedControl(true); } @Override @Transactional(rollbackFor = Exception.class) public void saveMsg(String messageUuid, String conversationUuid, String userUid, String userQuestion, String totalAnswerStorage) { log.info("==>保存会话和信息,messageUuid:{},conversationUuid:{},userUid:{},userQuestion:{},totalAnswerStorage:{}", messageUuid, conversationUuid, userUid, userQuestion, totalAnswerStorage); // 会话不存在的,新建会话并保存 // 保存消息 } }

    使用方法

    package com.test.demo.sse;
    
    import lombok.extern.slf4j.Slf4j;
    import org.springframework.web.bind.annotation.PostMapping;
    import org.springframework.web.bind.annotation.RequestBody;
    import org.springframework.web.bind.annotation.RequestMapping;
    import org.springframework.web.bind.annotation.RestController;
    import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
    
    /**
     * 

    * SseTestController *

    * Description: */ @Slf4j @RestController @RequestMapping("/sse") public class SseTestController { /** * 用户提问提问 * * @param input 输入信息 * @return */ @PostMapping(value = "/ask") public SseEmitter ask(@RequestBody @Valid Input input) { MySseEmitter sseEmitter = new MySseEmitter(true, 60000L, input.getUserUid(), input.getUserQuestion); try { MyWebSocketClientHelper.connectAndSend(sseEmitter, "1"); } catch (Exception e) { log.error("ask error", e); sseEmitter.myComplete("1"); } return sseEmitter; } }
  • 相关阅读:
    wsgiref模块、web框架、django框架简介
    Git提交代码仓库的两种方式
    2023年系统规划与设计管理师-第一章信息的综合知识
    Qt 窗口的坐标体系
    思辨:移动开发的未来在哪?
    ToBeWritten之基于ATT&CK的模拟攻击:闭环的防御与安全运营
    基于机器视觉的移动消防机器人(二)--详细设计
    Mybatis源码分析-查询机制&工作原理
    Java当中的数组的定义与简单使用
    【Python基础】Python调试器pdb
  • 原文地址:https://blog.csdn.net/muguazhi/article/details/140345010