All checks were successful
Gitea Actions Demo / Explore-Gitea-Actions (push) Successful in 52s
726 lines
21 KiB
Go
726 lines
21 KiB
Go
package service
|
||
|
||
import (
|
||
"bufio"
|
||
"bytes"
|
||
"encoding/base64"
|
||
"encoding/binary"
|
||
"encoding/json"
|
||
"fmt"
|
||
"io"
|
||
"math"
|
||
"net/http"
|
||
"os"
|
||
"path/filepath"
|
||
"strings"
|
||
"time"
|
||
"unicode/utf8"
|
||
)
|
||
|
||
// Config holds the configuration for the LLM service
|
||
type Config struct {
|
||
LLMApiURL string
|
||
LLMApiKey string
|
||
MiniMaxApiKey string
|
||
MiniMaxApiURL string
|
||
FILE_URL string
|
||
}
|
||
|
||
// LLMService handles communication with the LLM API
|
||
type LLMService struct {
|
||
config Config
|
||
client *http.Client
|
||
}
|
||
|
||
// Message represents a single message in the conversation
|
||
type Message struct {
|
||
Answer string `json:"answer"`
|
||
IsEnd bool `json:"isEnd"`
|
||
ConversationID string `json:"conversation_id"`
|
||
TaskID string `json:"task_id"`
|
||
ClientID string `json:"client_id,omitempty"`
|
||
AudioData string `json:"audio_data,omitempty"`
|
||
}
|
||
|
||
// RequestPayload represents the payload sent to the LLM API
|
||
type RequestPayload struct {
|
||
Inputs map[string]interface{} `json:"inputs"`
|
||
Query string `json:"query"`
|
||
ResponseMode string `json:"response_mode"`
|
||
User string `json:"user"`
|
||
ConversationID string `json:"conversation_id"`
|
||
Files []interface{} `json:"files"`
|
||
Audio string `json:"audio"`
|
||
}
|
||
|
||
// VoiceSetting represents voice configuration
|
||
type VoiceSetting struct {
|
||
VoiceID string `json:"voice_id"`
|
||
Speed float64 `json:"speed"`
|
||
Vol float64 `json:"vol"`
|
||
Pitch float64 `json:"pitch"`
|
||
Emotion string `json:"emotion"`
|
||
}
|
||
|
||
// AudioSetting represents audio configuration
|
||
type AudioSetting struct {
|
||
SampleRate int `json:"sample_rate"`
|
||
Bitrate int `json:"bitrate"`
|
||
Format string `json:"format"`
|
||
}
|
||
|
||
// SpeechRequest represents the speech synthesis request payload
|
||
type SpeechRequest struct {
|
||
Model string `json:"model"`
|
||
Text string `json:"text"`
|
||
Stream bool `json:"stream"`
|
||
LanguageBoost string `json:"language_boost"`
|
||
OutputFormat string `json:"output_format"`
|
||
VoiceSetting VoiceSetting `json:"voice_setting"`
|
||
AudioSetting AudioSetting `json:"audio_setting"`
|
||
}
|
||
|
||
// SpeechData represents the speech data in the response
|
||
type SpeechData struct {
|
||
Audio string `json:"audio"`
|
||
Status int `json:"status"`
|
||
}
|
||
|
||
// ExtraInfo represents additional information about the speech
|
||
type ExtraInfo struct {
|
||
AudioLength int `json:"audio_length"`
|
||
AudioSampleRate int `json:"audio_sample_rate"`
|
||
AudioSize int `json:"audio_size"`
|
||
AudioBitrate int `json:"audio_bitrate"`
|
||
WordCount int `json:"word_count"`
|
||
InvisibleCharacterRatio float64 `json:"invisible_character_ratio"`
|
||
AudioFormat string `json:"audio_format"`
|
||
UsageCharacters int `json:"usage_characters"`
|
||
}
|
||
|
||
// BaseResponse represents the base response structure
|
||
type BaseResponse struct {
|
||
StatusCode int `json:"status_code"`
|
||
StatusMsg string `json:"status_msg"`
|
||
}
|
||
|
||
// SpeechResponse represents the speech synthesis response
|
||
type SpeechResponse struct {
|
||
Data SpeechData `json:"data"`
|
||
ExtraInfo ExtraInfo `json:"extra_info"`
|
||
TraceID string `json:"trace_id"`
|
||
BaseResp BaseResponse `json:"base_resp"`
|
||
}
|
||
|
||
// NewLLMService creates a new instance of LLMService
|
||
func NewLLMService(config Config) *LLMService {
|
||
return &LLMService{
|
||
config: config,
|
||
client: &http.Client{},
|
||
}
|
||
}
|
||
|
||
// CallLLMAPI handles both streaming and non-streaming API calls
|
||
func (s *LLMService) CallLLMAPI(data map[string]interface{}) (interface{}, error) {
|
||
payload := RequestPayload{
|
||
Inputs: make(map[string]interface{}),
|
||
Query: getString(data, "query"),
|
||
ResponseMode: getString(data, "response_mode"),
|
||
User: getString(data, "user"),
|
||
ConversationID: getString(data, "conversation_id"),
|
||
Files: make([]interface{}, 0),
|
||
Audio: getString(data, "audio"),
|
||
}
|
||
|
||
fmt.Printf("前端传来的数据:%+v\n", payload)
|
||
jsonData, err := json.Marshal(payload)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("error marshaling payload: %v", err)
|
||
}
|
||
fmt.Println(s.config.LLMApiURL + "/chat-messages")
|
||
req, err := http.NewRequest("POST", s.config.LLMApiURL+"/chat-messages", bytes.NewBuffer(jsonData))
|
||
// req, err := http.NewRequest("GET", "http://localhost:8080/stream-text", nil)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("error creating request: %v", err)
|
||
}
|
||
|
||
req.Header.Set("Authorization", "Bearer "+s.config.LLMApiKey)
|
||
req.Header.Set("Content-Type", "application/json")
|
||
|
||
isStreaming := payload.ResponseMode == "streaming"
|
||
if isStreaming {
|
||
return s.handleStreamingResponse(req, data, payload.Audio)
|
||
}
|
||
|
||
return s.handleNonStreamingResponse(req)
|
||
}
|
||
|
||
// handleStreamingResponse processes streaming responses
|
||
func (s *LLMService) handleStreamingResponse(req *http.Request, data map[string]interface{}, audio_type string) (chan Message, error) {
|
||
resp, err := s.client.Do(req)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("error making request: %v", err)
|
||
}
|
||
|
||
if resp.StatusCode != http.StatusOK {
|
||
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
|
||
}
|
||
|
||
messageChan := make(chan Message, 100) // Buffered channel for better performance
|
||
initialSessage := ""
|
||
go func() {
|
||
defer resp.Body.Close()
|
||
defer close(messageChan)
|
||
reader := bufio.NewReader(resp.Body)
|
||
for {
|
||
line, err := reader.ReadString('\n')
|
||
if err != nil {
|
||
if err == io.EOF {
|
||
break
|
||
}
|
||
fmt.Printf("Error reading line: %v\n", err)
|
||
continue
|
||
}
|
||
|
||
line = strings.TrimSpace(line)
|
||
if line == "" {
|
||
continue
|
||
}
|
||
|
||
// Remove "data: " prefix if present
|
||
line = strings.TrimPrefix(line, "data: ")
|
||
|
||
var jsonData map[string]interface{}
|
||
if err := json.Unmarshal([]byte(line), &jsonData); err != nil {
|
||
fmt.Printf("Error unmarshaling JSON: %v\n", err)
|
||
continue
|
||
}
|
||
|
||
event := getString(jsonData, "event")
|
||
switch event {
|
||
case "message":
|
||
answer := getString(jsonData, "answer")
|
||
var audio string
|
||
|
||
// 定义标点符号map
|
||
punctuations := map[string]bool{
|
||
",": true, ",": true, // 逗号
|
||
".": true, "。": true, // 句号
|
||
"!": true, "!": true, // 感叹号
|
||
"?": true, "?": true, // 问号
|
||
";": true, ";": true, // 分号
|
||
":": true, ":": true, // 冒号
|
||
"、": true,
|
||
}
|
||
|
||
// 删除字符串前后的标点符号
|
||
trimPunctuation := func(s string) string {
|
||
if len(s) > 0 {
|
||
// 获取最后一个字符的 rune
|
||
lastRune, size := utf8.DecodeLastRuneInString(s)
|
||
if punctuations[string(lastRune)] {
|
||
s = s[:len(s)-size]
|
||
}
|
||
}
|
||
return s
|
||
}
|
||
|
||
// 判断字符串是否包含标点符号
|
||
containsPunctuation := func(s string) bool {
|
||
for _, char := range s {
|
||
if punctuations[string(char)] {
|
||
return true
|
||
}
|
||
}
|
||
return false
|
||
}
|
||
|
||
// 按标点符号分割文本
|
||
splitByPunctuation := func(s string) []string {
|
||
var result []string
|
||
var current string
|
||
for _, char := range s {
|
||
if punctuations[string(char)] {
|
||
if current != "" {
|
||
result = append(result, current+string(char))
|
||
current = ""
|
||
}
|
||
} else {
|
||
current += string(char)
|
||
}
|
||
}
|
||
if current != "" {
|
||
result = append(result, current)
|
||
}
|
||
return result
|
||
}
|
||
new_message := ""
|
||
initialSessage += answer
|
||
if containsPunctuation(initialSessage) {
|
||
segments := splitByPunctuation(initialSessage)
|
||
// fmt.Printf("原始文本: %s\n", initialSessage)
|
||
// fmt.Printf("分割后的片段数量: %d\n", len(segments))
|
||
// for i, segment := range segments {
|
||
// fmt.Printf("片段 %d: %s\n", i+1, segment)
|
||
// }
|
||
if len(segments) > 1 {
|
||
initialSessage = segments[len(segments)-1]
|
||
new_message = strings.Join(segments[:len(segments)-1], "")
|
||
} else {
|
||
new_message = initialSessage
|
||
initialSessage = ""
|
||
}
|
||
// fmt.Printf("新消息: %s\n", new_message)
|
||
// fmt.Printf("剩余文本: %s\n", initialSessage)
|
||
}
|
||
|
||
if new_message == "" {
|
||
continue
|
||
}
|
||
s_msg := strings.TrimSpace(new_message)
|
||
// Trim punctuation from the message
|
||
new_message = trimPunctuation(s_msg)
|
||
// fmt.Println("new_message", new_message)
|
||
|
||
// 最多重试一次
|
||
for i := 0; i < 1; i++ {
|
||
speechResp, err := s.SynthesizeSpeech(new_message, audio_type)
|
||
if err != nil {
|
||
fmt.Printf("Error synthesizing speech: %v\n", err)
|
||
break // 语音接口报错直接跳出
|
||
}
|
||
fmt.Println("语音:", speechResp)
|
||
audio = speechResp.Data.Audio
|
||
if audio != "" {
|
||
// Download audio from URL and trim silence
|
||
resp, err := http.Get(audio)
|
||
if err != nil {
|
||
fmt.Printf("Error downloading audio: %v\n", err)
|
||
} else {
|
||
defer resp.Body.Close()
|
||
audioBytes, err := io.ReadAll(resp.Body)
|
||
if err != nil {
|
||
fmt.Printf("Error reading audio data: %v\n", err)
|
||
} else {
|
||
// Save original audio first
|
||
originalPath := fmt.Sprintf("audio/original_%d.wav", time.Now().Unix())
|
||
if err := os.WriteFile(originalPath, audioBytes, 0644); err != nil {
|
||
fmt.Printf("Error saving original audio: %v\n", err)
|
||
}
|
||
|
||
// Convert audio bytes to base64 for processing
|
||
audioBase64 := base64.StdEncoding.EncodeToString(audioBytes)
|
||
trimmedAudio, err := s.TrimAudioSilence(audioBase64)
|
||
if err != nil {
|
||
fmt.Printf("Error trimming audio silence: %v\n", err)
|
||
} else {
|
||
// Save the trimmed audio as WAV file
|
||
audio_path := fmt.Sprintf("trimmed_%d.wav", time.Now().Unix())
|
||
outputPath := "audio/" + audio_path
|
||
if err := s.SaveBase64AsWAV(trimmedAudio, outputPath); err != nil {
|
||
fmt.Printf("Error saving trimmed WAV file: %v\n", err)
|
||
}
|
||
audio = s.config.FILE_URL + audio_path
|
||
}
|
||
}
|
||
}
|
||
break // 获取到音频就退出
|
||
}
|
||
fmt.Println("audio is empty, retry", speechResp)
|
||
// time.Sleep(1 * time.Second)
|
||
}
|
||
|
||
messageChan <- Message{
|
||
Answer: new_message,
|
||
IsEnd: false,
|
||
ConversationID: getString(jsonData, "conversation_id"),
|
||
TaskID: getString(jsonData, "task_id"),
|
||
ClientID: getString(data, "conversation_id"),
|
||
AudioData: audio, // Update to use the correct path to audio data
|
||
}
|
||
case "message_end":
|
||
messageChan <- Message{
|
||
Answer: "",
|
||
IsEnd: true,
|
||
ConversationID: getString(jsonData, "conversation_id"),
|
||
TaskID: getString(jsonData, "task_id"),
|
||
}
|
||
return
|
||
}
|
||
}
|
||
}()
|
||
|
||
return messageChan, nil
|
||
}
|
||
|
||
// handleNonStreamingResponse processes non-streaming responses
|
||
func (s *LLMService) handleNonStreamingResponse(req *http.Request) (map[string]interface{}, error) {
|
||
resp, err := s.client.Do(req)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("error making request: %v", err)
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
if resp.StatusCode != http.StatusOK {
|
||
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
|
||
}
|
||
|
||
var result map[string]interface{}
|
||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||
return nil, fmt.Errorf("error decoding response: %v", err)
|
||
}
|
||
|
||
return result, nil
|
||
}
|
||
|
||
// StopConversation stops an ongoing conversation
|
||
func (s *LLMService) StopConversation(taskID string) (map[string]interface{}, error) {
|
||
req, err := http.NewRequest("POST", fmt.Sprintf("%s/chat-messages/%s/stop", s.config.LLMApiURL, taskID), nil)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("error creating request: %v", err)
|
||
}
|
||
|
||
req.Header.Set("Authorization", "Bearer "+s.config.LLMApiKey)
|
||
req.Header.Set("Content-Type", "application/json")
|
||
|
||
resp, err := s.client.Do(req)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("error making request: %v", err)
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
var result map[string]interface{}
|
||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||
return nil, fmt.Errorf("error decoding response: %v", err)
|
||
}
|
||
|
||
return result, nil
|
||
}
|
||
|
||
// DeleteConversation deletes a conversation
|
||
func (s *LLMService) DeleteConversation(conversationID, user string) (map[string]interface{}, error) {
|
||
req, err := http.NewRequest("DELETE", fmt.Sprintf("%s/conversations/%s", s.config.LLMApiURL, conversationID), nil)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("error creating request: %v", err)
|
||
}
|
||
|
||
req.Header.Set("Authorization", "Bearer "+s.config.LLMApiKey)
|
||
req.Header.Set("Content-Type", "application/json")
|
||
|
||
resp, err := s.client.Do(req)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("error making request: %v", err)
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
var result map[string]interface{}
|
||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||
return nil, fmt.Errorf("error decoding response: %v", err)
|
||
}
|
||
|
||
return result, nil
|
||
}
|
||
|
||
// SynthesizeSpeech converts text to speech
|
||
func (s *LLMService) SynthesizeSpeech(text string, audio string) (*SpeechResponse, error) {
|
||
payload := SpeechRequest{
|
||
Model: "speech-02-turbo",
|
||
Text: text,
|
||
Stream: false,
|
||
LanguageBoost: "auto",
|
||
OutputFormat: "url",
|
||
VoiceSetting: VoiceSetting{
|
||
VoiceID: audio,
|
||
Speed: 1,
|
||
Vol: 1,
|
||
Pitch: 0,
|
||
Emotion: "happy",
|
||
},
|
||
AudioSetting: AudioSetting{
|
||
SampleRate: 32000,
|
||
Bitrate: 128000,
|
||
Format: "wav",
|
||
},
|
||
}
|
||
|
||
jsonData, err := json.Marshal(payload)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("error marshaling speech request: %v", err)
|
||
}
|
||
|
||
req, err := http.NewRequest("POST", s.config.MiniMaxApiURL, bytes.NewBuffer(jsonData))
|
||
if err != nil {
|
||
fmt.Println("error creating speech request: ", err)
|
||
return nil, fmt.Errorf("error creating speech request: %v", err)
|
||
}
|
||
|
||
req.Header.Set("Authorization", "Bearer "+s.config.MiniMaxApiKey)
|
||
req.Header.Set("Content-Type", "application/json")
|
||
|
||
resp, err := s.client.Do(req)
|
||
if err != nil {
|
||
fmt.Println("error making speech request: ", err)
|
||
return nil, fmt.Errorf("error making speech request: %v", err)
|
||
}
|
||
defer resp.Body.Close()
|
||
// fmt.Println(resp.Body)
|
||
if resp.StatusCode != http.StatusOK {
|
||
fmt.Println("unexpected status code: ", resp.StatusCode)
|
||
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
|
||
}
|
||
|
||
var result SpeechResponse
|
||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||
fmt.Println("error decoding speech response: ", err)
|
||
return nil, fmt.Errorf("error decoding speech response: %v", err)
|
||
}
|
||
|
||
return &result, nil
|
||
}
|
||
|
||
// StreamTextResponse handles streaming text output with predefined segments
|
||
func (s *LLMService) StreamTextResponse(conversationID string) (chan Message, error) {
|
||
messageChan := make(chan Message, 100)
|
||
|
||
segments := []string{
|
||
"好的,",
|
||
"我已经成功替换了文本内容。",
|
||
"新的文本是一段连续的描述,",
|
||
"没有换行,",
|
||
"总共65个字符,",
|
||
"符合100字以内的要求,",
|
||
"并且是一个连续的段落。",
|
||
"现在我需要完成任务。",
|
||
}
|
||
|
||
go func() {
|
||
defer close(messageChan)
|
||
taskID := "task_" + time.Now().Format("20060102150405")
|
||
|
||
for _, segment := range segments {
|
||
// Send message
|
||
messageChan <- Message{
|
||
Answer: segment,
|
||
IsEnd: false,
|
||
ConversationID: conversationID,
|
||
TaskID: taskID,
|
||
ClientID: conversationID,
|
||
}
|
||
|
||
// Add delay between segments
|
||
time.Sleep(500 * time.Millisecond)
|
||
}
|
||
|
||
// Send end message
|
||
messageChan <- Message{
|
||
Answer: "",
|
||
IsEnd: true,
|
||
ConversationID: conversationID,
|
||
TaskID: taskID,
|
||
}
|
||
}()
|
||
|
||
return messageChan, nil
|
||
}
|
||
|
||
// Helper function to safely get string values from interface{}
|
||
func getString(data map[string]interface{}, key string) string {
|
||
if val, ok := data[key]; ok {
|
||
if str, ok := val.(string); ok {
|
||
return str
|
||
}
|
||
}
|
||
return ""
|
||
}
|
||
|
||
// TrimAudioSilence trims the silence at the end of the audio data
|
||
func (s *LLMService) TrimAudioSilence(audioData string) (string, error) {
|
||
// Decode base64 audio data
|
||
decodedData, err := base64.StdEncoding.DecodeString(audioData)
|
||
if err != nil {
|
||
return "", fmt.Errorf("error decoding base64 audio: %v", err)
|
||
}
|
||
|
||
// Create a buffer from the decoded data
|
||
buf := bytes.NewReader(decodedData)
|
||
|
||
// Read RIFF header
|
||
var riffHeader struct {
|
||
ChunkID [4]byte
|
||
ChunkSize uint32
|
||
Format [4]byte
|
||
}
|
||
|
||
if err := binary.Read(buf, binary.LittleEndian, &riffHeader); err != nil {
|
||
return "", fmt.Errorf("error reading RIFF header: %v", err)
|
||
}
|
||
|
||
// Verify RIFF header
|
||
if string(riffHeader.ChunkID[:]) != "RIFF" || string(riffHeader.Format[:]) != "WAVE" {
|
||
return "", fmt.Errorf("invalid WAV format")
|
||
}
|
||
|
||
// Read fmt chunk
|
||
var fmtChunk struct {
|
||
Subchunk1ID [4]byte
|
||
Subchunk1Size uint32
|
||
AudioFormat uint16
|
||
NumChannels uint16
|
||
SampleRate uint32
|
||
ByteRate uint32
|
||
BlockAlign uint16
|
||
BitsPerSample uint16
|
||
}
|
||
|
||
if err := binary.Read(buf, binary.LittleEndian, &fmtChunk); err != nil {
|
||
return "", fmt.Errorf("error reading fmt chunk: %v", err)
|
||
}
|
||
|
||
// Skip any extra bytes in fmt chunk
|
||
if fmtChunk.Subchunk1Size > 16 {
|
||
extraBytes := make([]byte, fmtChunk.Subchunk1Size-16)
|
||
if _, err := buf.Read(extraBytes); err != nil {
|
||
return "", fmt.Errorf("error skipping extra fmt bytes: %v", err)
|
||
}
|
||
}
|
||
|
||
// Find data chunk
|
||
var dataChunk struct {
|
||
Subchunk2ID [4]byte
|
||
Subchunk2Size uint32
|
||
}
|
||
|
||
for {
|
||
if err := binary.Read(buf, binary.LittleEndian, &dataChunk); err != nil {
|
||
return "", fmt.Errorf("error reading chunk header: %v", err)
|
||
}
|
||
|
||
if string(dataChunk.Subchunk2ID[:]) == "data" {
|
||
break
|
||
}
|
||
|
||
// Skip this chunk
|
||
if _, err := buf.Seek(int64(dataChunk.Subchunk2Size), io.SeekCurrent); err != nil {
|
||
return "", fmt.Errorf("error skipping chunk: %v", err)
|
||
}
|
||
}
|
||
|
||
// Read audio data
|
||
audioBytes := make([]byte, dataChunk.Subchunk2Size)
|
||
if _, err := buf.Read(audioBytes); err != nil {
|
||
return "", fmt.Errorf("error reading audio data: %v", err)
|
||
}
|
||
|
||
// Calculate samples per channel
|
||
samplesPerChannel := len(audioBytes) / int(fmtChunk.BlockAlign)
|
||
channels := int(fmtChunk.NumChannels)
|
||
bytesPerSample := int(fmtChunk.BitsPerSample) / 8
|
||
|
||
// Find the last non-silent sample
|
||
lastNonSilent := 0
|
||
silenceThreshold := 0.01 // Adjust this threshold as needed
|
||
|
||
for i := 0; i < samplesPerChannel; i++ {
|
||
isSilent := true
|
||
for ch := 0; ch < channels; ch++ {
|
||
offset := i*int(fmtChunk.BlockAlign) + ch*bytesPerSample
|
||
if offset+bytesPerSample > len(audioBytes) {
|
||
continue
|
||
}
|
||
|
||
// Convert bytes to sample value
|
||
var sample int16
|
||
if err := binary.Read(bytes.NewReader(audioBytes[offset:offset+bytesPerSample]), binary.LittleEndian, &sample); err != nil {
|
||
continue
|
||
}
|
||
|
||
// Normalize sample to [-1, 1] range
|
||
normalizedSample := float64(sample) / 32768.0
|
||
if math.Abs(normalizedSample) > silenceThreshold {
|
||
isSilent = false
|
||
break
|
||
}
|
||
}
|
||
|
||
if !isSilent {
|
||
lastNonSilent = i
|
||
}
|
||
}
|
||
|
||
// Add a small buffer (e.g., 0.1 seconds) after the last non-silent sample
|
||
bufferSamples := int(float64(fmtChunk.SampleRate) * 0.1)
|
||
lastSample := lastNonSilent + bufferSamples
|
||
if lastSample > samplesPerChannel {
|
||
lastSample = samplesPerChannel
|
||
}
|
||
|
||
// Calculate new data size
|
||
newDataSize := lastSample * int(fmtChunk.BlockAlign)
|
||
trimmedAudio := audioBytes[:newDataSize]
|
||
|
||
// Create new buffer for the trimmed audio
|
||
var newBuf bytes.Buffer
|
||
|
||
// Write RIFF header
|
||
riffHeader.ChunkSize = uint32(36 + newDataSize)
|
||
if err := binary.Write(&newBuf, binary.LittleEndian, riffHeader); err != nil {
|
||
return "", fmt.Errorf("error writing RIFF header: %v", err)
|
||
}
|
||
|
||
// Write fmt chunk
|
||
if err := binary.Write(&newBuf, binary.LittleEndian, fmtChunk); err != nil {
|
||
return "", fmt.Errorf("error writing fmt chunk: %v", err)
|
||
}
|
||
|
||
// Write data chunk header
|
||
dataChunk.Subchunk2Size = uint32(newDataSize)
|
||
if err := binary.Write(&newBuf, binary.LittleEndian, dataChunk); err != nil {
|
||
return "", fmt.Errorf("error writing data chunk header: %v", err)
|
||
}
|
||
|
||
// Write trimmed audio data
|
||
if _, err := newBuf.Write(trimmedAudio); err != nil {
|
||
return "", fmt.Errorf("error writing trimmed audio data: %v", err)
|
||
}
|
||
|
||
// Encode back to base64
|
||
return base64.StdEncoding.EncodeToString(newBuf.Bytes()), nil
|
||
}
|
||
|
||
// SaveBase64AsWAV saves base64 encoded audio data as a WAV file
|
||
func (s *LLMService) SaveBase64AsWAV(base64Data string, outputPath string) error {
|
||
// Decode base64 data
|
||
audioData, err := base64.StdEncoding.DecodeString(base64Data)
|
||
if err != nil {
|
||
return fmt.Errorf("error decoding base64 data: %v", err)
|
||
}
|
||
|
||
// Validate WAV header
|
||
if len(audioData) < 44 { // WAV header is 44 bytes
|
||
return fmt.Errorf("invalid WAV data: too short")
|
||
}
|
||
|
||
// Check RIFF header
|
||
if string(audioData[0:4]) != "RIFF" {
|
||
return fmt.Errorf("invalid WAV format: missing RIFF header")
|
||
}
|
||
|
||
// Check WAVE format
|
||
if string(audioData[8:12]) != "WAVE" {
|
||
return fmt.Errorf("invalid WAV format: missing WAVE format")
|
||
}
|
||
|
||
// Create output directory if it doesn't exist
|
||
dir := filepath.Dir(outputPath)
|
||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||
return fmt.Errorf("error creating directory: %v", err)
|
||
}
|
||
|
||
// Write the audio data to file
|
||
if err := os.WriteFile(outputPath, audioData, 0644); err != nil {
|
||
return fmt.Errorf("error writing WAV file: %v", err)
|
||
}
|
||
|
||
return nil
|
||
}
|