songjvcheng 0300cc96c4
All checks were successful
Gitea Actions Demo / Explore-Gitea-Actions (push) Successful in 1m3s
llm service
2026-01-12 12:37:29 +08:00

1498 lines
44 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import (
"bufio"
"bytes"
"encoding/base64"
"encoding/binary"
"encoding/json"
"fmt"
"io"
"math"
"net/http"
"os"
"path/filepath"
"strings"
"time"
"unicode/utf8"
)
// Config holds the configuration for the LLM service
type Config struct {
LLMApiURL string
LLMApiKey string
MiniMaxApiKey string
MiniMaxApiURL string
FILE_URL string
LLMOurApiUrl string
LLMOurApiKey string
}
// LLMService handles communication with the LLM API
type LLMService struct {
config Config
client *http.Client
}
// Message represents a single message in the conversation
type Message struct {
Answer string `json:"answer"`
IsEnd bool `json:"isEnd"`
ConversationID string `json:"conversation_id"`
TaskID string `json:"task_id"`
ClientID string `json:"client_id,omitempty"`
AudioData string `json:"audio_data,omitempty"`
}
// RequestPayload represents the payload sent to the LLM API
type RequestPayload struct {
Inputs map[string]interface{} `json:"inputs"`
Query string `json:"query"`
ResponseMode string `json:"response_mode"`
User string `json:"user"`
ConversationID string `json:"conversation_id"`
Files []interface{} `json:"files"`
Audio string `json:"audio"`
LlmType string `json:"llm_type"`
}
// VoiceSetting represents voice configuration
type VoiceSetting struct {
VoiceID string `json:"voice_id"`
Speed float64 `json:"speed"`
Vol float64 `json:"vol"`
Pitch float64 `json:"pitch"`
Emotion string `json:"emotion"`
}
// AudioSetting represents audio configuration
type AudioSetting struct {
SampleRate int `json:"sample_rate"`
Bitrate int `json:"bitrate"`
Format string `json:"format"`
}
// SpeechRequest represents the speech synthesis request payload
type SpeechRequest struct {
Model string `json:"model"`
Text string `json:"text"`
Stream bool `json:"stream"`
LanguageBoost string `json:"language_boost"`
OutputFormat string `json:"output_format"`
VoiceSetting VoiceSetting `json:"voice_setting"`
AudioSetting AudioSetting `json:"audio_setting"`
}
// SpeechData represents the speech data in the response
type SpeechData struct {
Audio string `json:"audio"`
Status int `json:"status"`
}
// ExtraInfo represents additional information about the speech
type ExtraInfo struct {
AudioLength int `json:"audio_length"`
AudioSampleRate int `json:"audio_sample_rate"`
AudioSize int `json:"audio_size"`
AudioBitrate int `json:"audio_bitrate"`
WordCount int `json:"word_count"`
InvisibleCharacterRatio float64 `json:"invisible_character_ratio"`
AudioFormat string `json:"audio_format"`
UsageCharacters int `json:"usage_characters"`
}
// BaseResponse represents the base response structure
type BaseResponse struct {
StatusCode int `json:"status_code"`
StatusMsg string `json:"status_msg"`
}
// SpeechResponse represents the speech synthesis response
type SpeechResponse struct {
Data SpeechData `json:"data"`
ExtraInfo ExtraInfo `json:"extra_info"`
TraceID string `json:"trace_id"`
BaseResp BaseResponse `json:"base_resp"`
}
type LLMOurMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type LLMOurRequestPayload struct {
Model string `json:"model"`
Stream bool `json:"stream"`
StreamOptions map[string]interface{} `json:"stream_options"`
Messages []LLMOurMessage `json:"messages"`
}
// ExtQARequestPayload represents the payload for the external QA API
type ExtQARequestPayload struct {
TagIDs []int `json:"tag_ids"`
ConversationID string `json:"conversation_id"`
Content string `json:"content"`
}
// ExtQAResponse represents the response from the external QA API
type ExtQAResponse struct {
RequestID string `json:"request_id"`
Output struct {
Text string `json:"text"`
FinishReason string `json:"finish_reason"`
SessionID string `json:"session_id"`
DocReferences []interface{} `json:"doc_references"`
} `json:"output"`
}
// NewLLMService creates a new instance of LLMService
func NewLLMService(config Config) *LLMService {
return &LLMService{
config: config,
client: &http.Client{},
}
}
// CallLLMAPI handles both streaming and non-streaming API calls
func (s *LLMService) CallLLMAPI(data map[string]interface{}) (interface{}, error) {
payload := RequestPayload{
Inputs: make(map[string]interface{}),
Query: getString(data, "query"),
ResponseMode: getString(data, "response_mode"),
User: getString(data, "user"),
ConversationID: getString(data, "conversation_id"),
Files: make([]interface{}, 0),
Audio: getString(data, "audio"),
LlmType: getString(data, "llm_type"),
}
fmt.Printf("前端传来的数据:%+v\n", payload)
jsonData, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("error marshaling payload: %v", err)
}
currentUrl := s.config.LLMApiURL + "/chat-messages"
fmt.Println(currentUrl)
req := &http.Request{}
if payload.LlmType == "ours" {
// 动态构造 messages
var messages []LLMOurMessage
if msgs, ok := data["messages"]; ok {
if arr, ok := msgs.([]interface{}); ok {
for _, m := range arr {
if mMap, ok := m.(map[string]interface{}); ok {
role, _ := mMap["role"].(string)
content, _ := mMap["content"].(string)
messages = append(messages, LLMOurMessage{Role: role, Content: content})
}
}
}
}
// fallback: 如果没有 messages则用 query 作为 user 消息
if len(messages) == 0 && payload.Query != "" {
messages = append(messages, LLMOurMessage{Role: "user", Content: payload.Query})
}
ourPayload := LLMOurRequestPayload{
Model: "bot-20250522162100-44785", // 可根据 data 或配置传入
Stream: true,
StreamOptions: map[string]interface{}{"include_usage": true},
Messages: messages,
}
jsonData, err = json.Marshal(ourPayload)
if err != nil {
return nil, fmt.Errorf("error marshaling ourPayload: %v", err)
}
currentUrl = s.config.LLMOurApiUrl
req, err = http.NewRequest("POST", currentUrl, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("error creating request: %v", err)
}
req.Header.Set("Authorization", "Bearer "+s.config.LLMOurApiKey)
req.Header.Set("Content-Type", "application/json")
return s.handleStreamingResponseV2(req, data, payload.Audio)
}
req, err = http.NewRequest("POST", currentUrl, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("error creating request: %v", err)
}
req.Header.Set("Authorization", "Bearer "+s.config.LLMApiKey)
req.Header.Set("Content-Type", "application/json")
isStreaming := payload.ResponseMode == "streaming"
if isStreaming {
return s.handleStreamingResponse(req, data, payload.Audio)
}
return s.handleNonStreamingResponse(req)
}
// processStreamSegment 处理流式文本分段、语音合成等逻辑,返回 new_message、audio、是否需要发送
func (s *LLMService) processStreamSegment(initialSessage *string, all_message *string, answer string, audio_type string) (string, string, bool) {
// 定义标点符号map
punctuations := map[string]bool{
",": true, "": true, // 逗号
".": true, "。": true, // 句号
"!": true, "": true, // 感叹号
"?": true, "": true, // 问号
";": true, "": true, // 分号
":": true, "": true, // 冒号
"、": true,
}
// 删除字符串前后的标点符号
trimPunctuation := func(s string) string {
if len(s) > 0 {
lastRune, size := utf8.DecodeLastRuneInString(s)
if punctuations[string(lastRune)] {
s = s[:len(s)-size]
}
}
return s
}
// 判断字符串是否包含标点符号
containsPunctuation := func(s string) bool {
for _, char := range s {
if punctuations[string(char)] {
return true
}
}
return false
}
// 按标点符号分割文本
splitByPunctuation := func(s string) []string {
var result []string
var current string
for _, char := range s {
if punctuations[string(char)] {
if current != "" {
result = append(result, current+string(char))
current = ""
}
} else {
current += string(char)
}
}
if current != "" {
result = append(result, current)
}
return result
}
*initialSessage += answer
*all_message += answer
new_message := ""
if containsPunctuation(*initialSessage) {
segments := splitByPunctuation(*initialSessage)
if len(segments) > 1 {
format_message := strings.Join(segments[:len(segments)-1], "")
if utf8.RuneCountInString(format_message) > 10 {
*initialSessage = segments[len(segments)-1]
new_message = strings.Join(segments[:len(segments)-1], "")
} else {
return "", "", false
}
} else {
if utf8.RuneCountInString(*initialSessage) > 10 {
new_message = *initialSessage
*initialSessage = ""
} else if utf8.RuneCountInString(*initialSessage) <= 10 && strings.HasSuffix(*initialSessage, "。") {
new_message = *initialSessage
*initialSessage = ""
} else {
return "", "", false
}
}
}
if new_message == "" {
return "", "", false
}
s_msg := strings.TrimSpace(new_message)
new_message = trimPunctuation(s_msg)
audio := ""
for i := 0; i < 1; i++ {
speechResp, err := s.SynthesizeSpeech(new_message, audio_type)
if err != nil {
fmt.Printf("Error synthesizing speech: %v\n", err)
break
}
fmt.Println("触发音频", speechResp)
audio = speechResp.Data.Audio
if audio != "" {
resp, err := http.Get(audio)
if err != nil {
fmt.Printf("Error downloading audio: %v\n", err)
} else {
defer resp.Body.Close()
audioBytes, err := io.ReadAll(resp.Body)
if err != nil {
fmt.Printf("Error reading audio data: %v\n", err)
} else {
originalPath := fmt.Sprintf("audio/original_%d.wav", time.Now().UnixNano())
if err := os.WriteFile(originalPath, audioBytes, 0644); err != nil {
fmt.Printf("Error saving original audio: %v\n", err)
}
audioBase64 := base64.StdEncoding.EncodeToString(audioBytes)
trimmedAudio, err := s.TrimAudioSilence(audioBase64)
if err != nil {
fmt.Printf("Error trimming audio silence: %v\n", err)
} else {
audio_path := fmt.Sprintf("trimmed_%d.wav", time.Now().UnixNano())
outputPath := "audio/" + audio_path
if err := s.SaveBase64AsWAV(trimmedAudio, outputPath); err != nil {
fmt.Printf("Error saving trimmed WAV file: %v\n", err)
}
audio = s.config.FILE_URL + audio_path
}
}
}
break
}
}
return s_msg, audio, true
}
// handleStreamingResponse processes streaming responses
func (s *LLMService) handleStreamingResponse(req *http.Request, data map[string]interface{}, audio_type string) (chan Message, error) {
resp, err := s.client.Do(req)
if err != nil {
return nil, fmt.Errorf("error making request: %v", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}
messageChan := make(chan Message, 100) // Buffered channel for better performance
all_message := ""
initialSessage := ""
go func() {
defer resp.Body.Close()
defer close(messageChan)
reader := bufio.NewReader(resp.Body)
for {
line, err := reader.ReadString('\n')
if err != nil {
if err == io.EOF {
break
}
fmt.Printf("Error reading line: %v\n", err)
continue
}
line = strings.TrimSpace(line)
if line == "" {
continue
}
// Remove "data: " prefix if present
line = strings.TrimPrefix(line, "data: ")
var jsonData map[string]interface{}
if err := json.Unmarshal([]byte(line), &jsonData); err != nil {
fmt.Printf("Error unmarshaling JSON: %v\n", err)
continue
}
event := getString(jsonData, "event")
switch event {
case "message":
answer := getString(jsonData, "answer")
fmt.Println("源文本:", answer)
var audio string
// 定义标点符号map
punctuations := map[string]bool{
",": true, "": true, // 逗号
".": true, "。": true, // 句号
"!": true, "": true, // 感叹号
"?": true, "": true, // 问号
";": true, "": true, // 分号
"": true, // 冒号
"、": true,
}
// 删除字符串前后的标点符号
trimPunctuation := func(s string) string {
if len(s) > 0 {
// 获取最后一个字符的 rune
lastRune, size := utf8.DecodeLastRuneInString(s)
if punctuations[string(lastRune)] {
s = s[:len(s)-size]
}
}
return s
}
// 判断字符串是否包含标点符号
containsPunctuation := func(s string) bool {
for _, char := range s {
if punctuations[string(char)] {
return true
}
}
return false
}
// 按标点符号分割文本
splitByPunctuation := func(s string) []string {
var result []string
var current string
for _, char := range s {
if punctuations[string(char)] {
if current != "" {
result = append(result, current+string(char))
current = ""
}
} else {
current += string(char)
}
}
if current != "" {
result = append(result, current)
}
return result
}
new_message := ""
initialSessage += answer
all_message += answer
if containsPunctuation(initialSessage) {
segments := splitByPunctuation(initialSessage)
// fmt.Printf("原始文本: %s\n", initialSessage)
// fmt.Printf("分割后的片段数量: %d\n", len(segments))
// for i, segment := range segments {
// fmt.Printf("片段 %d: %s\n", i+1, segment)
// }
if len(segments) > 1 {
format_message := strings.Join(segments[:len(segments)-1], "")
// 检查initialSessage的字符长度是否超过15个
if utf8.RuneCountInString(format_message) > 15 {
initialSessage = segments[len(segments)-1]
// 如果超过10个字符将其添加到new_message中并清空initialSessage
new_message = strings.Join(segments[:len(segments)-1], "")
// initialSessage = ""
} else {
if containsPunctuation(format_message) && utf8.RuneCountInString(format_message) > 10 {
initialSessage = segments[len(segments)-1]
new_message = strings.Join(segments[:len(segments)-1], "")
} else {
continue
}
}
} else {
if utf8.RuneCountInString(initialSessage) > 15 {
new_message = initialSessage
initialSessage = ""
} else {
continue
}
}
// fmt.Printf("新消息: %s\n", new_message)
// fmt.Printf("剩余文本: %s\n", initialSessage)
}
if new_message == "" {
continue
}
s_msg := strings.TrimSpace(new_message)
// Trim punctuation from the message
new_message = trimPunctuation(s_msg)
fmt.Println("new_message", new_message)
// println(new_message)
// 最多重试一次
for i := 0; i < 1; i++ {
speechResp, err := s.SynthesizeSpeech(new_message, audio_type)
if err != nil {
fmt.Printf("Error synthesizing speech: %v\n", err)
break // 语音接口报错直接跳出
}
fmt.Println("语音:", speechResp)
audio = speechResp.Data.Audio
if audio != "" {
// Download audio from URL and trim silence
resp, err := http.Get(audio)
if err != nil {
fmt.Printf("Error downloading audio: %v\n", err)
} else {
defer resp.Body.Close()
audioBytes, err := io.ReadAll(resp.Body)
if err != nil {
fmt.Printf("Error reading audio data: %v\n", err)
} else {
// Save original audio first
originalPath := fmt.Sprintf("audio/original_%d.wav", time.Now().UnixNano())
if err := os.WriteFile(originalPath, audioBytes, 0644); err != nil {
fmt.Printf("Error saving original audio: %v\n", err)
}
// Convert audio bytes to base64 for processing
audioBase64 := base64.StdEncoding.EncodeToString(audioBytes)
trimmedAudio, err := s.TrimAudioSilence(audioBase64)
if err != nil {
fmt.Printf("Error trimming audio silence: %v\n", err)
} else {
// Save the trimmed audio as WAV file
audio_path := fmt.Sprintf("trimmed_%d.wav", time.Now().UnixNano())
outputPath := "audio/" + audio_path
if err := s.SaveBase64AsWAV(trimmedAudio, outputPath); err != nil {
fmt.Printf("Error saving trimmed WAV file: %v\n", err)
}
audio = s.config.FILE_URL + audio_path
}
}
}
break // 获取到音频就退出
}
// fmt.Println("audio is empty, retry", speechResp)
// time.Sleep(1 * time.Second)
}
fmt.Println("所有消息:", all_message)
messageChan <- Message{
Answer: s_msg,
IsEnd: false,
ConversationID: getString(jsonData, "conversation_id"),
TaskID: getString(jsonData, "task_id"),
ClientID: getString(data, "conversation_id"),
AudioData: audio, // Update to use the correct path to audio data
}
case "message_end":
// 在流结束前,处理剩余的文本生成音频
if initialSessage != "" {
// 不管文本长度,直接生成音频
s_msg := strings.TrimSpace(initialSessage)
// 定义标点符号map
punctuations := map[string]bool{
",": true, "": true, // 逗号
".": true, "。": true, // 句号
"!": true, "": true, // 感叹号
"?": true, "": true, // 问号
";": true, "": true, // 分号
":": true, "": true, // 冒号
"、": true,
}
// 删除字符串前后的标点符号
trimPunctuation := func(s string) string {
if len(s) > 0 {
lastRune, size := utf8.DecodeLastRuneInString(s)
if punctuations[string(lastRune)] {
s = s[:len(s)-size]
}
}
return s
}
new_message := trimPunctuation(s_msg)
fmt.Println("最后一段文本生成音频:", new_message)
// 生成语音
var audio string
for i := 0; i < 1; i++ {
speechResp, err := s.SynthesizeSpeech(new_message, audio_type)
if err != nil {
fmt.Printf("Error synthesizing speech: %v\n", err)
break
}
fmt.Println("语音:", speechResp)
audio = speechResp.Data.Audio
if audio != "" {
// 下载并处理音频
resp, err := http.Get(audio)
if err != nil {
fmt.Printf("Error downloading audio: %v\n", err)
} else {
defer resp.Body.Close()
audioBytes, err := io.ReadAll(resp.Body)
if err != nil {
fmt.Printf("Error reading audio data: %v\n", err)
} else {
// 保存原始音频
originalPath := fmt.Sprintf("audio/original_%d.wav", time.Now().UnixNano())
if err := os.WriteFile(originalPath, audioBytes, 0644); err != nil {
fmt.Printf("Error saving original audio: %v\n", err)
}
// 静音裁剪
audioBase64 := base64.StdEncoding.EncodeToString(audioBytes)
trimmedAudio, err := s.TrimAudioSilence(audioBase64)
if err != nil {
fmt.Printf("Error trimming audio silence: %v\n", err)
} else {
audio_path := fmt.Sprintf("trimmed_%d.wav", time.Now().UnixNano())
outputPath := "audio/" + audio_path
if err := s.SaveBase64AsWAV(trimmedAudio, outputPath); err != nil {
fmt.Printf("Error saving trimmed WAV file: %v\n", err)
}
audio = s.config.FILE_URL + audio_path
}
}
}
break
}
}
// 发送最后一段文本的消息
messageChan <- Message{
Answer: s_msg,
IsEnd: false,
ConversationID: getString(jsonData, "conversation_id"),
TaskID: getString(jsonData, "task_id"),
ClientID: getString(data, "conversation_id"),
AudioData: audio,
}
// 清空剩余文本
initialSessage = ""
}
// 发送结束消息
messageChan <- Message{
Answer: "",
IsEnd: true,
ConversationID: getString(jsonData, "conversation_id"),
TaskID: getString(jsonData, "task_id"),
}
return
}
}
}()
return messageChan, nil
}
// handleStreamingResponseV2 适配新流式返回格式
func (s *LLMService) handleStreamingResponseV2(req *http.Request, data map[string]interface{}, audio_type string) (chan Message, error) {
resp, err := s.client.Do(req)
if err != nil {
return nil, fmt.Errorf("error making request: %v", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}
messageChan := make(chan Message, 100)
all_message := ""
initialSessage := ""
go func() {
defer resp.Body.Close()
defer close(messageChan)
reader := bufio.NewReader(resp.Body)
for {
line, err := reader.ReadString('\n')
if err != nil {
if err == io.EOF {
break
}
fmt.Printf("Error reading line: %v\n", err)
continue
}
line = strings.TrimSpace(line)
if line == "" {
continue
}
// line = strings.TrimSpace(line)
if strings.HasPrefix(line, "data:") {
line = strings.TrimSpace(strings.TrimPrefix(line, "data:"))
}
// fmt.Println("line: ", line)
if line == "[DONE]" {
messageChan <- Message{
Answer: "",
IsEnd: true,
ConversationID: getString(data, "conversation_id"),
TaskID: getString(data, "task_id"),
}
return
}
var jsonData map[string]interface{}
if err := json.Unmarshal([]byte(line), &jsonData); err != nil {
fmt.Printf("Error unmarshaling JSON: %v\n", err)
continue
}
choices, ok := jsonData["choices"].([]interface{})
if !ok || len(choices) == 0 {
continue
}
choice, ok := choices[0].(map[string]interface{})
if !ok {
continue
}
delta, ok := choice["delta"].(map[string]interface{})
if !ok {
continue
}
content, _ := delta["content"].(string)
if content == "" {
continue
}
new_message, audio, needSend := s.processStreamSegment(&initialSessage, &all_message, content, audio_type)
if !needSend {
continue
}
messageChan <- Message{
Answer: new_message,
IsEnd: false,
ConversationID: getString(data, "conversation_id"),
TaskID: getString(data, "task_id"),
ClientID: getString(data, "conversation_id"),
AudioData: audio,
}
}
}()
return messageChan, nil
}
// handleNonStreamingResponse processes non-streaming responses
func (s *LLMService) handleNonStreamingResponse(req *http.Request) (map[string]interface{}, error) {
resp, err := s.client.Do(req)
if err != nil {
return nil, fmt.Errorf("error making request: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}
var result map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("error decoding response: %v", err)
}
return result, nil
}
// StopConversation stops an ongoing conversation
func (s *LLMService) StopConversation(taskID string) (map[string]interface{}, error) {
req, err := http.NewRequest("POST", fmt.Sprintf("%s/chat-messages/%s/stop", s.config.LLMApiURL, taskID), nil)
if err != nil {
return nil, fmt.Errorf("error creating request: %v", err)
}
req.Header.Set("Authorization", "Bearer "+s.config.LLMApiKey)
req.Header.Set("Content-Type", "application/json")
resp, err := s.client.Do(req)
if err != nil {
return nil, fmt.Errorf("error making request: %v", err)
}
defer resp.Body.Close()
var result map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("error decoding response: %v", err)
}
return result, nil
}
// DeleteConversation deletes a conversation
func (s *LLMService) DeleteConversation(conversationID, user string) (map[string]interface{}, error) {
req, err := http.NewRequest("DELETE", fmt.Sprintf("%s/conversations/%s", s.config.LLMApiURL, conversationID), nil)
if err != nil {
return nil, fmt.Errorf("error creating request: %v", err)
}
req.Header.Set("Authorization", "Bearer "+s.config.LLMApiKey)
req.Header.Set("Content-Type", "application/json")
resp, err := s.client.Do(req)
if err != nil {
return nil, fmt.Errorf("error making request: %v", err)
}
defer resp.Body.Close()
var result map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("error decoding response: %v", err)
}
return result, nil
}
// SynthesizeSpeech converts text to speech
func (s *LLMService) SynthesizeSpeech(text string, audio string) (*SpeechResponse, error) {
payload := SpeechRequest{
Model: "speech-02-hd",
Text: text,
Stream: false,
LanguageBoost: "auto",
OutputFormat: "url",
VoiceSetting: VoiceSetting{
VoiceID: audio,
Speed: 1,
Vol: 1,
Pitch: -1,
Emotion: "neutral",
},
AudioSetting: AudioSetting{
SampleRate: 32000,
Bitrate: 128000,
Format: "wav",
},
}
jsonData, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("error marshaling speech request: %v", err)
}
req, err := http.NewRequest("POST", s.config.MiniMaxApiURL, bytes.NewBuffer(jsonData))
if err != nil {
fmt.Println("error creating speech request: ", err)
return nil, fmt.Errorf("error creating speech request: %v", err)
}
req.Header.Set("Authorization", "Bearer "+s.config.MiniMaxApiKey)
req.Header.Set("Content-Type", "application/json")
resp, err := s.client.Do(req)
if err != nil {
fmt.Println("error making speech request: ", err)
return nil, fmt.Errorf("error making speech request: %v", err)
}
defer resp.Body.Close()
// fmt.Println(resp.Body)
if resp.StatusCode != http.StatusOK {
fmt.Println("unexpected status code: ", resp.StatusCode)
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}
var result SpeechResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
fmt.Println("error decoding speech response: ", err)
return nil, fmt.Errorf("error decoding speech response: %v", err)
}
return &result, nil
}
// StreamTextResponse handles streaming text output with predefined segments
func (s *LLMService) StreamTextResponse(conversationID string) (chan Message, error) {
messageChan := make(chan Message, 100)
segments := []string{
"好的,",
"我已经成功替换了文本内容。",
"新的文本是一段连续的描述,",
"没有换行,",
"总共65个字符",
"符合100字以内的要求",
"并且是一个连续的段落。",
"现在我需要完成任务。",
}
go func() {
defer close(messageChan)
taskID := "task_" + time.Now().Format("20060102150405")
for _, segment := range segments {
// Send message
messageChan <- Message{
Answer: segment,
IsEnd: false,
ConversationID: conversationID,
TaskID: taskID,
ClientID: conversationID,
}
// Add delay between segments
time.Sleep(500 * time.Millisecond)
}
// Send end message
messageChan <- Message{
Answer: "",
IsEnd: true,
ConversationID: conversationID,
TaskID: taskID,
}
}()
return messageChan, nil
}
// Helper function to safely get string values from interface{}
func getString(data map[string]interface{}, key string) string {
if val, ok := data[key]; ok {
if str, ok := val.(string); ok {
return str
}
}
return ""
}
// TrimAudioSilence trims the silence at the end of the audio data
func (s *LLMService) TrimAudioSilence(audioData string) (string, error) {
// Decode base64 audio data
decodedData, err := base64.StdEncoding.DecodeString(audioData)
if err != nil {
return "", fmt.Errorf("error decoding base64 audio: %v", err)
}
// Create a buffer from the decoded data
buf := bytes.NewReader(decodedData)
// Read RIFF header
var riffHeader struct {
ChunkID [4]byte
ChunkSize uint32
Format [4]byte
}
if err := binary.Read(buf, binary.LittleEndian, &riffHeader); err != nil {
return "", fmt.Errorf("error reading RIFF header: %v", err)
}
// Verify RIFF header
if string(riffHeader.ChunkID[:]) != "RIFF" || string(riffHeader.Format[:]) != "WAVE" {
return "", fmt.Errorf("invalid WAV format")
}
// Read fmt chunk
var fmtChunk struct {
Subchunk1ID [4]byte
Subchunk1Size uint32
AudioFormat uint16
NumChannels uint16
SampleRate uint32
ByteRate uint32
BlockAlign uint16
BitsPerSample uint16
}
if err := binary.Read(buf, binary.LittleEndian, &fmtChunk); err != nil {
return "", fmt.Errorf("error reading fmt chunk: %v", err)
}
// Skip any extra bytes in fmt chunk
if fmtChunk.Subchunk1Size > 16 {
extraBytes := make([]byte, fmtChunk.Subchunk1Size-16)
if _, err := buf.Read(extraBytes); err != nil {
return "", fmt.Errorf("error skipping extra fmt bytes: %v", err)
}
}
// Find data chunk
var dataChunk struct {
Subchunk2ID [4]byte
Subchunk2Size uint32
}
for {
if err := binary.Read(buf, binary.LittleEndian, &dataChunk); err != nil {
return "", fmt.Errorf("error reading chunk header: %v", err)
}
if string(dataChunk.Subchunk2ID[:]) == "data" {
break
}
// Skip this chunk if it's not "data"
if _, err := buf.Seek(int64(dataChunk.Subchunk2Size), io.SeekCurrent); err != nil {
return "", fmt.Errorf("error skipping chunk: %v", err)
}
}
// Read audio data
audioBytes := make([]byte, dataChunk.Subchunk2Size)
if _, err := buf.Read(audioBytes); err != nil {
return "", fmt.Errorf("error reading audio data: %v", err)
}
// Calculate samples
bytesPerSample := int(fmtChunk.BitsPerSample / 8)
if bytesPerSample == 0 {
bytesPerSample = 2 // Default to 16-bit if unknown
}
numSamples := len(audioBytes) / bytesPerSample
// Find last non-silent sample
// Threshold: approx 1% of max amplitude for 16-bit audio (32768 * 0.01 ~= 327)
threshold := 327.0
lastNonSilent := 0
for i := 0; i < numSamples; i++ {
// Get sample value
var sample int16
offset := i * bytesPerSample
if offset+bytesPerSample > len(audioBytes) {
break
}
if bytesPerSample == 2 {
sample = int16(binary.LittleEndian.Uint16(audioBytes[offset : offset+2]))
} else if bytesPerSample == 1 {
// 8-bit audio is usually unsigned 0-255, center at 128
sample = int16(audioBytes[offset]) - 128
sample *= 256 // Scale to 16-bit range roughly
}
if math.Abs(float64(sample)) > threshold {
lastNonSilent = i
}
}
// Add a small buffer (e.g., 0.1 seconds) after the last non-silent sample
bufferSamples := int(float64(fmtChunk.SampleRate) * 0.1)
lastSample := lastNonSilent + bufferSamples
if lastSample > numSamples {
lastSample = numSamples
}
// Calculate new data size
newDataSize := lastSample * bytesPerSample
trimmedAudio := audioBytes[:newDataSize]
// Create new buffer for the trimmed audio
var newBuf bytes.Buffer
// Write RIFF header
riffHeader.ChunkSize = uint32(36 + newDataSize)
if err := binary.Write(&newBuf, binary.LittleEndian, riffHeader); err != nil {
return "", fmt.Errorf("error writing RIFF header: %v", err)
}
// Write fmt chunk
if err := binary.Write(&newBuf, binary.LittleEndian, fmtChunk); err != nil {
return "", fmt.Errorf("error writing fmt chunk: %v", err)
}
// Write data chunk header
dataChunk.Subchunk2Size = uint32(newDataSize)
if err := binary.Write(&newBuf, binary.LittleEndian, dataChunk); err != nil {
return "", fmt.Errorf("error writing data chunk header: %v", err)
}
// Write trimmed audio data
if _, err := newBuf.Write(trimmedAudio); err != nil {
return "", fmt.Errorf("error writing trimmed audio data: %v", err)
}
return base64.StdEncoding.EncodeToString(newBuf.Bytes()), nil
}
// CallExtQAAPI handles the external QA API call
func (s *LLMService) CallExtQAAPI(data map[string]interface{}) (interface{}, error) {
var tagIDs []int
// 优先尝试直接断言为 []int
if ids, ok := data["tag_ids"].([]int); ok {
tagIDs = ids
} else if tagIDsRaw, ok := data["tag_ids"].([]interface{}); ok {
for _, v := range tagIDsRaw {
if id, ok := v.(float64); ok {
tagIDs = append(tagIDs, int(id))
}
}
}
payload := ExtQARequestPayload{
TagIDs: tagIDs,
ConversationID: getString(data, "conversation_id"),
Content: getString(data, "content"),
}
jsonData, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("error marshaling payload: %v", err)
}
url := "http://47.100.108.206:30028/api/qa/v1/chat/completionForExt"
fmt.Printf("Sending request to %s with payload: %s\n", url, string(jsonData))
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("error creating request: %v", err)
}
req.Header.Set("Content-Type", "application/json")
return s.handleStreamingResponseForExt(req, data)
}
// CallExtQAAPIStreamDirect handles the external QA API call and streams the response directly
func (s *LLMService) CallExtQAAPIStreamDirect(data map[string]interface{}) (interface{}, error) {
var tagIDs []int
// 优先尝试直接断言为 []int
if ids, ok := data["tag_ids"].([]int); ok {
tagIDs = ids
} else if tagIDsRaw, ok := data["tag_ids"].([]interface{}); ok {
for _, v := range tagIDsRaw {
if id, ok := v.(float64); ok {
tagIDs = append(tagIDs, int(id))
}
}
}
payload := ExtQARequestPayload{
TagIDs: tagIDs,
ConversationID: getString(data, "conversation_id"),
Content: getString(data, "content"),
}
jsonData, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("error marshaling payload: %v", err)
}
url := "https://ai.ces-invest.com/api/qa/v1/chat/completionForExt"
fmt.Printf("Sending request to %s with payload: %s\n", url, string(jsonData))
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("error creating request: %v", err)
}
req.Header.Set("Content-Type", "application/json")
return s.handleStreamingResponseForExtDirect(req, data)
}
// handleStreamingResponseForExtDirect processes streaming responses from the external QA API and returns Message channel
func (s *LLMService) handleStreamingResponseForExtDirect(req *http.Request, data map[string]interface{}) (chan Message, error) {
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
fmt.Printf("Error making external request: %v\n", err)
return nil, fmt.Errorf("error making request: %v", err)
}
fmt.Printf("External API response status: %d\n", resp.StatusCode)
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
fmt.Printf("External API error body: %s\n", string(body))
resp.Body.Close()
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}
messageChan := make(chan Message, 100)
go func() {
defer resp.Body.Close()
defer close(messageChan)
reader := bufio.NewReader(resp.Body)
conversationID := getString(data, "conversation_id")
// Use current time as task ID since external API might not provide it in every chunk
taskID := fmt.Sprintf("task_%d", time.Now().UnixNano())
for {
line, err := reader.ReadString('\n')
if err != nil {
if err == io.EOF {
break
}
fmt.Printf("Error reading line: %v\n", err)
continue
}
line = strings.TrimSpace(line)
if line == "" {
continue
}
// Remove "data: " prefix if present
if strings.HasPrefix(line, "data:") {
line = strings.TrimSpace(strings.TrimPrefix(line, "data:"))
}
if line == "" {
continue
}
fmt.Printf("Processing line: %s\n", line)
var response ExtQAResponse
if err := json.Unmarshal([]byte(line), &response); err != nil {
fmt.Printf("Error unmarshaling JSON: %v, line: %s\n", err, line)
continue
}
// Direct forwarding logic
answer := response.Output.Text
isEnd := response.Output.FinishReason == "stop" || response.Output.FinishReason == "length"
messageChan <- Message{
Answer: answer,
IsEnd: isEnd,
ConversationID: conversationID,
TaskID: taskID,
}
if isEnd {
return
}
}
}()
return messageChan, nil
}
// handleStreamingResponseForExt processes streaming responses from the external QA API
func (s *LLMService) handleStreamingResponseForExt(req *http.Request, data map[string]interface{}) (chan Message, error) {
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
fmt.Printf("Error making external request: %v\n", err)
return nil, fmt.Errorf("error making request: %v", err)
}
fmt.Printf("External API response status: %d\n", resp.StatusCode)
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
fmt.Printf("External API error body: %s\n", string(body))
resp.Body.Close()
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}
messageChan := make(chan Message, 100) // Buffered channel for better performance
all_message := ""
initialSessage := ""
go func() {
defer resp.Body.Close()
defer close(messageChan)
reader := bufio.NewReader(resp.Body)
for {
line, err := reader.ReadString('\n')
if err != nil {
if err == io.EOF {
break
}
fmt.Printf("Error reading line: %v\n", err)
continue
}
line = strings.TrimSpace(line)
if line == "" {
continue
}
// Remove "data: " prefix if present
if strings.HasPrefix(line, "data:") {
line = strings.TrimSpace(strings.TrimPrefix(line, "data:"))
}
if line == "" {
continue
}
fmt.Printf("Processing line: %s\n", line)
var response ExtQAResponse
if err := json.Unmarshal([]byte(line), &response); err != nil {
fmt.Printf("Error unmarshaling JSON: %v, line: %s\n", err, line)
continue
}
// Map external API response to local variables
answer := response.Output.Text
conversationID := response.Output.SessionID
taskID := response.RequestID
// Logic block matching 'case "message"'
if answer != "" || response.Output.FinishReason != "stop" {
fmt.Println("源文本:", answer)
// 定义标点符号map
punctuations := map[string]bool{
",": true, "": true, // 逗号
".": true, "。": true, // 句号
"!": true, "": true, // 感叹号
"?": true, "": true, // 问号
";": true, "": true, // 分号
"": true, // 冒号
"、": true,
}
// 删除字符串前后的标点符号
trimPunctuation := func(s string) string {
if len(s) > 0 {
// 获取最后一个字符的 rune
lastRune, size := utf8.DecodeLastRuneInString(s)
if punctuations[string(lastRune)] {
s = s[:len(s)-size]
}
}
return s
}
// 判断字符串是否包含标点符号
containsPunctuation := func(s string) bool {
for _, char := range s {
if punctuations[string(char)] {
return true
}
}
return false
}
// 按标点符号分割文本
splitByPunctuation := func(s string) []string {
var result []string
var current string
for _, char := range s {
if punctuations[string(char)] {
if current != "" {
result = append(result, current+string(char))
current = ""
}
} else {
current += string(char)
}
}
if current != "" {
result = append(result, current)
}
return result
}
new_message := ""
initialSessage += answer
all_message += answer
if containsPunctuation(initialSessage) {
segments := splitByPunctuation(initialSessage)
if len(segments) > 1 {
format_message := strings.Join(segments[:len(segments)-1], "")
// 检查initialSessage的字符长度是否超过15个
if utf8.RuneCountInString(format_message) > 15 {
initialSessage = segments[len(segments)-1]
// 如果超过10个字符将其添加到new_message中并清空initialSessage
new_message = strings.Join(segments[:len(segments)-1], "")
} else {
if containsPunctuation(format_message) && utf8.RuneCountInString(format_message) > 10 {
initialSessage = segments[len(segments)-1]
new_message = strings.Join(segments[:len(segments)-1], "")
} else {
// continue logic (do nothing here)
}
}
} else {
if utf8.RuneCountInString(initialSessage) > 15 {
new_message = initialSessage
initialSessage = ""
} else {
// continue logic (do nothing here)
}
}
}
if new_message != "" {
s_msg := strings.TrimSpace(new_message)
// Trim punctuation from the message
new_message = trimPunctuation(s_msg)
fmt.Println("new_message", new_message)
// Send message without audio
fmt.Println("所有消息:", all_message)
messageChan <- Message{
Answer: s_msg,
IsEnd: false,
ConversationID: conversationID,
TaskID: taskID,
// ClientID: conversationID,
AudioData: "",
}
}
}
// Logic block matching 'case "message_end"'
if response.Output.FinishReason == "stop" {
// 在流结束前,处理剩余的文本
if initialSessage != "" {
s_msg := strings.TrimSpace(initialSessage)
// 定义标点符号map (needed again if functions are not visible, but we can reuse the ones above if scoped correctly.
// To be safe and "copy steps", I'll redefine or just use the logic if it's reachable.
// In Go, functions defined inside a loop are re-created or we can define them outside the loop.
// The user code defined them inside `case "message"`.
// I will define them at the top of the loop or inside the if block.
// Since I'm not in a switch case anymore, I can define them once at the top of the loop or before the loop.)
// To strictly follow "copy steps", I will assume the logic needs to run.
// I'll just send the remaining message without punctuation trimming logic for audio generation
// because the original code only trimmed for audio generation `SynthesizeSpeech(new_message, audio_type)`.
// Wait, the original code sent `s_msg` (trimmed space) to messageChan, and used `new_message` (trimmed punctuation) for audio.
// So for messageChan, I just use `s_msg`.
fmt.Println("最后一段文本:", s_msg)
// Send the last message
messageChan <- Message{
Answer: s_msg,
IsEnd: false,
ConversationID: conversationID,
TaskID: taskID,
// ClientID: conversationID,
AudioData: "",
}
initialSessage = ""
}
// Send end message
messageChan <- Message{
Answer: "",
IsEnd: true,
ConversationID: conversationID,
TaskID: taskID,
}
return
}
}
}()
return messageChan, nil
}
// SaveBase64AsWAV saves base64 encoded audio data as a WAV file
func (s *LLMService) SaveBase64AsWAV(base64Data string, outputPath string) error {
// Decode base64 audio data
audioData, err := base64.StdEncoding.DecodeString(base64Data)
if err != nil {
return fmt.Errorf("error decoding base64 audio: %v", err)
}
// Valid WAV header check
if len(audioData) < 44 { // WAV header is 44 bytes
return fmt.Errorf("invalid WAV data: too short")
}
// Check RIFF header
if string(audioData[0:4]) != "RIFF" {
return fmt.Errorf("invalid WAV format: missing RIFF header")
}
// Check WAVE format
if string(audioData[8:12]) != "WAVE" {
return fmt.Errorf("invalid WAV format: missing WAVE format")
}
// Create output directory if it doesn't exist
dir := filepath.Dir(outputPath)
if err := os.MkdirAll(dir, 0755); err != nil {
return fmt.Errorf("error creating directory: %v", err)
}
// Write the audio data to file
if err := os.WriteFile(outputPath, audioData, 0644); err != nil {
return fmt.Errorf("error writing WAV file: %v", err)
}
return nil
}