dh chat api

This commit is contained in:
Song367 2025-12-27 16:30:27 +08:00
parent 6511d48d7b
commit babbf126aa
4 changed files with 355 additions and 40 deletions

View File

@ -30,12 +30,12 @@ jobs:
uses: https://gitea.yantootech.com/neil/build-push-action@v6 uses: https://gitea.yantootech.com/neil/build-push-action@v6
with: with:
push: true push: true
tags: 14.103.114.237:30005/gongzheng-backend:${{ gitea.run_id }} tags: 14.103.114.237:30005/dh-backend:${{ gitea.run_id }}
- name: Install - name: Install
run: | run: |
helm upgrade --install gongzheng-backend ./.gitea/charts \ helm upgrade --install dh-backend ./.gitea/charts \
--namespace gongzhengb \ --namespace dh \
--create-namespace \ --create-namespace \
--set image.repository=14.103.114.237:30005/gongzheng-backend \ --set image.repository=14.103.114.237:30005/dh-backend \
--set image.tag=${{ gitea.run_id }} --set image.tag=${{ gitea.run_id }}
- run: echo "🍏 This job's status is ${{ job.status }}." - run: echo "🍏 This job's status is ${{ job.status }}."

View File

@ -79,6 +79,63 @@ func (h *LLMHandler) Chat(c *gin.Context) {
c.JSON(http.StatusOK, response) c.JSON(http.StatusOK, response)
} }
// ChatExt handles external QA chat requests
func (h *LLMHandler) ChatExt(c *gin.Context) {
var requestData map[string]interface{}
if err := c.ShouldBindJSON(&requestData); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request data"})
return
}
response, err := h.llmService.CallExtQAAPI(requestData)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// Check if the response is a channel (streaming response)
if messageChan, ok := response.(chan service.Message); ok {
// Set headers for SSE
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Transfer-Encoding", "chunked")
// Create a channel to handle client disconnection
clientGone := c.Writer.CloseNotify()
// Stream the messages
for {
select {
case <-clientGone:
return
case message, ok := <-messageChan:
if !ok {
return
}
// Convert message to JSON
jsonData, err := json.Marshal(message)
if err != nil {
continue
}
// Write the SSE message
c.SSEvent("message", string(jsonData))
c.Writer.Flush()
// If this is the end message, close the connection
if message.IsEnd {
return
}
}
}
}
// Non-streaming response
c.JSON(http.StatusOK, response)
}
// StopConversation handles stopping a conversation // StopConversation handles stopping a conversation
func (h *LLMHandler) StopConversation(c *gin.Context) { func (h *LLMHandler) StopConversation(c *gin.Context) {
taskID := c.Param("task_id") taskID := c.Param("task_id")

View File

@ -69,6 +69,7 @@ func main() {
// Define routes // Define routes
router.POST("/chat", llmHandler.Chat) router.POST("/chat", llmHandler.Chat)
router.POST("/chat-ext", llmHandler.ChatExt)
router.POST("/chat-messages/:task_id/stop", llmHandler.StopConversation) router.POST("/chat-messages/:task_id/stop", llmHandler.StopConversation)
router.DELETE("/conversations/:conversation_id", llmHandler.DeleteConversation) router.DELETE("/conversations/:conversation_id", llmHandler.DeleteConversation)
router.POST("/speech/synthesize", llmHandler.SynthesizeSpeech) router.POST("/speech/synthesize", llmHandler.SynthesizeSpeech)

View File

@ -127,6 +127,24 @@ type LLMOurRequestPayload struct {
Messages []LLMOurMessage `json:"messages"` Messages []LLMOurMessage `json:"messages"`
} }
// ExtQARequestPayload represents the payload for the external QA API
type ExtQARequestPayload struct {
TagIDs []int `json:"tag_ids"`
ConversationID string `json:"conversation_id"`
Content string `json:"content"`
}
// ExtQAResponse represents the response from the external QA API
type ExtQAResponse struct {
RequestID string `json:"request_id"`
Output struct {
Text string `json:"text"`
FinishReason string `json:"finish_reason"`
SessionID string `json:"session_id"`
DocReferences []interface{} `json:"doc_references"`
} `json:"output"`
}
// NewLLMService creates a new instance of LLMService // NewLLMService creates a new instance of LLMService
func NewLLMService(config Config) *LLMService { func NewLLMService(config Config) *LLMService {
return &LLMService{ return &LLMService{
@ -981,12 +999,10 @@ func (s *LLMService) TrimAudioSilence(audioData string) (string, error) {
if err := binary.Read(buf, binary.LittleEndian, &dataChunk); err != nil { if err := binary.Read(buf, binary.LittleEndian, &dataChunk); err != nil {
return "", fmt.Errorf("error reading chunk header: %v", err) return "", fmt.Errorf("error reading chunk header: %v", err)
} }
if string(dataChunk.Subchunk2ID[:]) == "data" { if string(dataChunk.Subchunk2ID[:]) == "data" {
break break
} }
// Skip this chunk if it's not "data"
// Skip this chunk
if _, err := buf.Seek(int64(dataChunk.Subchunk2Size), io.SeekCurrent); err != nil { if _, err := buf.Seek(int64(dataChunk.Subchunk2Size), io.SeekCurrent); err != nil {
return "", fmt.Errorf("error skipping chunk: %v", err) return "", fmt.Errorf("error skipping chunk: %v", err)
} }
@ -998,38 +1014,38 @@ func (s *LLMService) TrimAudioSilence(audioData string) (string, error) {
return "", fmt.Errorf("error reading audio data: %v", err) return "", fmt.Errorf("error reading audio data: %v", err)
} }
// Calculate samples per channel // Calculate samples
samplesPerChannel := len(audioBytes) / int(fmtChunk.BlockAlign) bytesPerSample := int(fmtChunk.BitsPerSample / 8)
channels := int(fmtChunk.NumChannels) if bytesPerSample == 0 {
bytesPerSample := int(fmtChunk.BitsPerSample) / 8 bytesPerSample = 2 // Default to 16-bit if unknown
}
numSamples := len(audioBytes) / bytesPerSample
// Find last non-silent sample
// Threshold: approx 1% of max amplitude for 16-bit audio (32768 * 0.01 ~= 327)
threshold := 327.0
// Find the last non-silent sample
lastNonSilent := 0 lastNonSilent := 0
silenceThreshold := 0.01 // Adjust this threshold as needed
for i := 0; i < samplesPerChannel; i++ { for i := 0; i < numSamples; i++ {
isSilent := true // Get sample value
for ch := 0; ch < channels; ch++ {
offset := i*int(fmtChunk.BlockAlign) + ch*bytesPerSample
if offset+bytesPerSample > len(audioBytes) {
continue
}
// Convert bytes to sample value
var sample int16 var sample int16
if err := binary.Read(bytes.NewReader(audioBytes[offset:offset+bytesPerSample]), binary.LittleEndian, &sample); err != nil { offset := i * bytesPerSample
continue
}
// Normalize sample to [-1, 1] range if offset+bytesPerSample > len(audioBytes) {
normalizedSample := float64(sample) / 32768.0
if math.Abs(normalizedSample) > silenceThreshold {
isSilent = false
break break
} }
if bytesPerSample == 2 {
sample = int16(binary.LittleEndian.Uint16(audioBytes[offset : offset+2]))
} else if bytesPerSample == 1 {
// 8-bit audio is usually unsigned 0-255, center at 128
sample = int16(audioBytes[offset]) - 128
sample *= 256 // Scale to 16-bit range roughly
} }
if !isSilent { if math.Abs(float64(sample)) > threshold {
lastNonSilent = i lastNonSilent = i
} }
} }
@ -1037,12 +1053,12 @@ func (s *LLMService) TrimAudioSilence(audioData string) (string, error) {
// Add a small buffer (e.g., 0.1 seconds) after the last non-silent sample // Add a small buffer (e.g., 0.1 seconds) after the last non-silent sample
bufferSamples := int(float64(fmtChunk.SampleRate) * 0.1) bufferSamples := int(float64(fmtChunk.SampleRate) * 0.1)
lastSample := lastNonSilent + bufferSamples lastSample := lastNonSilent + bufferSamples
if lastSample > samplesPerChannel { if lastSample > numSamples {
lastSample = samplesPerChannel lastSample = numSamples
} }
// Calculate new data size // Calculate new data size
newDataSize := lastSample * int(fmtChunk.BlockAlign) newDataSize := lastSample * bytesPerSample
trimmedAudio := audioBytes[:newDataSize] trimmedAudio := audioBytes[:newDataSize]
// Create new buffer for the trimmed audio // Create new buffer for the trimmed audio
@ -1070,19 +1086,260 @@ func (s *LLMService) TrimAudioSilence(audioData string) (string, error) {
return "", fmt.Errorf("error writing trimmed audio data: %v", err) return "", fmt.Errorf("error writing trimmed audio data: %v", err)
} }
// Encode back to base64
return base64.StdEncoding.EncodeToString(newBuf.Bytes()), nil return base64.StdEncoding.EncodeToString(newBuf.Bytes()), nil
} }
// CallExtQAAPI handles the external QA API call
func (s *LLMService) CallExtQAAPI(data map[string]interface{}) (interface{}, error) {
var tagIDs []int
if tagIDsRaw, ok := data["tag_ids"].([]interface{}); ok {
for _, v := range tagIDsRaw {
if id, ok := v.(float64); ok {
tagIDs = append(tagIDs, int(id))
}
}
}
payload := ExtQARequestPayload{
TagIDs: tagIDs,
ConversationID: getString(data, "conversation_id"),
Content: getString(data, "content"),
}
jsonData, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("error marshaling payload: %v", err)
}
url := "http://47.100.108.206:30028/api/qa/v1/chat/completionForExt"
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("error creating request: %v", err)
}
req.Header.Set("Content-Type", "application/json")
return s.handleStreamingResponseForExt(req, data)
}
// handleStreamingResponseForExt processes streaming responses from the external QA API
func (s *LLMService) handleStreamingResponseForExt(req *http.Request, data map[string]interface{}) (chan Message, error) {
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("error making request: %v", err)
}
if resp.StatusCode != http.StatusOK {
resp.Body.Close()
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}
messageChan := make(chan Message, 100) // Buffered channel for better performance
all_message := ""
initialSessage := ""
go func() {
defer resp.Body.Close()
defer close(messageChan)
reader := bufio.NewReader(resp.Body)
for {
line, err := reader.ReadString('\n')
if err != nil {
if err == io.EOF {
break
}
fmt.Printf("Error reading line: %v\n", err)
continue
}
line = strings.TrimSpace(line)
if line == "" {
continue
}
// Remove "data: " prefix if present
if strings.HasPrefix(line, "data:") {
line = strings.TrimSpace(strings.TrimPrefix(line, "data:"))
}
if line == "" {
continue
}
var response ExtQAResponse
if err := json.Unmarshal([]byte(line), &response); err != nil {
// fmt.Printf("Error unmarshaling JSON: %v, line: %s\n", err, line)
continue
}
// Map external API response to local variables
answer := response.Output.Text
conversationID := response.Output.SessionID
taskID := response.RequestID
// Logic block matching 'case "message"'
if answer != "" || response.Output.FinishReason != "stop" {
fmt.Println("源文本:", answer)
// 定义标点符号map
punctuations := map[string]bool{
",": true, "": true, // 逗号
".": true, "。": true, // 句号
"!": true, "": true, // 感叹号
"?": true, "": true, // 问号
";": true, "": true, // 分号
"": true, // 冒号
"、": true,
}
// 删除字符串前后的标点符号
trimPunctuation := func(s string) string {
if len(s) > 0 {
// 获取最后一个字符的 rune
lastRune, size := utf8.DecodeLastRuneInString(s)
if punctuations[string(lastRune)] {
s = s[:len(s)-size]
}
}
return s
}
// 判断字符串是否包含标点符号
containsPunctuation := func(s string) bool {
for _, char := range s {
if punctuations[string(char)] {
return true
}
}
return false
}
// 按标点符号分割文本
splitByPunctuation := func(s string) []string {
var result []string
var current string
for _, char := range s {
if punctuations[string(char)] {
if current != "" {
result = append(result, current+string(char))
current = ""
}
} else {
current += string(char)
}
}
if current != "" {
result = append(result, current)
}
return result
}
new_message := ""
initialSessage += answer
all_message += answer
if containsPunctuation(initialSessage) {
segments := splitByPunctuation(initialSessage)
if len(segments) > 1 {
format_message := strings.Join(segments[:len(segments)-1], "")
// 检查initialSessage的字符长度是否超过15个
if utf8.RuneCountInString(format_message) > 15 {
initialSessage = segments[len(segments)-1]
// 如果超过10个字符将其添加到new_message中并清空initialSessage
new_message = strings.Join(segments[:len(segments)-1], "")
} else {
if containsPunctuation(format_message) && utf8.RuneCountInString(format_message) > 10 {
initialSessage = segments[len(segments)-1]
new_message = strings.Join(segments[:len(segments)-1], "")
} else {
// continue logic (do nothing here)
}
}
} else {
if utf8.RuneCountInString(initialSessage) > 15 {
new_message = initialSessage
initialSessage = ""
} else {
// continue logic (do nothing here)
}
}
}
if new_message != "" {
s_msg := strings.TrimSpace(new_message)
// Trim punctuation from the message
new_message = trimPunctuation(s_msg)
fmt.Println("new_message", new_message)
// Send message without audio
fmt.Println("所有消息:", all_message)
messageChan <- Message{
Answer: s_msg,
IsEnd: false,
ConversationID: conversationID,
TaskID: taskID,
// ClientID: conversationID,
AudioData: "",
}
}
}
// Logic block matching 'case "message_end"'
if response.Output.FinishReason == "stop" {
// 在流结束前,处理剩余的文本
if initialSessage != "" {
s_msg := strings.TrimSpace(initialSessage)
// 定义标点符号map (needed again if functions are not visible, but we can reuse the ones above if scoped correctly.
// To be safe and "copy steps", I'll redefine or just use the logic if it's reachable.
// In Go, functions defined inside a loop are re-created or we can define them outside the loop.
// The user code defined them inside `case "message"`.
// I will define them at the top of the loop or inside the if block.
// Since I'm not in a switch case anymore, I can define them once at the top of the loop or before the loop.)
// To strictly follow "copy steps", I will assume the logic needs to run.
// I'll just send the remaining message without punctuation trimming logic for audio generation
// because the original code only trimmed for audio generation `SynthesizeSpeech(new_message, audio_type)`.
// Wait, the original code sent `s_msg` (trimmed space) to messageChan, and used `new_message` (trimmed punctuation) for audio.
// So for messageChan, I just use `s_msg`.
fmt.Println("最后一段文本:", s_msg)
// Send the last message
messageChan <- Message{
Answer: s_msg,
IsEnd: false,
ConversationID: conversationID,
TaskID: taskID,
// ClientID: conversationID,
AudioData: "",
}
initialSessage = ""
}
// Send end message
messageChan <- Message{
Answer: "",
IsEnd: true,
ConversationID: conversationID,
TaskID: taskID,
}
return
}
}
}()
return messageChan, nil
}
// SaveBase64AsWAV saves base64 encoded audio data as a WAV file // SaveBase64AsWAV saves base64 encoded audio data as a WAV file
func (s *LLMService) SaveBase64AsWAV(base64Data string, outputPath string) error { func (s *LLMService) SaveBase64AsWAV(base64Data string, outputPath string) error {
// Decode base64 data // Decode base64 audio data
audioData, err := base64.StdEncoding.DecodeString(base64Data) audioData, err := base64.StdEncoding.DecodeString(base64Data)
if err != nil { if err != nil {
return fmt.Errorf("error decoding base64 data: %v", err) return fmt.Errorf("error decoding base64 audio: %v", err)
} }
// Validate WAV header // Valid WAV header check
if len(audioData) < 44 { // WAV header is 44 bytes if len(audioData) < 44 { // WAV header is 44 bytes
return fmt.Errorf("invalid WAV data: too short") return fmt.Errorf("invalid WAV data: too short")
} }