go_digital_backend/service/llm_service.go
Song367 9aee63f624
All checks were successful
Gitea Actions Demo / Explore-Gitea-Actions (push) Successful in 3s
添加15个字,末尾标点符号处理
2025-08-14 18:27:27 +08:00

1020 lines
29 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import (
"bufio"
"bytes"
"encoding/base64"
"encoding/binary"
"encoding/json"
"fmt"
"io"
"math"
"net/http"
"os"
"path/filepath"
"strings"
"time"
"unicode/utf8"
)
// Config holds the configuration for the LLM service
type Config struct {
LLMApiURL string
LLMApiKey string
MiniMaxApiKey string
MiniMaxApiURL string
FILE_URL string
LLMOurApiUrl string
LLMOurApiKey string
}
// LLMService handles communication with the LLM API
type LLMService struct {
config Config
client *http.Client
}
// Message represents a single message in the conversation
type Message struct {
Answer string `json:"answer"`
IsEnd bool `json:"isEnd"`
ConversationID string `json:"conversation_id"`
TaskID string `json:"task_id"`
ClientID string `json:"client_id,omitempty"`
AudioData string `json:"audio_data,omitempty"`
}
// RequestPayload represents the payload sent to the LLM API
type RequestPayload struct {
Inputs map[string]interface{} `json:"inputs"`
Query string `json:"query"`
ResponseMode string `json:"response_mode"`
User string `json:"user"`
ConversationID string `json:"conversation_id"`
Files []interface{} `json:"files"`
Audio string `json:"audio"`
LlmType string `json:"llm_type"`
}
// VoiceSetting represents voice configuration
type VoiceSetting struct {
VoiceID string `json:"voice_id"`
Speed float64 `json:"speed"`
Vol float64 `json:"vol"`
Pitch float64 `json:"pitch"`
Emotion string `json:"emotion"`
}
// AudioSetting represents audio configuration
type AudioSetting struct {
SampleRate int `json:"sample_rate"`
Bitrate int `json:"bitrate"`
Format string `json:"format"`
}
// SpeechRequest represents the speech synthesis request payload
type SpeechRequest struct {
Model string `json:"model"`
Text string `json:"text"`
Stream bool `json:"stream"`
LanguageBoost string `json:"language_boost"`
OutputFormat string `json:"output_format"`
VoiceSetting VoiceSetting `json:"voice_setting"`
AudioSetting AudioSetting `json:"audio_setting"`
}
// SpeechData represents the speech data in the response
type SpeechData struct {
Audio string `json:"audio"`
Status int `json:"status"`
}
// ExtraInfo represents additional information about the speech
type ExtraInfo struct {
AudioLength int `json:"audio_length"`
AudioSampleRate int `json:"audio_sample_rate"`
AudioSize int `json:"audio_size"`
AudioBitrate int `json:"audio_bitrate"`
WordCount int `json:"word_count"`
InvisibleCharacterRatio float64 `json:"invisible_character_ratio"`
AudioFormat string `json:"audio_format"`
UsageCharacters int `json:"usage_characters"`
}
// BaseResponse represents the base response structure
type BaseResponse struct {
StatusCode int `json:"status_code"`
StatusMsg string `json:"status_msg"`
}
// SpeechResponse represents the speech synthesis response
type SpeechResponse struct {
Data SpeechData `json:"data"`
ExtraInfo ExtraInfo `json:"extra_info"`
TraceID string `json:"trace_id"`
BaseResp BaseResponse `json:"base_resp"`
}
type LLMOurMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type LLMOurRequestPayload struct {
Model string `json:"model"`
Stream bool `json:"stream"`
StreamOptions map[string]interface{} `json:"stream_options"`
Messages []LLMOurMessage `json:"messages"`
}
// NewLLMService creates a new instance of LLMService
func NewLLMService(config Config) *LLMService {
return &LLMService{
config: config,
client: &http.Client{},
}
}
// CallLLMAPI handles both streaming and non-streaming API calls
func (s *LLMService) CallLLMAPI(data map[string]interface{}) (interface{}, error) {
payload := RequestPayload{
Inputs: make(map[string]interface{}),
Query: getString(data, "query"),
ResponseMode: getString(data, "response_mode"),
User: getString(data, "user"),
ConversationID: getString(data, "conversation_id"),
Files: make([]interface{}, 0),
Audio: getString(data, "audio"),
LlmType: getString(data, "llm_type"),
}
fmt.Printf("前端传来的数据:%+v\n", payload)
jsonData, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("error marshaling payload: %v", err)
}
currentUrl := s.config.LLMApiURL + "/chat-messages"
fmt.Println(currentUrl)
req := &http.Request{}
if payload.LlmType == "ours" {
// 动态构造 messages
var messages []LLMOurMessage
if msgs, ok := data["messages"]; ok {
if arr, ok := msgs.([]interface{}); ok {
for _, m := range arr {
if mMap, ok := m.(map[string]interface{}); ok {
role, _ := mMap["role"].(string)
content, _ := mMap["content"].(string)
messages = append(messages, LLMOurMessage{Role: role, Content: content})
}
}
}
}
// fallback: 如果没有 messages则用 query 作为 user 消息
if len(messages) == 0 && payload.Query != "" {
messages = append(messages, LLMOurMessage{Role: "user", Content: payload.Query})
}
ourPayload := LLMOurRequestPayload{
Model: "bot-20250522162100-44785", // 可根据 data 或配置传入
Stream: true,
StreamOptions: map[string]interface{}{"include_usage": true},
Messages: messages,
}
jsonData, err = json.Marshal(ourPayload)
if err != nil {
return nil, fmt.Errorf("error marshaling ourPayload: %v", err)
}
currentUrl = s.config.LLMOurApiUrl
req, err = http.NewRequest("POST", currentUrl, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("error creating request: %v", err)
}
req.Header.Set("Authorization", "Bearer "+s.config.LLMOurApiKey)
req.Header.Set("Content-Type", "application/json")
return s.handleStreamingResponseV2(req, data, payload.Audio)
}
req, err = http.NewRequest("POST", currentUrl, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("error creating request: %v", err)
}
req.Header.Set("Authorization", "Bearer "+s.config.LLMApiKey)
req.Header.Set("Content-Type", "application/json")
isStreaming := payload.ResponseMode == "streaming"
if isStreaming {
return s.handleStreamingResponse(req, data, payload.Audio)
}
return s.handleNonStreamingResponse(req)
}
// processStreamSegment 处理流式文本分段、语音合成等逻辑,返回 new_message、audio、是否需要发送
func (s *LLMService) processStreamSegment(initialSessage *string, all_message *string, answer string, audio_type string) (string, string, bool) {
// 定义标点符号map
punctuations := map[string]bool{
",": true, "": true, // 逗号
".": true, "。": true, // 句号
"!": true, "": true, // 感叹号
"?": true, "": true, // 问号
";": true, "": true, // 分号
":": true, "": true, // 冒号
"、": true,
}
// 删除字符串前后的标点符号
trimPunctuation := func(s string) string {
if len(s) > 0 {
lastRune, size := utf8.DecodeLastRuneInString(s)
if punctuations[string(lastRune)] {
s = s[:len(s)-size]
}
}
return s
}
// 判断字符串是否包含标点符号
containsPunctuation := func(s string) bool {
for _, char := range s {
if punctuations[string(char)] {
return true
}
}
return false
}
// 按标点符号分割文本
splitByPunctuation := func(s string) []string {
var result []string
var current string
for _, char := range s {
if punctuations[string(char)] {
if current != "" {
result = append(result, current+string(char))
current = ""
}
} else {
current += string(char)
}
}
if current != "" {
result = append(result, current)
}
return result
}
*initialSessage += answer
*all_message += answer
new_message := ""
if containsPunctuation(*initialSessage) {
segments := splitByPunctuation(*initialSessage)
if len(segments) > 1 {
format_message := strings.Join(segments[:len(segments)-1], "")
if utf8.RuneCountInString(format_message) > 10 {
*initialSessage = segments[len(segments)-1]
new_message = strings.Join(segments[:len(segments)-1], "")
} else {
return "", "", false
}
} else {
if utf8.RuneCountInString(*initialSessage) > 10 {
new_message = *initialSessage
*initialSessage = ""
} else if utf8.RuneCountInString(*initialSessage) <= 10 && strings.HasSuffix(*initialSessage, "。") {
new_message = *initialSessage
*initialSessage = ""
} else {
return "", "", false
}
}
}
if new_message == "" {
return "", "", false
}
s_msg := strings.TrimSpace(new_message)
new_message = trimPunctuation(s_msg)
audio := ""
for i := 0; i < 1; i++ {
speechResp, err := s.SynthesizeSpeech(new_message, audio_type)
if err != nil {
fmt.Printf("Error synthesizing speech: %v\n", err)
break
}
fmt.Println("触发音频", speechResp)
audio = speechResp.Data.Audio
if audio != "" {
resp, err := http.Get(audio)
if err != nil {
fmt.Printf("Error downloading audio: %v\n", err)
} else {
defer resp.Body.Close()
audioBytes, err := io.ReadAll(resp.Body)
if err != nil {
fmt.Printf("Error reading audio data: %v\n", err)
} else {
originalPath := fmt.Sprintf("audio/original_%d.wav", time.Now().UnixNano())
if err := os.WriteFile(originalPath, audioBytes, 0644); err != nil {
fmt.Printf("Error saving original audio: %v\n", err)
}
audioBase64 := base64.StdEncoding.EncodeToString(audioBytes)
trimmedAudio, err := s.TrimAudioSilence(audioBase64)
if err != nil {
fmt.Printf("Error trimming audio silence: %v\n", err)
} else {
audio_path := fmt.Sprintf("trimmed_%d.wav", time.Now().UnixNano())
outputPath := "audio/" + audio_path
if err := s.SaveBase64AsWAV(trimmedAudio, outputPath); err != nil {
fmt.Printf("Error saving trimmed WAV file: %v\n", err)
}
audio = s.config.FILE_URL + audio_path
}
}
}
break
}
}
return s_msg, audio, true
}
// handleStreamingResponse processes streaming responses
func (s *LLMService) handleStreamingResponse(req *http.Request, data map[string]interface{}, audio_type string) (chan Message, error) {
resp, err := s.client.Do(req)
if err != nil {
return nil, fmt.Errorf("error making request: %v", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}
messageChan := make(chan Message, 100) // Buffered channel for better performance
all_message := ""
initialSessage := ""
go func() {
defer resp.Body.Close()
defer close(messageChan)
reader := bufio.NewReader(resp.Body)
for {
line, err := reader.ReadString('\n')
if err != nil {
if err == io.EOF {
break
}
fmt.Printf("Error reading line: %v\n", err)
continue
}
line = strings.TrimSpace(line)
if line == "" {
continue
}
// Remove "data: " prefix if present
line = strings.TrimPrefix(line, "data: ")
var jsonData map[string]interface{}
if err := json.Unmarshal([]byte(line), &jsonData); err != nil {
fmt.Printf("Error unmarshaling JSON: %v\n", err)
continue
}
event := getString(jsonData, "event")
switch event {
case "message":
answer := getString(jsonData, "answer")
fmt.Println("源文本:", answer)
var audio string
// 定义标点符号map
punctuations := map[string]bool{
",": true, "": true, // 逗号
".": true, "。": true, // 句号
"!": true, "": true, // 感叹号
"?": true, "": true, // 问号
";": true, "": true, // 分号
"": true, // 冒号
"、": true,
}
// 删除字符串前后的标点符号
trimPunctuation := func(s string) string {
if len(s) > 0 {
// 获取最后一个字符的 rune
lastRune, size := utf8.DecodeLastRuneInString(s)
if punctuations[string(lastRune)] {
s = s[:len(s)-size]
}
}
return s
}
// 判断字符串是否包含标点符号
containsPunctuation := func(s string) bool {
for _, char := range s {
if punctuations[string(char)] {
return true
}
}
return false
}
// 按标点符号分割文本
splitByPunctuation := func(s string) []string {
var result []string
var current string
for _, char := range s {
if punctuations[string(char)] {
if current != "" {
result = append(result, current+string(char))
current = ""
}
} else {
current += string(char)
}
}
if current != "" {
result = append(result, current)
}
return result
}
new_message := ""
initialSessage += answer
all_message += answer
if containsPunctuation(initialSessage) {
segments := splitByPunctuation(initialSessage)
// fmt.Printf("原始文本: %s\n", initialSessage)
// fmt.Printf("分割后的片段数量: %d\n", len(segments))
// for i, segment := range segments {
// fmt.Printf("片段 %d: %s\n", i+1, segment)
// }
if len(segments) > 1 {
format_message := strings.Join(segments[:len(segments)-1], "")
// 检查initialSessage的字符长度是否超过15个
if utf8.RuneCountInString(format_message) > 15 {
initialSessage = segments[len(segments)-1]
// 如果超过10个字符将其添加到new_message中并清空initialSessage
new_message = strings.Join(segments[:len(segments)-1], "")
// initialSessage = ""
} else {
continue
}
} else {
if utf8.RuneCountInString(initialSessage) > 15 {
new_message = initialSessage
initialSessage = ""
} else if utf8.RuneCountInString(initialSessage) <= 15 && containsPunctuation(initialSessage) {
new_message = initialSessage
initialSessage = ""
} else {
continue
}
}
// fmt.Printf("新消息: %s\n", new_message)
// fmt.Printf("剩余文本: %s\n", initialSessage)
}
if new_message == "" {
continue
}
s_msg := strings.TrimSpace(new_message)
// Trim punctuation from the message
new_message = trimPunctuation(s_msg)
fmt.Println("new_message", new_message)
// println(new_message)
// 最多重试一次
for i := 0; i < 1; i++ {
speechResp, err := s.SynthesizeSpeech(new_message, audio_type)
if err != nil {
fmt.Printf("Error synthesizing speech: %v\n", err)
break // 语音接口报错直接跳出
}
fmt.Println("语音:", speechResp)
audio = speechResp.Data.Audio
if audio != "" {
// Download audio from URL and trim silence
resp, err := http.Get(audio)
if err != nil {
fmt.Printf("Error downloading audio: %v\n", err)
} else {
defer resp.Body.Close()
audioBytes, err := io.ReadAll(resp.Body)
if err != nil {
fmt.Printf("Error reading audio data: %v\n", err)
} else {
// Save original audio first
originalPath := fmt.Sprintf("audio/original_%d.wav", time.Now().UnixNano())
if err := os.WriteFile(originalPath, audioBytes, 0644); err != nil {
fmt.Printf("Error saving original audio: %v\n", err)
}
// Convert audio bytes to base64 for processing
audioBase64 := base64.StdEncoding.EncodeToString(audioBytes)
trimmedAudio, err := s.TrimAudioSilence(audioBase64)
if err != nil {
fmt.Printf("Error trimming audio silence: %v\n", err)
} else {
// Save the trimmed audio as WAV file
audio_path := fmt.Sprintf("trimmed_%d.wav", time.Now().UnixNano())
outputPath := "audio/" + audio_path
if err := s.SaveBase64AsWAV(trimmedAudio, outputPath); err != nil {
fmt.Printf("Error saving trimmed WAV file: %v\n", err)
}
audio = s.config.FILE_URL + audio_path
}
}
}
break // 获取到音频就退出
}
// fmt.Println("audio is empty, retry", speechResp)
// time.Sleep(1 * time.Second)
}
fmt.Println("所有消息:", all_message)
messageChan <- Message{
Answer: s_msg,
IsEnd: false,
ConversationID: getString(jsonData, "conversation_id"),
TaskID: getString(jsonData, "task_id"),
ClientID: getString(data, "conversation_id"),
AudioData: audio, // Update to use the correct path to audio data
}
case "message_end":
messageChan <- Message{
Answer: "",
IsEnd: true,
ConversationID: getString(jsonData, "conversation_id"),
TaskID: getString(jsonData, "task_id"),
}
return
}
}
}()
return messageChan, nil
}
// handleStreamingResponseV2 适配新流式返回格式
func (s *LLMService) handleStreamingResponseV2(req *http.Request, data map[string]interface{}, audio_type string) (chan Message, error) {
resp, err := s.client.Do(req)
if err != nil {
return nil, fmt.Errorf("error making request: %v", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}
messageChan := make(chan Message, 100)
all_message := ""
initialSessage := ""
go func() {
defer resp.Body.Close()
defer close(messageChan)
reader := bufio.NewReader(resp.Body)
for {
line, err := reader.ReadString('\n')
if err != nil {
if err == io.EOF {
break
}
fmt.Printf("Error reading line: %v\n", err)
continue
}
line = strings.TrimSpace(line)
if line == "" {
continue
}
// line = strings.TrimSpace(line)
if strings.HasPrefix(line, "data:") {
line = strings.TrimSpace(strings.TrimPrefix(line, "data:"))
}
// fmt.Println("line: ", line)
if line == "[DONE]" {
messageChan <- Message{
Answer: "",
IsEnd: true,
ConversationID: getString(data, "conversation_id"),
TaskID: getString(data, "task_id"),
}
return
}
var jsonData map[string]interface{}
if err := json.Unmarshal([]byte(line), &jsonData); err != nil {
fmt.Printf("Error unmarshaling JSON: %v\n", err)
continue
}
choices, ok := jsonData["choices"].([]interface{})
if !ok || len(choices) == 0 {
continue
}
choice, ok := choices[0].(map[string]interface{})
if !ok {
continue
}
delta, ok := choice["delta"].(map[string]interface{})
if !ok {
continue
}
content, _ := delta["content"].(string)
if content == "" {
continue
}
new_message, audio, needSend := s.processStreamSegment(&initialSessage, &all_message, content, audio_type)
if !needSend {
continue
}
messageChan <- Message{
Answer: new_message,
IsEnd: false,
ConversationID: getString(data, "conversation_id"),
TaskID: getString(data, "task_id"),
ClientID: getString(data, "conversation_id"),
AudioData: audio,
}
}
}()
return messageChan, nil
}
// handleNonStreamingResponse processes non-streaming responses
func (s *LLMService) handleNonStreamingResponse(req *http.Request) (map[string]interface{}, error) {
resp, err := s.client.Do(req)
if err != nil {
return nil, fmt.Errorf("error making request: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}
var result map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("error decoding response: %v", err)
}
return result, nil
}
// StopConversation stops an ongoing conversation
func (s *LLMService) StopConversation(taskID string) (map[string]interface{}, error) {
req, err := http.NewRequest("POST", fmt.Sprintf("%s/chat-messages/%s/stop", s.config.LLMApiURL, taskID), nil)
if err != nil {
return nil, fmt.Errorf("error creating request: %v", err)
}
req.Header.Set("Authorization", "Bearer "+s.config.LLMApiKey)
req.Header.Set("Content-Type", "application/json")
resp, err := s.client.Do(req)
if err != nil {
return nil, fmt.Errorf("error making request: %v", err)
}
defer resp.Body.Close()
var result map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("error decoding response: %v", err)
}
return result, nil
}
// DeleteConversation deletes a conversation
func (s *LLMService) DeleteConversation(conversationID, user string) (map[string]interface{}, error) {
req, err := http.NewRequest("DELETE", fmt.Sprintf("%s/conversations/%s", s.config.LLMApiURL, conversationID), nil)
if err != nil {
return nil, fmt.Errorf("error creating request: %v", err)
}
req.Header.Set("Authorization", "Bearer "+s.config.LLMApiKey)
req.Header.Set("Content-Type", "application/json")
resp, err := s.client.Do(req)
if err != nil {
return nil, fmt.Errorf("error making request: %v", err)
}
defer resp.Body.Close()
var result map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("error decoding response: %v", err)
}
return result, nil
}
// SynthesizeSpeech converts text to speech
func (s *LLMService) SynthesizeSpeech(text string, audio string) (*SpeechResponse, error) {
payload := SpeechRequest{
Model: "speech-02-hd",
Text: text,
Stream: false,
LanguageBoost: "auto",
OutputFormat: "url",
VoiceSetting: VoiceSetting{
VoiceID: audio,
Speed: 1,
Vol: 1,
Pitch: -1,
Emotion: "neutral",
},
AudioSetting: AudioSetting{
SampleRate: 32000,
Bitrate: 128000,
Format: "wav",
},
}
jsonData, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("error marshaling speech request: %v", err)
}
req, err := http.NewRequest("POST", s.config.MiniMaxApiURL, bytes.NewBuffer(jsonData))
if err != nil {
fmt.Println("error creating speech request: ", err)
return nil, fmt.Errorf("error creating speech request: %v", err)
}
req.Header.Set("Authorization", "Bearer "+s.config.MiniMaxApiKey)
req.Header.Set("Content-Type", "application/json")
resp, err := s.client.Do(req)
if err != nil {
fmt.Println("error making speech request: ", err)
return nil, fmt.Errorf("error making speech request: %v", err)
}
defer resp.Body.Close()
// fmt.Println(resp.Body)
if resp.StatusCode != http.StatusOK {
fmt.Println("unexpected status code: ", resp.StatusCode)
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}
var result SpeechResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
fmt.Println("error decoding speech response: ", err)
return nil, fmt.Errorf("error decoding speech response: %v", err)
}
return &result, nil
}
// StreamTextResponse handles streaming text output with predefined segments
func (s *LLMService) StreamTextResponse(conversationID string) (chan Message, error) {
messageChan := make(chan Message, 100)
segments := []string{
"好的,",
"我已经成功替换了文本内容。",
"新的文本是一段连续的描述,",
"没有换行,",
"总共65个字符",
"符合100字以内的要求",
"并且是一个连续的段落。",
"现在我需要完成任务。",
}
go func() {
defer close(messageChan)
taskID := "task_" + time.Now().Format("20060102150405")
for _, segment := range segments {
// Send message
messageChan <- Message{
Answer: segment,
IsEnd: false,
ConversationID: conversationID,
TaskID: taskID,
ClientID: conversationID,
}
// Add delay between segments
time.Sleep(500 * time.Millisecond)
}
// Send end message
messageChan <- Message{
Answer: "",
IsEnd: true,
ConversationID: conversationID,
TaskID: taskID,
}
}()
return messageChan, nil
}
// Helper function to safely get string values from interface{}
func getString(data map[string]interface{}, key string) string {
if val, ok := data[key]; ok {
if str, ok := val.(string); ok {
return str
}
}
return ""
}
// TrimAudioSilence trims the silence at the end of the audio data
func (s *LLMService) TrimAudioSilence(audioData string) (string, error) {
// Decode base64 audio data
decodedData, err := base64.StdEncoding.DecodeString(audioData)
if err != nil {
return "", fmt.Errorf("error decoding base64 audio: %v", err)
}
// Create a buffer from the decoded data
buf := bytes.NewReader(decodedData)
// Read RIFF header
var riffHeader struct {
ChunkID [4]byte
ChunkSize uint32
Format [4]byte
}
if err := binary.Read(buf, binary.LittleEndian, &riffHeader); err != nil {
return "", fmt.Errorf("error reading RIFF header: %v", err)
}
// Verify RIFF header
if string(riffHeader.ChunkID[:]) != "RIFF" || string(riffHeader.Format[:]) != "WAVE" {
return "", fmt.Errorf("invalid WAV format")
}
// Read fmt chunk
var fmtChunk struct {
Subchunk1ID [4]byte
Subchunk1Size uint32
AudioFormat uint16
NumChannels uint16
SampleRate uint32
ByteRate uint32
BlockAlign uint16
BitsPerSample uint16
}
if err := binary.Read(buf, binary.LittleEndian, &fmtChunk); err != nil {
return "", fmt.Errorf("error reading fmt chunk: %v", err)
}
// Skip any extra bytes in fmt chunk
if fmtChunk.Subchunk1Size > 16 {
extraBytes := make([]byte, fmtChunk.Subchunk1Size-16)
if _, err := buf.Read(extraBytes); err != nil {
return "", fmt.Errorf("error skipping extra fmt bytes: %v", err)
}
}
// Find data chunk
var dataChunk struct {
Subchunk2ID [4]byte
Subchunk2Size uint32
}
for {
if err := binary.Read(buf, binary.LittleEndian, &dataChunk); err != nil {
return "", fmt.Errorf("error reading chunk header: %v", err)
}
if string(dataChunk.Subchunk2ID[:]) == "data" {
break
}
// Skip this chunk
if _, err := buf.Seek(int64(dataChunk.Subchunk2Size), io.SeekCurrent); err != nil {
return "", fmt.Errorf("error skipping chunk: %v", err)
}
}
// Read audio data
audioBytes := make([]byte, dataChunk.Subchunk2Size)
if _, err := buf.Read(audioBytes); err != nil {
return "", fmt.Errorf("error reading audio data: %v", err)
}
// Calculate samples per channel
samplesPerChannel := len(audioBytes) / int(fmtChunk.BlockAlign)
channels := int(fmtChunk.NumChannels)
bytesPerSample := int(fmtChunk.BitsPerSample) / 8
// Find the last non-silent sample
lastNonSilent := 0
silenceThreshold := 0.01 // Adjust this threshold as needed
for i := 0; i < samplesPerChannel; i++ {
isSilent := true
for ch := 0; ch < channels; ch++ {
offset := i*int(fmtChunk.BlockAlign) + ch*bytesPerSample
if offset+bytesPerSample > len(audioBytes) {
continue
}
// Convert bytes to sample value
var sample int16
if err := binary.Read(bytes.NewReader(audioBytes[offset:offset+bytesPerSample]), binary.LittleEndian, &sample); err != nil {
continue
}
// Normalize sample to [-1, 1] range
normalizedSample := float64(sample) / 32768.0
if math.Abs(normalizedSample) > silenceThreshold {
isSilent = false
break
}
}
if !isSilent {
lastNonSilent = i
}
}
// Add a small buffer (e.g., 0.1 seconds) after the last non-silent sample
bufferSamples := int(float64(fmtChunk.SampleRate) * 0.1)
lastSample := lastNonSilent + bufferSamples
if lastSample > samplesPerChannel {
lastSample = samplesPerChannel
}
// Calculate new data size
newDataSize := lastSample * int(fmtChunk.BlockAlign)
trimmedAudio := audioBytes[:newDataSize]
// Create new buffer for the trimmed audio
var newBuf bytes.Buffer
// Write RIFF header
riffHeader.ChunkSize = uint32(36 + newDataSize)
if err := binary.Write(&newBuf, binary.LittleEndian, riffHeader); err != nil {
return "", fmt.Errorf("error writing RIFF header: %v", err)
}
// Write fmt chunk
if err := binary.Write(&newBuf, binary.LittleEndian, fmtChunk); err != nil {
return "", fmt.Errorf("error writing fmt chunk: %v", err)
}
// Write data chunk header
dataChunk.Subchunk2Size = uint32(newDataSize)
if err := binary.Write(&newBuf, binary.LittleEndian, dataChunk); err != nil {
return "", fmt.Errorf("error writing data chunk header: %v", err)
}
// Write trimmed audio data
if _, err := newBuf.Write(trimmedAudio); err != nil {
return "", fmt.Errorf("error writing trimmed audio data: %v", err)
}
// Encode back to base64
return base64.StdEncoding.EncodeToString(newBuf.Bytes()), nil
}
// SaveBase64AsWAV saves base64 encoded audio data as a WAV file
func (s *LLMService) SaveBase64AsWAV(base64Data string, outputPath string) error {
// Decode base64 data
audioData, err := base64.StdEncoding.DecodeString(base64Data)
if err != nil {
return fmt.Errorf("error decoding base64 data: %v", err)
}
// Validate WAV header
if len(audioData) < 44 { // WAV header is 44 bytes
return fmt.Errorf("invalid WAV data: too short")
}
// Check RIFF header
if string(audioData[0:4]) != "RIFF" {
return fmt.Errorf("invalid WAV format: missing RIFF header")
}
// Check WAVE format
if string(audioData[8:12]) != "WAVE" {
return fmt.Errorf("invalid WAV format: missing WAVE format")
}
// Create output directory if it doesn't exist
dir := filepath.Dir(outputPath)
if err := os.MkdirAll(dir, 0755); err != nil {
return fmt.Errorf("error creating directory: %v", err)
}
// Write the audio data to file
if err := os.WriteFile(outputPath, audioData, 0644); err != nil {
return fmt.Errorf("error writing WAV file: %v", err)
}
return nil
}