fix: 补充提交 memory 模块和 tts/userdir 文件
This commit is contained in:
234
cmd/hxclaw/internal/memory/vector.go
Normal file
234
cmd/hxclaw/internal/memory/vector.go
Normal file
@@ -0,0 +1,234 @@
|
||||
package memory
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"net/http"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type VectorService struct {
|
||||
APIKey string
|
||||
BaseURL string
|
||||
Model string
|
||||
Dimension int
|
||||
}
|
||||
|
||||
var (
|
||||
vectorSvc *VectorService
|
||||
vectorMu sync.RWMutex
|
||||
)
|
||||
|
||||
type VectorOption func(*VectorService)
|
||||
|
||||
func WithAPIKey(key string) VectorOption {
|
||||
return func(v *VectorService) {
|
||||
v.APIKey = key
|
||||
}
|
||||
}
|
||||
|
||||
func WithBaseURL(url string) VectorOption {
|
||||
return func(v *VectorService) {
|
||||
v.BaseURL = url
|
||||
}
|
||||
}
|
||||
|
||||
func WithModel(model string) VectorOption {
|
||||
return func(v *VectorService) {
|
||||
v.Model = model
|
||||
}
|
||||
}
|
||||
|
||||
func WithDimension(dim int) VectorOption {
|
||||
return func(v *VectorService) {
|
||||
v.Dimension = dim
|
||||
}
|
||||
}
|
||||
|
||||
func InitVector(opts ...VectorOption) error {
|
||||
vectorSvc = &VectorService{
|
||||
APIKey: "",
|
||||
BaseURL: "https://api.siliconflow.cn/v1",
|
||||
Model: "BAAI/bge-m3",
|
||||
Dimension: 1024,
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(vectorSvc)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetVectorService() *VectorService {
|
||||
vectorMu.RLock()
|
||||
defer vectorMu.RUnlock()
|
||||
return vectorSvc
|
||||
}
|
||||
|
||||
func (v *VectorService) Generate(text string) ([]byte, error) {
|
||||
if text == "" {
|
||||
return make([]byte, v.Dimension*4), nil
|
||||
}
|
||||
|
||||
reqBody := map[string]interface{}{
|
||||
"input": text,
|
||||
"model": v.Model,
|
||||
}
|
||||
bodyBytes, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("序列化请求失败: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(
|
||||
context.Background(),
|
||||
"POST",
|
||||
v.BaseURL+"/embeddings",
|
||||
bytes.NewReader(bodyBytes),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建请求失败: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+v.APIKey)
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("请求失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取响应失败: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("API 返回错误: %d %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Data []struct {
|
||||
Embedding []float32 `json:"embedding"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||
return nil, fmt.Errorf("解析响应失败: %w", err)
|
||||
}
|
||||
|
||||
if len(result.Data) == 0 || len(result.Data[0].Embedding) == 0 {
|
||||
return nil, fmt.Errorf("未获取到向量")
|
||||
}
|
||||
|
||||
embedding := result.Data[0].Embedding
|
||||
buf := new(bytes.Buffer)
|
||||
if err := binary.Write(buf, binary.LittleEndian, embedding); err != nil {
|
||||
return nil, fmt.Errorf("编码向量失败: %w", err)
|
||||
}
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func (v *VectorService) GenerateBatch(texts []string) ([][]byte, error) {
|
||||
if len(texts) == 0 {
|
||||
return [][]byte{}, nil
|
||||
}
|
||||
|
||||
reqBody := map[string]interface{}{
|
||||
"input": texts,
|
||||
"model": v.Model,
|
||||
}
|
||||
bodyBytes, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("序列化请求失败: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(
|
||||
context.Background(),
|
||||
"POST",
|
||||
v.BaseURL+"/embeddings",
|
||||
bytes.NewReader(bodyBytes),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建请求失败: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+v.APIKey)
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("请求失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取响应失败: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("API 返回错误: %d %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Data []struct {
|
||||
Embedding []float32 `json:"embedding"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||
return nil, fmt.Errorf("解析响应失败: %w", err)
|
||||
}
|
||||
|
||||
embeddings := make([][]byte, len(result.Data))
|
||||
for i, data := range result.Data {
|
||||
embedding := data.Embedding
|
||||
buf := new(bytes.Buffer)
|
||||
if err := binary.Write(buf, binary.LittleEndian, embedding); err != nil {
|
||||
continue
|
||||
}
|
||||
embeddings[i] = buf.Bytes()
|
||||
}
|
||||
|
||||
return embeddings, nil
|
||||
}
|
||||
|
||||
func (v *VectorService) CosineSimilarity(a, b []byte) float64 {
|
||||
if len(a) == 0 || len(b) == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
vecLen := len(a) / 4
|
||||
if len(b)/4 < vecLen {
|
||||
vecLen = len(b) / 4
|
||||
}
|
||||
|
||||
var dotProduct float32
|
||||
var normA float32
|
||||
var normB float32
|
||||
|
||||
for i := 0; i < vecLen; i++ {
|
||||
offset := i * 4
|
||||
fa := binary.LittleEndian.Uint32(a[offset : offset+4])
|
||||
fb := binary.LittleEndian.Uint32(b[offset : offset+4])
|
||||
f32a := math.Float32frombits(fa)
|
||||
f32b := math.Float32frombits(fb)
|
||||
dotProduct += f32a * f32b
|
||||
normA += f32a * f32a
|
||||
normB += f32b * f32b
|
||||
}
|
||||
|
||||
if normA == 0 || normB == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
return float64(dotProduct / (normA * normB))
|
||||
}
|
||||
Reference in New Issue
Block a user