• 一个java项目中,如何使用sse协议,构造一个chatgpt的流式对话接口


    前言
    如何注册chatGPT,怎么和它交互,本文就不讲了;因为网上教程一大堆,而且你要使用的话,通常会再包一个算法服务,用来做一些数据训练和过滤处理之类的,业务服务基本不会直接与原生chatGPT交互。
    而下面阐述的,就是业务服务与算法服务的交互。

    业务需求-需要实现什么样的功能

    想要一个类似与AI问答助手的机器人,可以实现根据某些场景对话提问的功能

    1. 可以直接提问,类似直接使用chatGPT,只不过这个提问的过程会做一些业务通用处理,比如问答数据的归纳反馈、敏感词过滤等等。
    2. 也可以给它喂一篇论文,喂一批近期的资讯,或者是一本小说之类的,根据指定的上下文去进行问答(这种场景需要先投递数据建立相关索引)
    3. ai的回答要求和chatGPT一样保持流式返回(也就是一个字一个字,一边生成一边返回,而不是等整个回答生成完之后一股脑返回)

    剖析

    重点是流式,这里我们预设算法侧已经有了一个流式返回的接口,整体的交互如下图所示
    在这里插入图片描述
    下面分别介绍几个关键节点的数据交互设计,仅供参考

    q1

    简述:页面发送问答数据给业务服务端

    {
      "chatId": 233,
      "question": "这篇论文有几个论点?"
    }
    
    • 1
    • 2
    • 3
    • 4
    • 这里的chatId可以理解为一个对话框id,业务服务端可以根据这个来进行问答归类、批量删除收藏、问答上下文查询等操作。
    • question就是问题的内容

    这里需要注意就是,交互数据格式尽可能简单、易拓展,有些产品的页面交互设计的非常复杂,什么历史问答、角色信息之类的,套了一层又一层,其实很多都没必要的,这样前端组装起来也麻烦,也不利于数据的管理与后期功能的拓展。

    q2

    简述:就是业务服务根据前端传来的问题和所属的对话框,把相应的上下文查询出来(甚至可以前端维护一个是否发送上下文的开关,更动态一点),包装成算法服务所需要的问题数据,发给它。

    {
              "chatId":  233,
              "userName":  "张三",
              "messageKey":  "0a795f6a-a029-435f-8d67-6f6f8e078cfe",
              "message":  "这篇论文有几个论点?",
              "chatHistory":  [
              	{
                                  "messageKey":  "0a795f6a-a029-435f-8d67-6f6f8e07dasd",
                                  "question":  "这篇论文的作者是谁",
                                  "answer":  "这篇论文的作者是李四博士。"
                        }
              ],
              "callbackUrl":  "http://xxx/chat/question/callback"
    }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 上述的messageKey就是一个消息的key,用以常规的接口调试
    • chatHistory就是历史问答记录,即上下文,众所周知chatGPT带不带上下文,回答的结果可能截然不同
    • callbackUrl是业务定义的一个回调接口,用来回调一些算法侧异步生产的信息,比如原文的定位信息、根据当前问题生成的推荐问答等,这些和流式的回答是不会一起返回的,所以额外提供一个接口来接收。

    q3和a3

    这两步不详述(主要我也不是研究算法模型的哈哈,不是很清楚细节)
    我们只需要定义好a2返回的结果即可

    a2

    简述:主要是算法侧返回给业务服务的同步的流式回答,同时还可能有异步的额外信息的回调(q2的callbackUrl来接收)。所以a2的返回结果分为两个response
    response1:同步的流式回答,一般在2-7s内返回第一个字符

    data:data:data:... // 省略一些输出
    data:data:
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6

    流式问答的规范可以参考:流式接口协议规范

    response2:异步的拓展信息(可有可无)

    {
        "messageKey":"0a795f6a-a029-435f-8d67-6f6f8e078cfe", //必传 回调的消息key,每次问答唯一
        "expand":{
            "recommendedQuestions":[ // 推荐问题
                "这篇论文的主要论点是什么?"
            ],
            "originalIndex":[{
                "sourceId":3432,
                "text":"首先第一个论点是......",
                "textIndex":90
            }]
        }
    }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13

    a1

    简述:a1要返回的格式很好理解,就是把a2中的两个response组合在一起,需要注意的有几点

    1. a2的response2不一定有,需要设置超时策略,且需要在流式回答最后输出
    2. a2中的response1是流式,response2不是;但输出到a1的时候,需要保证都在流中
    3. 最好需要约定一些event来作为标识符
    event:messageKey // 消息key事件
    
    data:0a795f6a-a029-435f-8d67-6f6f8e078cfe
    
    event:answer // 流式回答开始事件
    
    data:"在论文"
    
    data:"中"
    
    data:","
    
    data:"我"
    
    data:"们"
    
    data:"一"
    
    data:"共"
    ...
    data:"几"
    
    data:"个"
    
    data:"论"
    
    data:"点"
    
    event:endTime
    
    data:2024-02-27 17:05:24
    
    event:expand // 拓展信息开始事件,此处等待15s超时
    
    data:{"recommendedQuestions":["这篇论文的主要论点是什么"],"originalIndex":{"sourceId":32133,"text":"首先第一个论点是......","textIndex":90}}
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35

    代码

    代码省略了一些无关紧要的业务特有的部分,只保留通用的部分
    工具类:SSEUtils,用来操作SSE客户端

    import lombok.extern.slf4j.Slf4j;
    
    import java.io.IOException;
    import java.util.Map;
    import java.util.concurrent.ConcurrentHashMap;
    
    /**
     * description
     *
     * @author luhui
     * @date 2024/1/25
     */
    @Slf4j
    public class SSEUtils {
        /**
         * timeout 30min
         */
        private static final Long DEFAULT_TIME_OUT = 30 * 60 * 1000L;
        /**
         * 订阅表
         */
        private static final Map<String, EvaEmitter> EMITTER_MAP = new ConcurrentHashMap<>();
    
        public static final String MSG_DATA_PREFIX = "data:";
        public static final String MSG_EVENT_PREFIX = "event:";
    
        /**
         * description: 创建流
         *
         * @param messageKey 本次问答的消息key
         * @return org.springframework.web.servlet.mvc.method.annotation.SseEmitter
         * @author luhui
         * @date 2024/2/23 17:09
         */
        public static EvaEmitter getEmitter(String messageKey) {
            if (null == messageKey || "".equals(messageKey)) {
                return null;
            }
    
            EvaEmitter emitter = EMITTER_MAP.get(messageKey);
            if (null == emitter) {
                emitter = new EvaEmitter(DEFAULT_TIME_OUT);
                EMITTER_MAP.put(messageKey, emitter);
            }
    
            return emitter;
        }
    
        /**
         * description: 发消息
         *
         * @param messageKey 本次问答的消息key
         * @param msg        消息
         * @author luhui
         * @date 2024/2/23 17:09
         */
        public static void pushMsg(String messageKey, String msg) throws IOException {
            EvaEmitter emitter = EMITTER_MAP.get(messageKey);
            if (null != emitter) {
                emitter.send(EvaEmitter.event().data(msg));
            }
        }
    
        public static void pushEvent(String messageKey, String eventDesc) throws IOException{
            EvaEmitter emitter = EMITTER_MAP.get(messageKey);
            if (null != emitter) {
                emitter.send(EvaEmitter.event().name(eventDesc));
            }
        }
    
        /**
         * description: 关闭流
         *
         * @param messageKey 本次问答的消息key
         * @author luhui
         * @date 2024/2/23 17:08
         */
        public static void closeEmitter(String messageKey) {
            EvaEmitter emitter = EMITTER_MAP.get(messageKey);
            if (null != emitter) {
                try {
                    emitter.complete();
                    EMITTER_MAP.remove(messageKey);
                } catch (Exception e) {
                    e.printStackTrace();
                }
            }
        }
    }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89

    工具类:SSEClient ,用来获取SSE流

    import com.alibaba.fastjson.JSONObject;
    import lombok.extern.slf4j.Slf4j;
    
    import java.io.*;
    import java.net.HttpURLConnection;
    import java.net.URL;
    
    /**
     * description
     *
     * @author luhui
     * @date 2024/1/25
     */
    @Slf4j
    public class SSEClient {
        // timeout
        public static Integer DEFAULT_TIME_OUT = 60 * 1000;
    
        /**
         * 获取SSE输入流
         */
        public static InputStream getSseInputStream(String urlPath, JSONObject param, int timeoutMill) {
            HttpURLConnection urlConnection = null;
            try {
                urlConnection = getHttpURLConnection(urlPath, timeoutMill);
                putData(urlConnection, param);
                InputStream inputStream = urlConnection.getInputStream();
                return new BufferedInputStream(inputStream);
            } catch (IOException e) {
                e.printStackTrace();
            }
            return null;
        }
    
        /**
         * 读流数据
         */
        public static void readStream(InputStream is, MsgHandler msgHandler) throws IOException {
            BufferedReader reader = new BufferedReader(new InputStreamReader(is));
            try {
                String line = "";
                while ((line = reader.readLine()) != null) {
                    if ("".equals(line)) {
                        continue;
                    }
                    msgHandler.handleMsg(line);
                }
            } catch (Exception e) {
                e.printStackTrace();
                // 目前这里抛出的显式异常来自与用户手动关闭的连接,此时服务端与算法端的连接也捕获并关闭,无需存储
            } finally {
                // 服务器端主动关闭时,客户端手动关闭
                reader.close();
                is.close();
            }
        }
    
        private static HttpURLConnection getHttpURLConnection(String urlPath, int timeoutMill) throws IOException {
            URL url = new URL(urlPath);
            HttpURLConnection urlConnection = (HttpURLConnection) url.openConnection();
            urlConnection.setDoOutput(true);
            urlConnection.setDoInput(true);
            urlConnection.setUseCaches(false);
            urlConnection.setRequestMethod("POST");
            urlConnection.setRequestProperty("Connection", "Keep-Alive");
            urlConnection.setRequestProperty("Charset", "UTF-8");
            urlConnection.setRequestProperty("Content-Type", "application/json;charset=UTF-8");
            urlConnection.setRequestProperty("accept", "text/event-stream");
            // 读过期时间
            urlConnection.setReadTimeout(timeoutMill);
            return urlConnection;
        }
    
        public static void putData(HttpURLConnection connection, JSONObject jsonStr) throws IOException {
            byte[] writebytes = jsonStr.toJSONString().getBytes();
            connection.setRequestProperty("Content-Length", String.valueOf(writebytes.length));
            DataOutputStream wr = new DataOutputStream(connection.getOutputStream());
            wr.write(jsonStr.toJSONString().getBytes());
            wr.flush();
            wr.close();
        }
    
        /**
         * 消息处理接口
         */
        public interface MsgHandler {
            void handleMsg(String line) throws IOException;
        }
    }
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    • 66
    • 67
    • 68
    • 69
    • 70
    • 71
    • 72
    • 73
    • 74
    • 75
    • 76
    • 77
    • 78
    • 79
    • 80
    • 81
    • 82
    • 83
    • 84
    • 85
    • 86
    • 87
    • 88
    • 89
    • 90

    工具类:EvaEmitter,用来封装一些流信息

    import cn.hutool.core.date.DateTime;
    import com.alibaba.fastjson.JSONObject;
    import io.swagger.annotations.ApiModelProperty;
    import lombok.Data;
    import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
    
    /**
     * description EvaEmitter
     *
     * @author luhui
     * @date 2024/02/22
     */
    @Data
    public class EvaEmitter extends SseEmitter {
        public EvaEmitter(Long timeout) {
            super(timeout);
        }
    
        @ApiModelProperty("版本id")
        private Long versionId;
        @ApiModelProperty("用户问题")
        private String question;
        @ApiModelProperty("唯一消息key")
        private String messageKey;
        @ApiModelProperty("当前用户")
        private Long currentUid;
        @ApiModelProperty("当前用户名")
        private String currentUserName;
        @ApiModelProperty("项目id")
        private Long projectId;
        @ApiModelProperty("ai回答")
        private String aiAnswer;
        @ApiModelProperty("拓展信息")
        private JSONObject expand;
        @ApiModelProperty("错误信息")
        private JSONObject error;
        @ApiModelProperty("提问开始时间")
        private DateTime startTime;
    
        public JSONObject getHistory() {
            JSONObject history = new JSONObject();
            history.put("question", question);
            history.put("answer", aiAnswer);
            history.put("expand", expand);
            history.put("error", error);
            return history;
        }
    }
    
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49

    具体的chat交互方法

    	String messageKey = UUID.randomUUID().toString();
    	EvaEmitter emitter = SSEUtils.getEmitter(messageKey);
    	emitter.setProjectId(111);
    	// 初始化相关字段
    	
    	sseService.chatTransfer(messageKey);
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
        @Async
        @Override
        public void chatTransfer(String messageKey) {
            EvaEmitter emitter = SSEUtils.getEmitter(messageKey);
    
            // 正式参数
            JSONObject params = new JSONObject(true);
            params.put("versionId", emitter.getVersionId().toString());
            params.put("userName", emitter.getCurrentUserName());
            params.put("messageKey", emitter.getMessageKey());
            params.put("message", emitter.getQuestion());
            params.put("chatHistory", chatHistoryService.getChatHistory(emitter));
            params.put("callbackUrl", gateway + "/xxxchat/question/callback");
    
            InputStream inputStream = SSEClient.getSseInputStream(aiChatUrl, params, SSEClient.DEFAULT_TIME_OUT);
            try {
                StringBuilder answer = new StringBuilder();
                SSEUtils.pushEvent(messageKey, "messageKey");
                SSEUtils.pushMsg(messageKey, messageKey);
                SSEUtils.pushEvent(messageKey, "answer");
                AtomicReference<Boolean> sdkError = new AtomicReference<>(false);
                SSEClient.readStream(inputStream, line -> {
                    log.info("messageKey:{}, chatTransfer:{}", emitter.getMessageKey(), line);
                    String message = "";
                    if (sdkError.get()) {
                        String errorStr = line.split(SSEUtils.MSG_DATA_PREFIX)[1].trim();
                        if (StringUtils.isNotBlank(errorStr)) {
                            // 做一些错误处理
                            message = "算法未知错误,请稍后再试";
                            emitter.setError(message);
                        }
                    } else if (line.contains(SSEUtils.MSG_DATA_PREFIX)) {
                        message = line.split(SSEUtils.MSG_DATA_PREFIX)[1].trim();
                    } else if (line.contains(SSEUtils.MSG_EVENT_PREFIX)) {
                        sdkError.set(true);
                    } else {
                        message = "";
                    }
                    if (StringUtils.isNotBlank(message)) {
                        answer.append(message.replaceAll("\"", ""));
                        SSEUtils.pushMsg(messageKey, message);
                    }
                });
                emitter.setAiAnswer(answer.toString());
                // 保存当前问答消息,自行实现
                ChatHistoryEntity message = chatHistoryService.saveHistory(messageKey);
                SSEUtils.pushEvent(messageKey, "endTime");
                SSEUtils.pushMsg(messageKey, DateUtil.formatDateTime(message.getEndTime()));
                SSEUtils.pushEvent(messageKey, "expand");
                chatHistoryService.pushExpand(messageKey);
            } catch (IllegalStateException | IOException e) {
                log.error("pushMsg error, web端流已被关闭");
            } catch (Exception e) {
                e.printStackTrace();
            } finally {
                // 消息发送完或者出现异常的话,存储当前的消息,然后关闭流
                try {
                    chatHistoryService.saveHistory(messageKey);
                } catch (Exception e) {
                    e.printStackTrace();
                } finally {
                    SSEUtils.closeEmitter(messageKey);
                }
            }
        }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47
    • 48
    • 49
    • 50
    • 51
    • 52
    • 53
    • 54
    • 55
    • 56
    • 57
    • 58
    • 59
    • 60
    • 61
    • 62
    • 63
    • 64
    • 65
    	@Override
        @Retryable(value = Exception.class, maxAttempts = 6, backoff = @Backoff(delay = 500, multiplier = 2))
        public void pushExpand(String messageKey) throws IOException {
           // 如果异步的拓展信息,即a2中的response2回调成功的话,会存储到这里
            Object expandObj = redisService.hGet(RedisConstants.CHAT_AI_RECOMMENDED_QUESTIONS, messageKey);
            if (expandObj == null) {
                log.error("未获取到相关拓展信息, 稍后重试");
                throw new RuntimeException("未获取到相关拓展信息");
            } else {
                JSONObject expand = JSONObject.parseObject(expandObj.toString());
                EvaEmitter emitter = SSEUtils.getEmitter(messageKey);
                emitter.setExpand(expand);
                SSEUtils.pushMsg(messageKey, expand.toJSONString());
                log.info("messageKey:{}, chatTransfer:{}", emitter.getMessageKey(), expand);
            }
        }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
  • 相关阅读:
    计算机组成原理(谭志虎主编 )
    Linux文件压缩和解压命令【gzip、gunzip、zip、unzip、tar】【详细总结】
    1.ts介绍
    HTTP响应详解, HTTP请求构造及HTTPS详解
    文件包含学习笔记总结
    数学建模学习(89):交叉熵优化算法(CEM)对多元函数寻优
    记一次 .NET 某外贸ERP 内存暴涨分析
    Fedora CoreOS 安装部署详解
    如何创建并运行java线程呢?
    (一)DepthAI-python相关接口:OAK Device
  • 原文地址:https://blog.csdn.net/qq_31363843/article/details/138076028