diff --git a/service/llm_service.go b/service/llm_service.go index 8ed3cc6..b95a712 100644 --- a/service/llm_service.go +++ b/service/llm_service.go @@ -460,16 +460,19 @@ func (s *LLMService) handleStreamingResponse(req *http.Request, data map[string] new_message = strings.Join(segments[:len(segments)-1], "") // initialSessage = "" } else { - continue + if containsPunctuation(format_message) && utf8.RuneCountInString(format_message) > 10 { + + initialSessage = segments[len(segments)-1] + new_message = strings.Join(segments[:len(segments)-1], "") + } else { + continue + } } } else { if utf8.RuneCountInString(initialSessage) > 15 { new_message = initialSessage initialSessage = "" - } else if utf8.RuneCountInString(initialSessage) <= 15 && containsPunctuation(initialSessage) { - new_message = initialSessage - initialSessage = "" } else { continue } @@ -544,6 +547,96 @@ func (s *LLMService) handleStreamingResponse(req *http.Request, data map[string] AudioData: audio, // Update to use the correct path to audio data } case "message_end": + // 在流结束前,处理剩余的文本生成音频 + if initialSessage != "" { + // 不管文本长度,直接生成音频 + s_msg := strings.TrimSpace(initialSessage) + // 定义标点符号map + punctuations := map[string]bool{ + ",": true, ",": true, // 逗号 + ".": true, "。": true, // 句号 + "!": true, "!": true, // 感叹号 + "?": true, "?": true, // 问号 + ";": true, ";": true, // 分号 + ":": true, ":": true, // 冒号 + "、": true, + } + + // 删除字符串前后的标点符号 + trimPunctuation := func(s string) string { + if len(s) > 0 { + lastRune, size := utf8.DecodeLastRuneInString(s) + if punctuations[string(lastRune)] { + s = s[:len(s)-size] + } + } + return s + } + + new_message := trimPunctuation(s_msg) + fmt.Println("最后一段文本生成音频:", new_message) + + // 生成语音 + var audio string + for i := 0; i < 1; i++ { + speechResp, err := s.SynthesizeSpeech(new_message, audio_type) + if err != nil { + fmt.Printf("Error synthesizing speech: %v\n", err) + break + } + fmt.Println("语音:", speechResp) + audio = speechResp.Data.Audio + if audio != "" { + // 下载并处理音频 + resp, err := http.Get(audio) + if err != nil { + fmt.Printf("Error downloading audio: %v\n", err) + } else { + defer resp.Body.Close() + audioBytes, err := io.ReadAll(resp.Body) + if err != nil { + fmt.Printf("Error reading audio data: %v\n", err) + } else { + // 保存原始音频 + originalPath := fmt.Sprintf("audio/original_%d.wav", time.Now().UnixNano()) + if err := os.WriteFile(originalPath, audioBytes, 0644); err != nil { + fmt.Printf("Error saving original audio: %v\n", err) + } + + // 静音裁剪 + audioBase64 := base64.StdEncoding.EncodeToString(audioBytes) + trimmedAudio, err := s.TrimAudioSilence(audioBase64) + if err != nil { + fmt.Printf("Error trimming audio silence: %v\n", err) + } else { + audio_path := fmt.Sprintf("trimmed_%d.wav", time.Now().UnixNano()) + outputPath := "audio/" + audio_path + if err := s.SaveBase64AsWAV(trimmedAudio, outputPath); err != nil { + fmt.Printf("Error saving trimmed WAV file: %v\n", err) + } + audio = s.config.FILE_URL + audio_path + } + } + } + break + } + } + + // 发送最后一段文本的消息 + messageChan <- Message{ + Answer: s_msg, + IsEnd: false, + ConversationID: getString(jsonData, "conversation_id"), + TaskID: getString(jsonData, "task_id"), + ClientID: getString(data, "conversation_id"), + AudioData: audio, + } + + // 清空剩余文本 + initialSessage = "" + } + + // 发送结束消息 messageChan <- Message{ Answer: "", IsEnd: true,