196 lines
4.8 KiB
Go
196 lines
4.8 KiB
Go
package handler
|
||
|
||
import (
|
||
"encoding/json"
|
||
"net/http"
|
||
"time"
|
||
|
||
"gongzheng_minimax/service"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
)
|
||
|
||
// LLMHandler handles HTTP requests for the LLM service
|
||
type LLMHandler struct {
|
||
llmService *service.LLMService
|
||
}
|
||
|
||
// NewLLMHandler creates a new instance of LLMHandler
|
||
func NewLLMHandler(llmService *service.LLMService) *LLMHandler {
|
||
return &LLMHandler{
|
||
llmService: llmService,
|
||
}
|
||
}
|
||
|
||
// Chat handles chat requests
|
||
func (h *LLMHandler) Chat(c *gin.Context) {
|
||
var requestData map[string]interface{}
|
||
if err := c.ShouldBindJSON(&requestData); err != nil {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request data"})
|
||
return
|
||
}
|
||
|
||
response, err := h.llmService.CallLLMAPI(requestData)
|
||
if err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||
return
|
||
}
|
||
|
||
// Check if the response is a channel (streaming response)
|
||
if messageChan, ok := response.(chan service.Message); ok {
|
||
// Set headers for SSE
|
||
c.Header("Content-Type", "text/event-stream")
|
||
c.Header("Cache-Control", "no-cache")
|
||
c.Header("Connection", "keep-alive")
|
||
c.Header("Transfer-Encoding", "chunked")
|
||
|
||
// Create a channel to handle client disconnection
|
||
clientGone := c.Writer.CloseNotify()
|
||
|
||
// Stream the messages
|
||
for {
|
||
select {
|
||
case <-clientGone:
|
||
return
|
||
case message, ok := <-messageChan:
|
||
if !ok {
|
||
return
|
||
}
|
||
|
||
// Convert message to JSON
|
||
jsonData, err := json.Marshal(message)
|
||
if err != nil {
|
||
continue
|
||
}
|
||
|
||
// Write the SSE message
|
||
c.SSEvent("message", string(jsonData))
|
||
c.Writer.Flush()
|
||
|
||
// If this is the end message, close the connection
|
||
if message.IsEnd {
|
||
return
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// Non-streaming response
|
||
c.JSON(http.StatusOK, response)
|
||
}
|
||
|
||
// StopConversation handles stopping a conversation
|
||
func (h *LLMHandler) StopConversation(c *gin.Context) {
|
||
taskID := c.Param("task_id")
|
||
if taskID == "" {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": "Task ID is required"})
|
||
return
|
||
}
|
||
|
||
result, err := h.llmService.StopConversation(taskID)
|
||
if err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||
return
|
||
}
|
||
|
||
c.JSON(http.StatusOK, result)
|
||
}
|
||
|
||
// DeleteConversation handles deleting a conversation
|
||
func (h *LLMHandler) DeleteConversation(c *gin.Context) {
|
||
conversationID := c.Param("conversation_id")
|
||
if conversationID == "" {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": "Conversation ID is required"})
|
||
return
|
||
}
|
||
|
||
user := c.DefaultQuery("user", "default_user")
|
||
result, err := h.llmService.DeleteConversation(conversationID, user)
|
||
if err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||
return
|
||
}
|
||
|
||
c.JSON(http.StatusOK, result)
|
||
}
|
||
|
||
// SynthesizeSpeech handles text-to-speech requests
|
||
func (h *LLMHandler) SynthesizeSpeech(c *gin.Context) {
|
||
var request struct {
|
||
Text string `json:"text" binding:"required"`
|
||
Audio string `json:"audio" binding:"required"`
|
||
}
|
||
|
||
if err := c.ShouldBindJSON(&request); err != nil {
|
||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request data"})
|
||
return
|
||
}
|
||
|
||
result, err := h.llmService.SynthesizeSpeech(request.Text, request.Audio)
|
||
if err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||
return
|
||
}
|
||
|
||
c.JSON(http.StatusOK, result)
|
||
}
|
||
|
||
// StreamText handles streaming text output
|
||
func (h *LLMHandler) StreamText(c *gin.Context) {
|
||
// Set headers for SSE
|
||
c.Header("Content-Type", "text/event-stream")
|
||
c.Header("Cache-Control", "no-cache")
|
||
c.Header("Connection", "keep-alive")
|
||
c.Header("Transfer-Encoding", "chunked")
|
||
|
||
segments := []string{
|
||
"好的,",
|
||
"我已经成功替换了文本内容。",
|
||
"新的文本是一段连续的描述,",
|
||
"没有换行,",
|
||
"总共65个字符,",
|
||
"符合100字以内的要求,",
|
||
"并且是一个连续的段落。",
|
||
"现在我需要完成任务。",
|
||
}
|
||
|
||
// Create a channel to handle client disconnection
|
||
clientGone := c.Writer.CloseNotify()
|
||
|
||
conversationID := "conv_" + time.Now().Format("20060102150405")
|
||
taskID := "task_" + time.Now().Format("20060102150405")
|
||
|
||
// Stream the segments
|
||
for _, segment := range segments {
|
||
select {
|
||
case <-clientGone:
|
||
return
|
||
default:
|
||
// Create message object
|
||
message := map[string]interface{}{
|
||
"event": "message",
|
||
"answer": segment,
|
||
"conversation_id": conversationID,
|
||
"task_id": taskID,
|
||
}
|
||
|
||
// Convert to JSON and send
|
||
jsonData, _ := json.Marshal(message)
|
||
c.Writer.Write([]byte("data: " + string(jsonData) + "\n\n"))
|
||
c.Writer.Flush()
|
||
}
|
||
}
|
||
|
||
// Send end message
|
||
endMessage := map[string]interface{}{
|
||
"event": "message_end",
|
||
"answer": "",
|
||
"conversation_id": conversationID,
|
||
"task_id": taskID,
|
||
}
|
||
|
||
jsonData, _ := json.Marshal(endMessage)
|
||
c.Writer.Write([]byte("data: " + string(jsonData) + "\n\n"))
|
||
c.Writer.Flush()
|
||
}
|