2025-06-17 20:40:48 +08:00

196 lines
4.8 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package handler
import (
"encoding/json"
"net/http"
"time"
"gongzheng_minimax/service"
"github.com/gin-gonic/gin"
)
// LLMHandler handles HTTP requests for the LLM service
type LLMHandler struct {
llmService *service.LLMService
}
// NewLLMHandler creates a new instance of LLMHandler
func NewLLMHandler(llmService *service.LLMService) *LLMHandler {
return &LLMHandler{
llmService: llmService,
}
}
// Chat handles chat requests
func (h *LLMHandler) Chat(c *gin.Context) {
var requestData map[string]interface{}
if err := c.ShouldBindJSON(&requestData); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request data"})
return
}
response, err := h.llmService.CallLLMAPI(requestData)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// Check if the response is a channel (streaming response)
if messageChan, ok := response.(chan service.Message); ok {
// Set headers for SSE
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Transfer-Encoding", "chunked")
// Create a channel to handle client disconnection
clientGone := c.Writer.CloseNotify()
// Stream the messages
for {
select {
case <-clientGone:
return
case message, ok := <-messageChan:
if !ok {
return
}
// Convert message to JSON
jsonData, err := json.Marshal(message)
if err != nil {
continue
}
// Write the SSE message
c.SSEvent("message", string(jsonData))
c.Writer.Flush()
// If this is the end message, close the connection
if message.IsEnd {
return
}
}
}
}
// Non-streaming response
c.JSON(http.StatusOK, response)
}
// StopConversation handles stopping a conversation
func (h *LLMHandler) StopConversation(c *gin.Context) {
taskID := c.Param("task_id")
if taskID == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "Task ID is required"})
return
}
result, err := h.llmService.StopConversation(taskID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, result)
}
// DeleteConversation handles deleting a conversation
func (h *LLMHandler) DeleteConversation(c *gin.Context) {
conversationID := c.Param("conversation_id")
if conversationID == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "Conversation ID is required"})
return
}
user := c.DefaultQuery("user", "default_user")
result, err := h.llmService.DeleteConversation(conversationID, user)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, result)
}
// SynthesizeSpeech handles text-to-speech requests
func (h *LLMHandler) SynthesizeSpeech(c *gin.Context) {
var request struct {
Text string `json:"text" binding:"required"`
Audio string `json:"audio" binding:"required"`
}
if err := c.ShouldBindJSON(&request); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request data"})
return
}
result, err := h.llmService.SynthesizeSpeech(request.Text, request.Audio)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, result)
}
// StreamText handles streaming text output
func (h *LLMHandler) StreamText(c *gin.Context) {
// Set headers for SSE
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Transfer-Encoding", "chunked")
segments := []string{
"好的,",
"我已经成功替换了文本内容。",
"新的文本是一段连续的描述,",
"没有换行,",
"总共65个字符",
"符合100字以内的要求",
"并且是一个连续的段落。",
"现在我需要完成任务。",
}
// Create a channel to handle client disconnection
clientGone := c.Writer.CloseNotify()
conversationID := "conv_" + time.Now().Format("20060102150405")
taskID := "task_" + time.Now().Format("20060102150405")
// Stream the segments
for _, segment := range segments {
select {
case <-clientGone:
return
default:
// Create message object
message := map[string]interface{}{
"event": "message",
"answer": segment,
"conversation_id": conversationID,
"task_id": taskID,
}
// Convert to JSON and send
jsonData, _ := json.Marshal(message)
c.Writer.Write([]byte("data: " + string(jsonData) + "\n\n"))
c.Writer.Flush()
}
}
// Send end message
endMessage := map[string]interface{}{
"event": "message_end",
"answer": "",
"conversation_id": conversationID,
"task_id": taskID,
}
jsonData, _ := json.Marshal(endMessage)
c.Writer.Write([]byte("data: " + string(jsonData) + "\n\n"))
c.Writer.Flush()
}