go_digital_backend/handler/llm_handler.go
songjvcheng 24f72168dd
All checks were successful
Gitea Actions Demo / Explore-Gitea-Actions (push) Successful in 38s
优化
2025-12-30 18:12:43 +08:00

290 lines
7.4 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package handler
import (
"encoding/json"
"fmt"
"hash/fnv"
"net/http"
"strconv"
"time"
"gongzheng_minimax/service"
"github.com/gin-gonic/gin"
)
type ChatExtRequest struct {
Sid string `json:"sid"`
DhCode string `json:"dh-code"`
DhQuestion string `json:"dh-question"` // 映射为 content
DhConversationID string `json:"dh-conversation-id"` // 需要转换为 conversation_id
DhContext []interface{} `json:"dh-context"` // 上下文
}
// LLMHandler handles HTTP requests for the LLM service
type LLMHandler struct {
llmService *service.LLMService
}
// NewLLMHandler creates a new instance of LLMHandler
func NewLLMHandler(llmService *service.LLMService) *LLMHandler {
return &LLMHandler{
llmService: llmService,
}
}
// Chat handles chat requests
func (h *LLMHandler) Chat(c *gin.Context) {
var requestData map[string]interface{}
if err := c.ShouldBindJSON(&requestData); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request data"})
return
}
response, err := h.llmService.CallLLMAPI(requestData)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// Check if the response is a channel (streaming response)
if messageChan, ok := response.(chan service.Message); ok {
// Set headers for SSE
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Transfer-Encoding", "chunked")
// Create a channel to handle client disconnection
clientGone := c.Writer.CloseNotify()
// Stream the messages
for {
select {
case <-clientGone:
return
case message, ok := <-messageChan:
if !ok {
return
}
// Convert message to JSON
jsonData, err := json.Marshal(message)
if err != nil {
continue
}
// Write the SSE message
c.SSEvent("message", string(jsonData))
c.Writer.Flush()
// If this is the end message, close the connection
if message.IsEnd {
return
}
}
}
}
// Non-streaming response
c.JSON(http.StatusOK, response)
}
// ChatExt handles external QA chat requests
func (h *LLMHandler) ChatExt(c *gin.Context) {
var requestData ChatExtRequest
if err := c.ShouldBindJSON(&requestData); err != nil {
fmt.Printf("Error binding JSON: %v\n", err)
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request data"})
return
}
fmt.Printf("Received ChatExt request: %+v\n", requestData)
// 将UUID转换为19位数字字符串
h64 := fnv.New64a()
h64.Write([]byte(requestData.DhConversationID))
hashValue := h64.Sum64()
conversationID := strconv.FormatUint(hashValue, 10)
// 确保长度为19位不足补0超过截取FNV64通常是19-20位数字
if len(conversationID) < 19 {
conversationID = fmt.Sprintf("%019s", conversationID)
} else if len(conversationID) > 19 {
conversationID = conversationID[:19]
}
fmt.Printf("Converted ConversationID: %s -> %s\n", requestData.DhConversationID, conversationID)
// 构造 Service 层需要的参数 map
serviceData := map[string]interface{}{
"tag_ids": []int{1, 11, 29},
"conversation_id": conversationID,
"content": requestData.DhQuestion,
}
fmt.Printf("Calling Service with data: %+v\n", serviceData)
response, err := h.llmService.CallExtQAAPIStreamDirect(serviceData)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// Check if the response is a channel (streaming response)
if messageChan, ok := response.(chan service.Message); ok {
// Set headers for SSE
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Transfer-Encoding", "chunked")
// Create a channel to handle client disconnection
clientGone := c.Writer.CloseNotify()
// Stream the messages
for {
select {
case <-clientGone:
return
case message, ok := <-messageChan:
if !ok {
return
}
// Convert message to JSON
jsonData, err := json.Marshal(message)
if err != nil {
continue
}
// Write the SSE message
c.SSEvent("message", string(jsonData))
c.Writer.Flush()
// If this is the end message, close the connection
if message.IsEnd {
return
}
}
}
}
// Non-streaming response
c.JSON(http.StatusOK, response)
}
// StopConversation handles stopping a conversation
func (h *LLMHandler) StopConversation(c *gin.Context) {
taskID := c.Param("task_id")
if taskID == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "Task ID is required"})
return
}
result, err := h.llmService.StopConversation(taskID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, result)
}
// DeleteConversation handles deleting a conversation
func (h *LLMHandler) DeleteConversation(c *gin.Context) {
conversationID := c.Param("conversation_id")
if conversationID == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "Conversation ID is required"})
return
}
user := c.DefaultQuery("user", "default_user")
result, err := h.llmService.DeleteConversation(conversationID, user)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, result)
}
// SynthesizeSpeech handles text-to-speech requests
func (h *LLMHandler) SynthesizeSpeech(c *gin.Context) {
var request struct {
Text string `json:"text" binding:"required"`
Audio string `json:"audio" binding:"required"`
}
if err := c.ShouldBindJSON(&request); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request data"})
return
}
result, err := h.llmService.SynthesizeSpeech(request.Text, request.Audio)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, result)
}
// StreamText handles streaming text output
func (h *LLMHandler) StreamText(c *gin.Context) {
// Set headers for SSE
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Transfer-Encoding", "chunked")
segments := []string{
"好的,",
"我已经成功替换了文本内容。",
"新的文本是一段连续的描述,",
"没有换行,",
"总共65个字符",
"符合100字以内的要求",
"并且是一个连续的段落。",
"现在我需要完成任务。",
}
// Create a channel to handle client disconnection
clientGone := c.Writer.CloseNotify()
conversationID := "conv_" + time.Now().Format("20060102150405")
taskID := "task_" + time.Now().Format("20060102150405")
// Stream the segments
for _, segment := range segments {
select {
case <-clientGone:
return
default:
// Create message object
message := map[string]interface{}{
"event": "message",
"answer": segment,
"conversation_id": conversationID,
"task_id": taskID,
}
// Convert to JSON and send
jsonData, _ := json.Marshal(message)
c.Writer.Write([]byte("data: " + string(jsonData) + "\n\n"))
c.Writer.Flush()
}
}
// Send end message
endMessage := map[string]interface{}{
"event": "message_end",
"answer": "",
"conversation_id": conversationID,
"task_id": taskID,
}
jsonData, _ := json.Marshal(endMessage)
c.Writer.Write([]byte("data: " + string(jsonData) + "\n\n"))
c.Writer.Flush()
}