All checks were successful
Gitea Actions Demo / Explore-Gitea-Actions (push) Successful in 42s
290 lines
7.4 KiB
Go
290 lines
7.4 KiB
Go
package handler
|
||
|
||
import (
|
||
"encoding/json"
|
||
"fmt"
|
||
"hash/fnv"
|
||
"net/http"
|
||
"strconv"
|
||
"time"
|
||
|
||
"gongzheng_minimax/service"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
)
|
||
|
||
type ChatExtRequest struct {
|
||
Sid string `json:"sid"`
|
||
DhCode string `json:"dh-code"`
|
||
DhQuestion string `json:"dh-question"` // 映射为 content
|
||
DhConversationID string `json:"dh-conversation-id"` // 需要转换为 conversation_id
|
||
DhContext []interface{} `json:"dh-context"` // 上下文
|
||
}
|
||
|
||
// LLMHandler handles HTTP requests for the LLM service
|
||
type LLMHandler struct {
|
||
llmService *service.LLMService
|
||
}
|
||
|
||
// NewLLMHandler creates a new instance of LLMHandler
|
||
func NewLLMHandler(llmService *service.LLMService) *LLMHandler {
|
||
return &LLMHandler{
|
||
llmService: llmService,
|
||
}
|
||
}
|
||
|
||
// Chat handles chat requests
|
||
func (h *LLMHandler) Chat(c *gin.Context) {
|
||
var requestData map[string]interface{}
|
||
if err := c.ShouldBindJSON(&requestData); err != nil {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request data"})
|
||
return
|
||
}
|
||
|
||
response, err := h.llmService.CallLLMAPI(requestData)
|
||
if err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||
return
|
||
}
|
||
|
||
// Check if the response is a channel (streaming response)
|
||
if messageChan, ok := response.(chan service.Message); ok {
|
||
// Set headers for SSE
|
||
c.Header("Content-Type", "text/event-stream")
|
||
c.Header("Cache-Control", "no-cache")
|
||
c.Header("Connection", "keep-alive")
|
||
c.Header("Transfer-Encoding", "chunked")
|
||
|
||
// Create a channel to handle client disconnection
|
||
clientGone := c.Writer.CloseNotify()
|
||
|
||
// Stream the messages
|
||
for {
|
||
select {
|
||
case <-clientGone:
|
||
return
|
||
case message, ok := <-messageChan:
|
||
if !ok {
|
||
return
|
||
}
|
||
|
||
// Convert message to JSON
|
||
jsonData, err := json.Marshal(message)
|
||
if err != nil {
|
||
continue
|
||
}
|
||
|
||
// Write the SSE message
|
||
c.SSEvent("message", string(jsonData))
|
||
c.Writer.Flush()
|
||
|
||
// If this is the end message, close the connection
|
||
if message.IsEnd {
|
||
return
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// Non-streaming response
|
||
c.JSON(http.StatusOK, response)
|
||
}
|
||
|
||
// ChatExt handles external QA chat requests
|
||
func (h *LLMHandler) ChatExt(c *gin.Context) {
|
||
var requestData ChatExtRequest
|
||
if err := c.ShouldBindJSON(&requestData); err != nil {
|
||
fmt.Printf("Error binding JSON: %v\n", err)
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request data"})
|
||
return
|
||
}
|
||
|
||
fmt.Printf("Received ChatExt request: %+v\n", requestData)
|
||
|
||
// 将UUID转换为19位数字字符串
|
||
h64 := fnv.New64a()
|
||
h64.Write([]byte(requestData.DhConversationID))
|
||
hashValue := h64.Sum64()
|
||
conversationID := strconv.FormatUint(hashValue, 10)
|
||
// 确保长度为19位,不足补0,超过截取(FNV64通常是19-20位数字)
|
||
if len(conversationID) < 19 {
|
||
conversationID = fmt.Sprintf("%019s", conversationID)
|
||
} else if len(conversationID) > 19 {
|
||
conversationID = conversationID[:19]
|
||
}
|
||
|
||
fmt.Printf("Converted ConversationID: %s -> %s\n", requestData.DhConversationID, conversationID)
|
||
|
||
// 构造 Service 层需要的参数 map
|
||
serviceData := map[string]interface{}{
|
||
"tag_ids": []string{"1", "2"},
|
||
"conversation_id": conversationID,
|
||
"content": requestData.DhQuestion,
|
||
}
|
||
|
||
fmt.Printf("Calling Service with data: %+v\n", serviceData)
|
||
|
||
response, err := h.llmService.CallExtQAAPI(serviceData)
|
||
if err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||
return
|
||
}
|
||
|
||
// Check if the response is a channel (streaming response)
|
||
if messageChan, ok := response.(chan service.Message); ok {
|
||
// Set headers for SSE
|
||
c.Header("Content-Type", "text/event-stream")
|
||
c.Header("Cache-Control", "no-cache")
|
||
c.Header("Connection", "keep-alive")
|
||
c.Header("Transfer-Encoding", "chunked")
|
||
|
||
// Create a channel to handle client disconnection
|
||
clientGone := c.Writer.CloseNotify()
|
||
|
||
// Stream the messages
|
||
for {
|
||
select {
|
||
case <-clientGone:
|
||
return
|
||
case message, ok := <-messageChan:
|
||
if !ok {
|
||
return
|
||
}
|
||
|
||
// Convert message to JSON
|
||
jsonData, err := json.Marshal(message)
|
||
if err != nil {
|
||
continue
|
||
}
|
||
|
||
// Write the SSE message
|
||
c.SSEvent("message", string(jsonData))
|
||
c.Writer.Flush()
|
||
|
||
// If this is the end message, close the connection
|
||
if message.IsEnd {
|
||
return
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// Non-streaming response
|
||
c.JSON(http.StatusOK, response)
|
||
}
|
||
|
||
// StopConversation handles stopping a conversation
|
||
func (h *LLMHandler) StopConversation(c *gin.Context) {
|
||
taskID := c.Param("task_id")
|
||
if taskID == "" {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": "Task ID is required"})
|
||
return
|
||
}
|
||
|
||
result, err := h.llmService.StopConversation(taskID)
|
||
if err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||
return
|
||
}
|
||
|
||
c.JSON(http.StatusOK, result)
|
||
}
|
||
|
||
// DeleteConversation handles deleting a conversation
|
||
func (h *LLMHandler) DeleteConversation(c *gin.Context) {
|
||
conversationID := c.Param("conversation_id")
|
||
if conversationID == "" {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": "Conversation ID is required"})
|
||
return
|
||
}
|
||
|
||
user := c.DefaultQuery("user", "default_user")
|
||
result, err := h.llmService.DeleteConversation(conversationID, user)
|
||
if err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||
return
|
||
}
|
||
|
||
c.JSON(http.StatusOK, result)
|
||
}
|
||
|
||
// SynthesizeSpeech handles text-to-speech requests
|
||
func (h *LLMHandler) SynthesizeSpeech(c *gin.Context) {
|
||
var request struct {
|
||
Text string `json:"text" binding:"required"`
|
||
Audio string `json:"audio" binding:"required"`
|
||
}
|
||
|
||
if err := c.ShouldBindJSON(&request); err != nil {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request data"})
|
||
return
|
||
}
|
||
|
||
result, err := h.llmService.SynthesizeSpeech(request.Text, request.Audio)
|
||
if err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||
return
|
||
}
|
||
|
||
c.JSON(http.StatusOK, result)
|
||
}
|
||
|
||
// StreamText handles streaming text output
|
||
func (h *LLMHandler) StreamText(c *gin.Context) {
|
||
// Set headers for SSE
|
||
c.Header("Content-Type", "text/event-stream")
|
||
c.Header("Cache-Control", "no-cache")
|
||
c.Header("Connection", "keep-alive")
|
||
c.Header("Transfer-Encoding", "chunked")
|
||
|
||
segments := []string{
|
||
"好的,",
|
||
"我已经成功替换了文本内容。",
|
||
"新的文本是一段连续的描述,",
|
||
"没有换行,",
|
||
"总共65个字符,",
|
||
"符合100字以内的要求,",
|
||
"并且是一个连续的段落。",
|
||
"现在我需要完成任务。",
|
||
}
|
||
|
||
// Create a channel to handle client disconnection
|
||
clientGone := c.Writer.CloseNotify()
|
||
|
||
conversationID := "conv_" + time.Now().Format("20060102150405")
|
||
taskID := "task_" + time.Now().Format("20060102150405")
|
||
|
||
// Stream the segments
|
||
for _, segment := range segments {
|
||
select {
|
||
case <-clientGone:
|
||
return
|
||
default:
|
||
// Create message object
|
||
message := map[string]interface{}{
|
||
"event": "message",
|
||
"answer": segment,
|
||
"conversation_id": conversationID,
|
||
"task_id": taskID,
|
||
}
|
||
|
||
// Convert to JSON and send
|
||
jsonData, _ := json.Marshal(message)
|
||
c.Writer.Write([]byte("data: " + string(jsonData) + "\n\n"))
|
||
c.Writer.Flush()
|
||
}
|
||
}
|
||
|
||
// Send end message
|
||
endMessage := map[string]interface{}{
|
||
"event": "message_end",
|
||
"answer": "",
|
||
"conversation_id": conversationID,
|
||
"task_id": taskID,
|
||
}
|
||
|
||
jsonData, _ := json.Marshal(endMessage)
|
||
c.Writer.Write([]byte("data: " + string(jsonData) + "\n\n"))
|
||
c.Writer.Flush()
|
||
}
|