完整版
parent
7b1492be0a
commit
01fbfdd5b6
|
@ -1,5 +1,6 @@
|
|||
package com.sikadi.user.controller;
|
||||
|
||||
import com.bwie.common.result.Result;
|
||||
import com.sikadi.user.listener.XfXhStreamClient;
|
||||
import com.sikadi.user.mapper.XfXhConfig;
|
||||
import com.sikadi.user.service.GTPService;
|
||||
|
@ -27,8 +28,9 @@ public class GTPController {
|
|||
* @return 星火大模型的回答
|
||||
*/
|
||||
@GetMapping("/sendQuestion")
|
||||
public String sendQuestion(@RequestParam("question") String question) {
|
||||
return gtpService.sendQuestion(question);
|
||||
public Result sendQuestion(@RequestParam("question") String question) {
|
||||
String s = gtpService.sendQuestion(question);
|
||||
return Result.success(s);
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,115 @@
|
|||
package com.sikadi.user.dto.response;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import com.sikadi.user.dto.MsgDTO;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* 返回参数
|
||||
* 对应生成的 JSON 结构参考 resources/demo-json/response.json
|
||||
*
|
||||
* @author 狐狸半面添
|
||||
* @create 2023-09-15 0:42
|
||||
*/
|
||||
@NoArgsConstructor
|
||||
@Data
|
||||
public class ResponseDTO {
|
||||
|
||||
@JsonProperty("header")
|
||||
private HeaderDTO header;
|
||||
@JsonProperty("payload")
|
||||
private PayloadDTO payload;
|
||||
|
||||
@NoArgsConstructor
|
||||
@Data
|
||||
public static class HeaderDTO {
|
||||
/**
|
||||
* 错误码,0表示正常,非0表示出错
|
||||
*/
|
||||
@JsonProperty("code")
|
||||
private Integer code;
|
||||
/**
|
||||
* 会话是否成功的描述信息
|
||||
*/
|
||||
@JsonProperty("message")
|
||||
private String message;
|
||||
/**
|
||||
* 会话的唯一id,用于讯飞技术人员查询服务端会话日志使用,出现调用错误时建议留存该字段
|
||||
*/
|
||||
@JsonProperty("sid")
|
||||
private String sid;
|
||||
/**
|
||||
* 会话状态,取值为[0,1,2];0代表首次结果;1代表中间结果;2代表最后一个结果
|
||||
*/
|
||||
@JsonProperty("status")
|
||||
private Integer status;
|
||||
}
|
||||
|
||||
@NoArgsConstructor
|
||||
@Data
|
||||
public static class PayloadDTO {
|
||||
@JsonProperty("choices")
|
||||
private ChoicesDTO choices;
|
||||
/**
|
||||
* 在最后一次结果返回
|
||||
*/
|
||||
@JsonProperty("usage")
|
||||
private UsageDTO usage;
|
||||
|
||||
@NoArgsConstructor
|
||||
@Data
|
||||
public static class ChoicesDTO {
|
||||
/**
|
||||
* 文本响应状态,取值为[0,1,2]; 0代表首个文本结果;1代表中间文本结果;2代表最后一个文本结果
|
||||
*/
|
||||
@JsonProperty("status")
|
||||
private Integer status;
|
||||
/**
|
||||
* 返回的数据序号,取值为[0,9999999]
|
||||
*/
|
||||
@JsonProperty("seq")
|
||||
private Integer seq;
|
||||
/**
|
||||
* 响应文本
|
||||
*/
|
||||
@JsonProperty("text")
|
||||
private List<MsgDTO> text;
|
||||
|
||||
}
|
||||
|
||||
@NoArgsConstructor
|
||||
@Data
|
||||
public static class UsageDTO {
|
||||
@JsonProperty("text")
|
||||
private TextDTO text;
|
||||
|
||||
@NoArgsConstructor
|
||||
@Data
|
||||
public static class TextDTO {
|
||||
/**
|
||||
* 保留字段,可忽略
|
||||
*/
|
||||
@JsonProperty("question_tokens")
|
||||
private Integer questionTokens;
|
||||
/**
|
||||
* 包含历史问题的总tokens大小
|
||||
*/
|
||||
@JsonProperty("prompt_tokens")
|
||||
private Integer promptTokens;
|
||||
/**
|
||||
* 回答的tokens大小
|
||||
*/
|
||||
@JsonProperty("completion_tokens")
|
||||
private Integer completionTokens;
|
||||
/**
|
||||
* prompt_tokens和completion_tokens的和,也是本次交互计费的tokens大小
|
||||
*/
|
||||
@JsonProperty("total_tokens")
|
||||
private Integer totalTokens;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,149 @@
|
|||
package com.sikadi.user.listener;
|
||||
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.sikadi.user.dto.MsgDTO;
|
||||
import com.sikadi.user.dto.request.RequestDTO;
|
||||
import com.sikadi.user.mapper.XfXhConfig;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import okhttp3.*;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import javax.annotation.Resource;
|
||||
import javax.crypto.Mac;
|
||||
import javax.crypto.spec.SecretKeySpec;
|
||||
import java.net.URL;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.text.SimpleDateFormat;
|
||||
import java.util.*;
|
||||
|
||||
/**
|
||||
* @author 狐狸半面添
|
||||
* @create 2023-09-15 1:10
|
||||
*/
|
||||
@Component
|
||||
@Slf4j
|
||||
public class XfXhStreamClient {
|
||||
@Resource
|
||||
private XfXhConfig xfXhConfig;
|
||||
|
||||
@Value("${xfxh.QPS}")
|
||||
private int connectionTokenCount;
|
||||
|
||||
/**
|
||||
* 获取令牌
|
||||
*/
|
||||
public static int GET_TOKEN_STATUS = 0;
|
||||
/**
|
||||
* 归还令牌
|
||||
*/
|
||||
public static int BACK_TOKEN_STATUS = 1;
|
||||
|
||||
/**
|
||||
* 操作令牌
|
||||
*
|
||||
* @param status 0-获取令牌 1-归还令牌
|
||||
* @return 是否操作成功
|
||||
*/
|
||||
public synchronized boolean operateToken(int status) {
|
||||
if (status == GET_TOKEN_STATUS) {
|
||||
// 获取令牌
|
||||
if (connectionTokenCount != 0) {
|
||||
// 说明还有令牌,将令牌数减一
|
||||
connectionTokenCount -= 1;
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
// 放回令牌
|
||||
connectionTokenCount += 1;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 发送消息
|
||||
*
|
||||
* @param uid 每个用户的id,用于区分不同用户
|
||||
* @param msgList 发送给大模型的消息,可以包含上下文内容
|
||||
* @return 获取websocket连接,以便于我们在获取完整大模型回复后手动关闭连接
|
||||
*/
|
||||
public WebSocket sendMsg(String uid, List<MsgDTO> msgList, WebSocketListener listener) {
|
||||
// 获取鉴权url
|
||||
String authUrl = this.getAuthUrl();
|
||||
// 鉴权方法生成失败,直接返回 null
|
||||
if (authUrl == null) {
|
||||
return null;
|
||||
}
|
||||
OkHttpClient okHttpClient = new OkHttpClient.Builder().build();
|
||||
// 将 https/http 连接替换为 ws/wss 连接
|
||||
String url = authUrl.replace("http://", "ws://").replace("https://", "wss://");
|
||||
Request request = new Request.Builder().url(url).build();
|
||||
// 建立 wss 连接
|
||||
WebSocket webSocket = okHttpClient.newWebSocket(request, listener);
|
||||
// 组装请求参数
|
||||
RequestDTO requestDTO = getRequestParam(uid, msgList);
|
||||
// 发送请求
|
||||
webSocket.send(JSONObject.toJSONString(requestDTO));
|
||||
return webSocket;
|
||||
}
|
||||
|
||||
/**
|
||||
* 生成鉴权方法,具体实现不用关心,这是讯飞官方定义的鉴权方式
|
||||
*
|
||||
* @return 鉴权访问大模型的路径
|
||||
*/
|
||||
public String getAuthUrl() {
|
||||
try {
|
||||
URL url = new URL(xfXhConfig.getHostUrl());
|
||||
// 时间
|
||||
SimpleDateFormat format = new SimpleDateFormat("EEE, dd MMM yyyy HH:mm:ss z", Locale.US);
|
||||
format.setTimeZone(TimeZone.getTimeZone("GMT"));
|
||||
String date = format.format(new Date());
|
||||
// 拼接
|
||||
String preStr = "host: " + url.getHost() + "\n" +
|
||||
"date: " + date + "\n" +
|
||||
"GET " + url.getPath() + " HTTP/1.1";
|
||||
// SHA256加密
|
||||
Mac mac = Mac.getInstance("hmacsha256");
|
||||
SecretKeySpec spec = new SecretKeySpec(xfXhConfig.getApiSecret().getBytes(StandardCharsets.UTF_8), "hmacsha256");
|
||||
mac.init(spec);
|
||||
|
||||
byte[] hexDigits = mac.doFinal(preStr.getBytes(StandardCharsets.UTF_8));
|
||||
// Base64加密
|
||||
String sha = Base64.getEncoder().encodeToString(hexDigits);
|
||||
// 拼接
|
||||
String authorizationOrigin = String.format("api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", xfXhConfig.getApiKey(), "hmac-sha256", "host date request-line", sha);
|
||||
// 拼接地址
|
||||
HttpUrl httpUrl = Objects.requireNonNull(HttpUrl.parse("https://" + url.getHost() + url.getPath())).newBuilder().
|
||||
addQueryParameter("authorization", Base64.getEncoder().encodeToString(authorizationOrigin.getBytes(StandardCharsets.UTF_8))).
|
||||
addQueryParameter("date", date).
|
||||
addQueryParameter("host", url.getHost()).
|
||||
build();
|
||||
|
||||
return httpUrl.toString();
|
||||
} catch (Exception e) {
|
||||
log.error("鉴权方法中发生错误:" + e.getMessage());
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取请求参数
|
||||
*
|
||||
* @param uid 每个用户的id,用于区分不同用户
|
||||
* @param msgList 发送给大模型的消息,可以包含上下文内容
|
||||
* @return 请求DTO,该 DTO 转 json 字符串后生成的格式参考 resources/demo-json/request.json
|
||||
*/
|
||||
public RequestDTO getRequestParam(String uid, List<MsgDTO> msgList) {
|
||||
RequestDTO requestDTO = new RequestDTO();
|
||||
requestDTO.setHeader(new RequestDTO.HeaderDTO(xfXhConfig.getAppId(), uid));
|
||||
requestDTO.setParameter(new RequestDTO.ParameterDTO(new RequestDTO.ParameterDTO.ChatDTO(xfXhConfig.getDomain(), xfXhConfig.getTemperature(), xfXhConfig.getMaxTokens())));
|
||||
requestDTO.setPayload(new RequestDTO.PayloadDTO(new RequestDTO.PayloadDTO.MessageDTO(msgList)));
|
||||
return requestDTO;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
@ -0,0 +1,77 @@
|
|||
package com.sikadi.user.listener;
|
||||
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.alibaba.nacos.shaded.org.checkerframework.checker.nullness.qual.Nullable;
|
||||
import com.sikadi.user.dto.MsgDTO;
|
||||
import com.sikadi.user.dto.response.ResponseDTO;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import okhttp3.Response;
|
||||
import okhttp3.WebSocket;
|
||||
import okhttp3.WebSocketListener;
|
||||
|
||||
import javax.validation.constraints.NotNull;
|
||||
|
||||
/**
|
||||
* @author 狐狸半面添
|
||||
* @create 2023-09-15 1:11
|
||||
*/
|
||||
/**
|
||||
* @author 狐狸半面添
|
||||
* @create 2023-09-15 1:11
|
||||
*/
|
||||
@Slf4j
|
||||
public class XfXhWebSocketListener extends WebSocketListener {
|
||||
private StringBuilder answer = new StringBuilder();
|
||||
|
||||
private boolean wsCloseFlag = false;
|
||||
|
||||
public StringBuilder getAnswer() {
|
||||
return answer;
|
||||
}
|
||||
|
||||
public boolean isWsCloseFlag() {
|
||||
return wsCloseFlag;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onOpen(@NotNull WebSocket webSocket, @NotNull Response response) {
|
||||
super.onOpen(webSocket, response);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onMessage(@NotNull WebSocket webSocket, @NotNull String text) {
|
||||
super.onMessage(webSocket, text);
|
||||
// 将大模型回复的 JSON 文本转为 ResponseDTO 对象
|
||||
ResponseDTO responseData = JSONObject.parseObject(text, ResponseDTO.class);
|
||||
// 如果响应数据中的 header 的 code 值不为 0,则表示响应错误
|
||||
if (responseData.getHeader().getCode() != 0) {
|
||||
// 日志记录
|
||||
log.error("发生错误,错误码为:" + responseData.getHeader().getCode() + "; " + "信息:" + responseData.getHeader().getMessage());
|
||||
// 设置回答
|
||||
this.answer = new StringBuilder("大模型响应错误,请稍后再试");
|
||||
// 关闭连接标识
|
||||
wsCloseFlag = true;
|
||||
return;
|
||||
}
|
||||
// 将回答进行拼接
|
||||
for (MsgDTO msgDTO : responseData.getPayload().getChoices().getText()) {
|
||||
this.answer.append(msgDTO.getContent());
|
||||
}
|
||||
// 对最后一个文本结果进行处理
|
||||
if (2 == responseData.getHeader().getStatus()) {
|
||||
wsCloseFlag = true;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onFailure(@NotNull WebSocket webSocket, @NotNull Throwable t, @Nullable Response response) {
|
||||
super.onFailure(webSocket, t, response);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onClosed(@NotNull WebSocket webSocket, int code, @NotNull String reason) {
|
||||
super.onClosed(webSocket, code, reason);
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue