234 lines
4.8 KiB
Go
234 lines
4.8 KiB
Go
package memory
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/binary"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"math"
|
|
"net/http"
|
|
"sync"
|
|
)
|
|
|
|
type VectorService struct {
|
|
APIKey string
|
|
BaseURL string
|
|
Model string
|
|
Dimension int
|
|
}
|
|
|
|
var (
|
|
vectorSvc *VectorService
|
|
vectorMu sync.RWMutex
|
|
)
|
|
|
|
type VectorOption func(*VectorService)
|
|
|
|
func WithAPIKey(key string) VectorOption {
|
|
return func(v *VectorService) {
|
|
v.APIKey = key
|
|
}
|
|
}
|
|
|
|
func WithBaseURL(url string) VectorOption {
|
|
return func(v *VectorService) {
|
|
v.BaseURL = url
|
|
}
|
|
}
|
|
|
|
func WithModel(model string) VectorOption {
|
|
return func(v *VectorService) {
|
|
v.Model = model
|
|
}
|
|
}
|
|
|
|
func WithDimension(dim int) VectorOption {
|
|
return func(v *VectorService) {
|
|
v.Dimension = dim
|
|
}
|
|
}
|
|
|
|
func InitVector(opts ...VectorOption) error {
|
|
vectorSvc = &VectorService{
|
|
APIKey: "",
|
|
BaseURL: "https://api.siliconflow.cn/v1",
|
|
Model: "BAAI/bge-m3",
|
|
Dimension: 1024,
|
|
}
|
|
for _, opt := range opts {
|
|
opt(vectorSvc)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func GetVectorService() *VectorService {
|
|
vectorMu.RLock()
|
|
defer vectorMu.RUnlock()
|
|
return vectorSvc
|
|
}
|
|
|
|
func (v *VectorService) Generate(text string) ([]byte, error) {
|
|
if text == "" {
|
|
return make([]byte, v.Dimension*4), nil
|
|
}
|
|
|
|
reqBody := map[string]interface{}{
|
|
"input": text,
|
|
"model": v.Model,
|
|
}
|
|
bodyBytes, err := json.Marshal(reqBody)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("序列化请求失败: %w", err)
|
|
}
|
|
|
|
req, err := http.NewRequestWithContext(
|
|
context.Background(),
|
|
"POST",
|
|
v.BaseURL+"/embeddings",
|
|
bytes.NewReader(bodyBytes),
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("创建请求失败: %w", err)
|
|
}
|
|
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Authorization", "Bearer "+v.APIKey)
|
|
|
|
client := &http.Client{}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("请求失败: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
respBody, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("读取响应失败: %w", err)
|
|
}
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return nil, fmt.Errorf("API 返回错误: %d %s", resp.StatusCode, string(respBody))
|
|
}
|
|
|
|
var result struct {
|
|
Data []struct {
|
|
Embedding []float32 `json:"embedding"`
|
|
} `json:"data"`
|
|
}
|
|
|
|
if err := json.Unmarshal(respBody, &result); err != nil {
|
|
return nil, fmt.Errorf("解析响应失败: %w", err)
|
|
}
|
|
|
|
if len(result.Data) == 0 || len(result.Data[0].Embedding) == 0 {
|
|
return nil, fmt.Errorf("未获取到向量")
|
|
}
|
|
|
|
embedding := result.Data[0].Embedding
|
|
buf := new(bytes.Buffer)
|
|
if err := binary.Write(buf, binary.LittleEndian, embedding); err != nil {
|
|
return nil, fmt.Errorf("编码向量失败: %w", err)
|
|
}
|
|
|
|
return buf.Bytes(), nil
|
|
}
|
|
|
|
func (v *VectorService) GenerateBatch(texts []string) ([][]byte, error) {
|
|
if len(texts) == 0 {
|
|
return [][]byte{}, nil
|
|
}
|
|
|
|
reqBody := map[string]interface{}{
|
|
"input": texts,
|
|
"model": v.Model,
|
|
}
|
|
bodyBytes, err := json.Marshal(reqBody)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("序列化请求失败: %w", err)
|
|
}
|
|
|
|
req, err := http.NewRequestWithContext(
|
|
context.Background(),
|
|
"POST",
|
|
v.BaseURL+"/embeddings",
|
|
bytes.NewReader(bodyBytes),
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("创建请求失败: %w", err)
|
|
}
|
|
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Authorization", "Bearer "+v.APIKey)
|
|
|
|
client := &http.Client{}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("请求失败: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
respBody, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("读取响应失败: %w", err)
|
|
}
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return nil, fmt.Errorf("API 返回错误: %d %s", resp.StatusCode, string(respBody))
|
|
}
|
|
|
|
var result struct {
|
|
Data []struct {
|
|
Embedding []float32 `json:"embedding"`
|
|
} `json:"data"`
|
|
}
|
|
|
|
if err := json.Unmarshal(respBody, &result); err != nil {
|
|
return nil, fmt.Errorf("解析响应失败: %w", err)
|
|
}
|
|
|
|
embeddings := make([][]byte, len(result.Data))
|
|
for i, data := range result.Data {
|
|
embedding := data.Embedding
|
|
buf := new(bytes.Buffer)
|
|
if err := binary.Write(buf, binary.LittleEndian, embedding); err != nil {
|
|
continue
|
|
}
|
|
embeddings[i] = buf.Bytes()
|
|
}
|
|
|
|
return embeddings, nil
|
|
}
|
|
|
|
func (v *VectorService) CosineSimilarity(a, b []byte) float64 {
|
|
if len(a) == 0 || len(b) == 0 {
|
|
return 0
|
|
}
|
|
|
|
vecLen := len(a) / 4
|
|
if len(b)/4 < vecLen {
|
|
vecLen = len(b) / 4
|
|
}
|
|
|
|
var dotProduct float32
|
|
var normA float32
|
|
var normB float32
|
|
|
|
for i := 0; i < vecLen; i++ {
|
|
offset := i * 4
|
|
fa := binary.LittleEndian.Uint32(a[offset : offset+4])
|
|
fb := binary.LittleEndian.Uint32(b[offset : offset+4])
|
|
f32a := math.Float32frombits(fa)
|
|
f32b := math.Float32frombits(fb)
|
|
dotProduct += f32a * f32b
|
|
normA += f32a * f32a
|
|
normB += f32b * f32b
|
|
}
|
|
|
|
if normA == 0 || normB == 0 {
|
|
return 0
|
|
}
|
|
|
|
return float64(dotProduct / (normA * normB))
|
|
} |