package memory import ( "bytes" "context" "encoding/binary" "encoding/json" "fmt" "io" "math" "net/http" "sync" ) type VectorService struct { APIKey string BaseURL string Model string Dimension int } var ( vectorSvc *VectorService vectorMu sync.RWMutex ) type VectorOption func(*VectorService) func WithAPIKey(key string) VectorOption { return func(v *VectorService) { v.APIKey = key } } func WithBaseURL(url string) VectorOption { return func(v *VectorService) { v.BaseURL = url } } func WithModel(model string) VectorOption { return func(v *VectorService) { v.Model = model } } func WithDimension(dim int) VectorOption { return func(v *VectorService) { v.Dimension = dim } } func InitVector(opts ...VectorOption) error { vectorSvc = &VectorService{ APIKey: "", BaseURL: "https://api.siliconflow.cn/v1", Model: "BAAI/bge-m3", Dimension: 1024, } for _, opt := range opts { opt(vectorSvc) } return nil } func GetVectorService() *VectorService { vectorMu.RLock() defer vectorMu.RUnlock() return vectorSvc } func (v *VectorService) Generate(text string) ([]byte, error) { if text == "" { return make([]byte, v.Dimension*4), nil } reqBody := map[string]interface{}{ "input": text, "model": v.Model, } bodyBytes, err := json.Marshal(reqBody) if err != nil { return nil, fmt.Errorf("序列化请求失败: %w", err) } req, err := http.NewRequestWithContext( context.Background(), "POST", v.BaseURL+"/embeddings", bytes.NewReader(bodyBytes), ) if err != nil { return nil, fmt.Errorf("创建请求失败: %w", err) } req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+v.APIKey) client := &http.Client{} resp, err := client.Do(req) if err != nil { return nil, fmt.Errorf("请求失败: %w", err) } defer resp.Body.Close() respBody, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("读取响应失败: %w", err) } if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("API 返回错误: %d %s", resp.StatusCode, string(respBody)) } var result struct { Data []struct { Embedding []float32 `json:"embedding"` } `json:"data"` } if err := json.Unmarshal(respBody, &result); err != nil { return nil, fmt.Errorf("解析响应失败: %w", err) } if len(result.Data) == 0 || len(result.Data[0].Embedding) == 0 { return nil, fmt.Errorf("未获取到向量") } embedding := result.Data[0].Embedding buf := new(bytes.Buffer) if err := binary.Write(buf, binary.LittleEndian, embedding); err != nil { return nil, fmt.Errorf("编码向量失败: %w", err) } return buf.Bytes(), nil } func (v *VectorService) GenerateBatch(texts []string) ([][]byte, error) { if len(texts) == 0 { return [][]byte{}, nil } reqBody := map[string]interface{}{ "input": texts, "model": v.Model, } bodyBytes, err := json.Marshal(reqBody) if err != nil { return nil, fmt.Errorf("序列化请求失败: %w", err) } req, err := http.NewRequestWithContext( context.Background(), "POST", v.BaseURL+"/embeddings", bytes.NewReader(bodyBytes), ) if err != nil { return nil, fmt.Errorf("创建请求失败: %w", err) } req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+v.APIKey) client := &http.Client{} resp, err := client.Do(req) if err != nil { return nil, fmt.Errorf("请求失败: %w", err) } defer resp.Body.Close() respBody, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("读取响应失败: %w", err) } if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("API 返回错误: %d %s", resp.StatusCode, string(respBody)) } var result struct { Data []struct { Embedding []float32 `json:"embedding"` } `json:"data"` } if err := json.Unmarshal(respBody, &result); err != nil { return nil, fmt.Errorf("解析响应失败: %w", err) } embeddings := make([][]byte, len(result.Data)) for i, data := range result.Data { embedding := data.Embedding buf := new(bytes.Buffer) if err := binary.Write(buf, binary.LittleEndian, embedding); err != nil { continue } embeddings[i] = buf.Bytes() } return embeddings, nil } func (v *VectorService) CosineSimilarity(a, b []byte) float64 { if len(a) == 0 || len(b) == 0 { return 0 } vecLen := len(a) / 4 if len(b)/4 < vecLen { vecLen = len(b) / 4 } var dotProduct float32 var normA float32 var normB float32 for i := 0; i < vecLen; i++ { offset := i * 4 fa := binary.LittleEndian.Uint32(a[offset : offset+4]) fb := binary.LittleEndian.Uint32(b[offset : offset+4]) f32a := math.Float32frombits(fa) f32b := math.Float32frombits(fb) dotProduct += f32a * f32b normA += f32a * f32a normB += f32b * f32b } if normA == 0 || normB == 0 { return 0 } return float64(dotProduct / (normA * normB)) }