Files
HxClaw/cmd/hxclaw/internal/memory/vector.go

234 lines
4.8 KiB
Go
Raw Normal View History

package memory
import (
"bytes"
"context"
"encoding/binary"
"encoding/json"
"fmt"
"io"
"math"
"net/http"
"sync"
)
type VectorService struct {
APIKey string
BaseURL string
Model string
Dimension int
}
var (
vectorSvc *VectorService
vectorMu sync.RWMutex
)
type VectorOption func(*VectorService)
func WithAPIKey(key string) VectorOption {
return func(v *VectorService) {
v.APIKey = key
}
}
func WithBaseURL(url string) VectorOption {
return func(v *VectorService) {
v.BaseURL = url
}
}
func WithModel(model string) VectorOption {
return func(v *VectorService) {
v.Model = model
}
}
func WithDimension(dim int) VectorOption {
return func(v *VectorService) {
v.Dimension = dim
}
}
func InitVector(opts ...VectorOption) error {
vectorSvc = &VectorService{
APIKey: "",
BaseURL: "https://api.siliconflow.cn/v1",
Model: "BAAI/bge-m3",
Dimension: 1024,
}
for _, opt := range opts {
opt(vectorSvc)
}
return nil
}
func GetVectorService() *VectorService {
vectorMu.RLock()
defer vectorMu.RUnlock()
return vectorSvc
}
func (v *VectorService) Generate(text string) ([]byte, error) {
if text == "" {
return make([]byte, v.Dimension*4), nil
}
reqBody := map[string]interface{}{
"input": text,
"model": v.Model,
}
bodyBytes, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("序列化请求失败: %w", err)
}
req, err := http.NewRequestWithContext(
context.Background(),
"POST",
v.BaseURL+"/embeddings",
bytes.NewReader(bodyBytes),
)
if err != nil {
return nil, fmt.Errorf("创建请求失败: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+v.APIKey)
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("请求失败: %w", err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("读取响应失败: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API 返回错误: %d %s", resp.StatusCode, string(respBody))
}
var result struct {
Data []struct {
Embedding []float32 `json:"embedding"`
} `json:"data"`
}
if err := json.Unmarshal(respBody, &result); err != nil {
return nil, fmt.Errorf("解析响应失败: %w", err)
}
if len(result.Data) == 0 || len(result.Data[0].Embedding) == 0 {
return nil, fmt.Errorf("未获取到向量")
}
embedding := result.Data[0].Embedding
buf := new(bytes.Buffer)
if err := binary.Write(buf, binary.LittleEndian, embedding); err != nil {
return nil, fmt.Errorf("编码向量失败: %w", err)
}
return buf.Bytes(), nil
}
func (v *VectorService) GenerateBatch(texts []string) ([][]byte, error) {
if len(texts) == 0 {
return [][]byte{}, nil
}
reqBody := map[string]interface{}{
"input": texts,
"model": v.Model,
}
bodyBytes, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("序列化请求失败: %w", err)
}
req, err := http.NewRequestWithContext(
context.Background(),
"POST",
v.BaseURL+"/embeddings",
bytes.NewReader(bodyBytes),
)
if err != nil {
return nil, fmt.Errorf("创建请求失败: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+v.APIKey)
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("请求失败: %w", err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("读取响应失败: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API 返回错误: %d %s", resp.StatusCode, string(respBody))
}
var result struct {
Data []struct {
Embedding []float32 `json:"embedding"`
} `json:"data"`
}
if err := json.Unmarshal(respBody, &result); err != nil {
return nil, fmt.Errorf("解析响应失败: %w", err)
}
embeddings := make([][]byte, len(result.Data))
for i, data := range result.Data {
embedding := data.Embedding
buf := new(bytes.Buffer)
if err := binary.Write(buf, binary.LittleEndian, embedding); err != nil {
continue
}
embeddings[i] = buf.Bytes()
}
return embeddings, nil
}
func (v *VectorService) CosineSimilarity(a, b []byte) float64 {
if len(a) == 0 || len(b) == 0 {
return 0
}
vecLen := len(a) / 4
if len(b)/4 < vecLen {
vecLen = len(b) / 4
}
var dotProduct float32
var normA float32
var normB float32
for i := 0; i < vecLen; i++ {
offset := i * 4
fa := binary.LittleEndian.Uint32(a[offset : offset+4])
fb := binary.LittleEndian.Uint32(b[offset : offset+4])
f32a := math.Float32frombits(fa)
f32b := math.Float32frombits(fb)
dotProduct += f32a * f32b
normA += f32a * f32a
normB += f32b * f32b
}
if normA == 0 || normB == 0 {
return 0
}
return float64(dotProduct / (normA * normB))
}