fix: 补充提交 memory 模块和 tts/userdir 文件

This commit is contained in:
2026-04-27 07:15:08 +08:00
parent f6332fbaaf
commit 662c4e05a4
8 changed files with 1651 additions and 0 deletions

View File

@@ -0,0 +1,371 @@
package memory
import (
"context"
"database/sql"
"fmt"
"os"
"path/filepath"
"github.com/google/uuid"
"github.com/hxclaw/hxclaw/cmd/hxclaw/internal"
_ "github.com/tursodatabase/libsql-client-go/libsql"
)
type DB struct {
db *sql.DB
}
var (
db *DB
cfg *DBConfig
)
type DBConfig struct {
DBPath string
}
type DBSOption func(*DBConfig)
func WithDBPath(path string) DBSOption {
return func(c *DBConfig) {
c.DBPath = path
}
}
func getDefaultDBPath() string {
return internal.GetDBFile()
}
// GetDefaultDBPath 获取默认数据库路径(导出)
func GetDefaultDBPath() string {
return getDefaultDBPath()
}
func Init(opts ...DBSOption) error {
memoryCfg := internal.GetProjectConfig().Memory
cfg = &DBConfig{
DBPath: memoryCfg.DBPath,
}
for _, opt := range opts {
opt(cfg)
}
path := cfg.DBPath
if path == "" {
path = getDefaultDBPath()
}
if dir := filepath.Dir(path); dir != "" {
os.MkdirAll(dir, 0755)
}
dbInst, err := sql.Open("libsql", "file:"+path)
if err != nil {
return fmt.Errorf("打开数据库失败: %w", err)
}
if err := dbInst.Ping(); err != nil {
return fmt.Errorf("ping 数据库失败: %w", err)
}
db = &DB{db: dbInst}
return db.createTables()
}
func GetDB() *DB {
return db
}
func (d *DB) createTables() error {
queries := []string{
`CREATE TABLE IF NOT EXISTS sessions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
uuid TEXT UNIQUE NOT NULL,
summary TEXT,
summary_embedding BLOB,
chat_ids TEXT,
created_at INTEGER NOT NULL,
updated_at INTEGER NOT NULL
)`,
`CREATE TABLE IF NOT EXISTS chats (
id INTEGER PRIMARY KEY AUTOINCREMENT,
session_id INTEGER NOT NULL,
user_input TEXT NOT NULL,
ai_replies TEXT,
summary TEXT,
summary_embedding BLOB,
created_at INTEGER NOT NULL,
updated_at INTEGER NOT NULL,
FOREIGN KEY(session_id) REFERENCES sessions(id)
)`,
`CREATE INDEX IF NOT EXISTS idx_chats_session_id ON chats(session_id)`,
`CREATE INDEX IF NOT EXISTS idx_sessions_uuid ON sessions(uuid)`,
}
for _, q := range queries {
if _, err := d.db.ExecContext(context.Background(), q); err != nil {
return fmt.Errorf("创建表失败: %w", err)
}
}
return nil
}
func (d *DB) CreateSession(uuid string) (*Session, error) {
s := NewSession(uuid)
result, err := d.db.ExecContext(
context.Background(),
`INSERT INTO sessions (uuid, summary, summary_embedding, chat_ids, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?)`,
s.UUID, s.Summary, s.SummaryEmbedding, "[]", s.CreatedAt, s.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("创建会话失败: %w", err)
}
id, err := result.LastInsertId()
if err != nil {
return nil, fmt.Errorf("获取ID失败: %w", err)
}
s.ID = id
return s, nil
}
func (d *DB) GetSessionByUUID(uuid string) (*Session, error) {
s := &Session{}
var chatIDsJSON string
err := d.db.QueryRowContext(
context.Background(),
`SELECT id, uuid, summary, summary_embedding, chat_ids, created_at, updated_at FROM sessions WHERE uuid = ?`,
uuid,
).Scan(&s.ID, &s.UUID, &s.Summary, &s.SummaryEmbedding, &chatIDsJSON, &s.CreatedAt, &s.UpdatedAt)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("查询会话失败: %w", err)
}
if err := s.SetChatIDsFromJSON(chatIDsJSON); err != nil {
return nil, fmt.Errorf("解析chat_ids失败: %w", err)
}
return s, nil
}
func (d *DB) GetSessionByID(id int64) (*Session, error) {
s := &Session{}
var chatIDsJSON string
err := d.db.QueryRowContext(
context.Background(),
`SELECT id, uuid, summary, summary_embedding, chat_ids, created_at, updated_at FROM sessions WHERE id = ?`,
id,
).Scan(&s.ID, &s.UUID, &s.Summary, &s.SummaryEmbedding, &chatIDsJSON, &s.CreatedAt, &s.UpdatedAt)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("查询会话失败: %w", err)
}
if err := s.SetChatIDsFromJSON(chatIDsJSON); err != nil {
return nil, fmt.Errorf("解析chat_ids失败: %w", err)
}
return s, nil
}
func (d *DB) UpdateSession(s *Session) error {
chatIDsJSON, err := s.GetChatIDsJSON()
if err != nil {
return fmt.Errorf("序列化chat_ids失败: %w", err)
}
_, err = d.db.ExecContext(
context.Background(),
`UPDATE sessions SET summary = ?, summary_embedding = ?, chat_ids = ?, updated_at = ? WHERE id = ?`,
s.Summary, s.SummaryEmbedding, chatIDsJSON, s.UpdatedAt, s.ID,
)
if err != nil {
return fmt.Errorf("更新会话失败: %w", err)
}
return nil
}
func (d *DB) CreateChat(sessionID int64, userInput string) (*Chat, error) {
c := NewChat(sessionID, userInput)
result, err := d.db.ExecContext(
context.Background(),
`INSERT INTO chats (session_id, user_input, ai_replies, summary, summary_embedding, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?)`,
c.SessionID, c.UserInput, "[]", c.Summary, c.SummaryEmbedding, c.CreatedAt, c.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("创建聊天记录失败: %w", err)
}
id, err := result.LastInsertId()
if err != nil {
return nil, fmt.Errorf("获取ID失败: %w", err)
}
c.ID = id
return c, nil
}
func (d *DB) GetChatByID(id int64) (*Chat, error) {
c := &Chat{}
var aiRepliesJSON string
err := d.db.QueryRowContext(
context.Background(),
`SELECT id, session_id, user_input, ai_replies, summary, summary_embedding, created_at, updated_at FROM chats WHERE id = ?`,
id,
).Scan(&c.ID, &c.SessionID, &c.UserInput, &aiRepliesJSON, &c.Summary, &c.SummaryEmbedding, &c.CreatedAt, &c.UpdatedAt)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("查询聊天记录失败: %w", err)
}
if err := c.SetAIRepliesFromJSON(aiRepliesJSON); err != nil {
return nil, fmt.Errorf("解析ai_replies失败: %w", err)
}
return c, nil
}
func (d *DB) UpdateChat(c *Chat) error {
aiRepliesJSON, err := c.GetAIRepliesJSON()
if err != nil {
return fmt.Errorf("序列化ai_replies失败: %w", err)
}
_, err = d.db.ExecContext(
context.Background(),
`UPDATE chats SET ai_replies = ?, summary = ?, summary_embedding = ?, updated_at = ? WHERE id = ?`,
aiRepliesJSON, c.Summary, c.SummaryEmbedding, c.UpdatedAt, c.ID,
)
if err != nil {
return fmt.Errorf("更新聊天记录失败: %w", err)
}
return nil
}
func (d *DB) GetLatestSession() (*Session, error) {
s := &Session{}
var chatIDsJSON string
err := d.db.QueryRowContext(
context.Background(),
`SELECT id, uuid, summary, summary_embedding, chat_ids, created_at, updated_at FROM sessions ORDER BY id DESC LIMIT 1`,
).Scan(&s.ID, &s.UUID, &s.Summary, &s.SummaryEmbedding, &chatIDsJSON, &s.CreatedAt, &s.UpdatedAt)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("查询最新会话失败: %w", err)
}
if err := s.SetChatIDsFromJSON(chatIDsJSON); err != nil {
return nil, fmt.Errorf("解析chat_ids失败: %w", err)
}
return s, nil
}
func (d *DB) GetAllSessions() ([]*Session, error) {
rows, err := d.db.QueryContext(
context.Background(),
`SELECT id, uuid, summary, summary_embedding, chat_ids, created_at, updated_at FROM sessions ORDER BY id DESC`,
)
if err != nil {
return nil, fmt.Errorf("查询会话列表失败: %w", err)
}
defer rows.Close()
var sessions []*Session
for rows.Next() {
s := &Session{}
var chatIDsJSON string
if err := rows.Scan(&s.ID, &s.UUID, &s.Summary, &s.SummaryEmbedding, &chatIDsJSON, &s.CreatedAt, &s.UpdatedAt); err != nil {
return nil, fmt.Errorf("扫描会话失败: %w", err)
}
if err := s.SetChatIDsFromJSON(chatIDsJSON); err != nil {
return nil, fmt.Errorf("解析chat_ids失败: %w", err)
}
sessions = append(sessions, s)
}
return sessions, nil
}
func (d *DB) GetChatsBySessionID(sessionID int64) ([]*Chat, error) {
rows, err := d.db.QueryContext(
context.Background(),
`SELECT id, session_id, user_input, ai_replies, summary, summary_embedding, created_at, updated_at FROM chats WHERE session_id = ? ORDER BY id`,
sessionID,
)
if err != nil {
return nil, fmt.Errorf("查询聊天记录失败: %w", err)
}
defer rows.Close()
var chats []*Chat
for rows.Next() {
c := &Chat{}
var aiRepliesJSON string
if err := rows.Scan(&c.ID, &c.SessionID, &c.UserInput, &aiRepliesJSON, &c.Summary, &c.SummaryEmbedding, &c.CreatedAt, &c.UpdatedAt); err != nil {
return nil, fmt.Errorf("扫描聊天记录失败: %w", err)
}
if err := c.SetAIRepliesFromJSON(aiRepliesJSON); err != nil {
return nil, fmt.Errorf("解析ai_replies失败: %w", err)
}
chats = append(chats, c)
}
return chats, nil
}
func (d *DB) Close() error {
if db != nil && db.db != nil {
return db.db.Close()
}
return nil
}
func CreateNewSession() (*Session, error) {
if db == nil {
return nil, fmt.Errorf("数据库未初始化")
}
return db.CreateSession(generateUUID())
}
func GetLatestSession() (*Session, error) {
if db == nil {
return nil, nil
}
return db.GetLatestSession()
}
func GetAllSessions() ([]*Session, error) {
if db == nil {
return nil, nil
}
return db.GetAllSessions()
}
func ListSessions() ([]*Session, error) {
return GetAllSessions()
}
func GetCurrentSession() *Session {
return currentSession
}
var currentSession *Session
func generateUUID() string {
return uuid.New().String()
}

View File

@@ -0,0 +1,110 @@
package memory
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"time"
"github.com/hxclaw/hxclaw/cmd/hxclaw/internal"
)
type ExportDocument struct {
Version int `json:"version"`
ExportedAt string `json:"exported_at"`
Sessions []*SessionData `json:"sessions"`
}
type SessionData struct {
ID int64 `json:"id"`
UUID string `json:"uuid"`
Summary string `json:"summary"`
ChatIDs []int64 `json:"chat_ids"`
CreatedAt int64 `json:"created_at"`
UpdatedAt int64 `json:"updated_at"`
Chats []*Chat `json:"chats"`
}
func ExportIfNeeded() {
memoryCfg := internal.GetProjectConfig().Memory
if !memoryCfg.Enabled || !memoryCfg.AutoExport {
return
}
db := GetDB()
if db == nil {
return
}
exportPath := filepath.Join(internal.GetConfigDir(), "export-data.json")
if err := ExportToFile(exportPath); err != nil {
fmt.Printf("警告:导出数据失败: %v\n", err)
return
}
fmt.Printf("数据已导出: %s\n", exportPath)
}
func ExportToFile(filename string) error {
db := GetDB()
if db == nil {
return fmt.Errorf("数据库未初始化")
}
sessions, err := db.GetAllSessions()
if err != nil {
return fmt.Errorf("查询会话失败: %w", err)
}
var doc ExportDocument
doc.Version = 1
doc.ExportedAt = time.Now().Format(time.RFC3339)
if _, err := os.Stat(filename); err == nil {
data, readErr := os.ReadFile(filename)
if readErr == nil {
json.Unmarshal(data, &doc)
}
}
sessionMap := make(map[string]*SessionData)
for i := range doc.Sessions {
sessionMap[doc.Sessions[i].UUID] = doc.Sessions[i]
}
for _, s := range sessions {
if existing, ok := sessionMap[s.UUID]; ok {
existing.Summary = s.Summary
existing.ChatIDs = s.ChatIDs
existing.UpdatedAt = s.UpdatedAt
} else {
sessionData := &SessionData{
ID: s.ID,
UUID: s.UUID,
Summary: s.Summary,
ChatIDs: s.ChatIDs,
CreatedAt: s.CreatedAt,
UpdatedAt: s.UpdatedAt,
Chats: []*Chat{},
}
doc.Sessions = append(doc.Sessions, sessionData)
sessionMap[s.UUID] = sessionData
}
chats, err := db.GetChatsBySessionID(s.ID)
if err != nil {
continue
}
sessionMap[s.UUID].Chats = append(sessionMap[s.UUID].Chats, chats...)
}
file, err := os.Create(filename)
if err != nil {
return fmt.Errorf("创建文件失败: %w", err)
}
defer file.Close()
encoder := json.NewEncoder(file)
encoder.SetIndent("", " ")
return encoder.Encode(doc)
}

View File

@@ -0,0 +1,98 @@
package memory
import (
"encoding/json"
"time"
)
type Session struct {
ID int64 `json:"id"`
UUID string `json:"uuid"`
Summary string `json:"summary"`
SummaryEmbedding []byte `json:"-"`
ChatIDs []int64 `json:"chat_ids"`
CreatedAt int64 `json:"created_at"`
UpdatedAt int64 `json:"updated_at"`
}
type Chat struct {
ID int64 `json:"id"`
SessionID int64 `json:"session_id"`
UserInput string `json:"user_input"`
AIReplies []string `json:"ai_replies"`
Summary string `json:"summary"`
SummaryEmbedding []byte `json:"-"`
CreatedAt int64 `json:"created_at"`
UpdatedAt int64 `json:"updated_at"`
}
func NewSession(uuid string) *Session {
now := time.Now().Unix()
return &Session{
UUID: uuid,
ChatIDs: []int64{},
CreatedAt: now,
UpdatedAt: now,
}
}
func NewChat(sessionID int64, userInput string) *Chat {
now := time.Now().Unix()
return &Chat{
SessionID: sessionID,
UserInput: userInput,
AIReplies: []string{},
CreatedAt: now,
UpdatedAt: now,
}
}
func (s *Session) AddChatID(chatID int64) {
s.ChatIDs = append(s.ChatIDs, chatID)
s.UpdatedAt = time.Now().Unix()
}
func (s *Session) GetChatIDsJSON() (string, error) {
data, err := json.Marshal(s.ChatIDs)
if err != nil {
return "", err
}
return string(data), nil
}
func (s *Session) SetChatIDsFromJSON(data string) error {
if data == "" {
s.ChatIDs = []int64{}
return nil
}
return json.Unmarshal([]byte(data), &s.ChatIDs)
}
func (c *Chat) AddAIReply(reply string) {
c.AIReplies = append(c.AIReplies, reply)
c.UpdatedAt = time.Now().Unix()
}
func (c *Chat) GetAIRepliesJSON() (string, error) {
data, err := json.Marshal(c.AIReplies)
if err != nil {
return "", err
}
return string(data), nil
}
func (c *Chat) SetAIRepliesFromJSON(data string) error {
if data == "" {
c.AIReplies = []string{}
return nil
}
return json.Unmarshal([]byte(data), &c.AIReplies)
}
func GenerateSummary(userInput, aiReply string) string {
fullText := userInput + "\n" + aiReply
if len(fullText) > 200 {
fullText = fullText[:200]
}
return fullText
}

View File

@@ -0,0 +1,443 @@
package memory
import (
"fmt"
"os"
"path/filepath"
"strings"
"time"
"github.com/hxclaw/hxclaw/cmd/hxclaw/internal"
)
var lastContext string
func GetContextPrompt(userInput string) string {
db := GetDB()
if db == nil {
return ""
}
memoryCfg := internal.GetProjectConfig().Memory
if !memoryCfg.Enabled {
return ""
}
session := currentSession
if session == nil {
var err error
session, err = db.GetLatestSession()
if err != nil || session == nil {
return ""
}
currentSession = session
}
var context string
// 读取 picoclaw 的长期记忆
picoMemory := readPicoClawMemory()
if picoMemory != "" {
context += "=== 长期记忆 ===\n" + picoMemory + "\n============\n"
}
// 添加会话摘要
if session.Summary != "" {
context += "=== 当前会话摘要 ===\n" + session.Summary + "\n============\n"
}
if shouldRecall(userInput) {
recallResult := detectAndRecall(userInput)
if recallResult != "" {
context += recallResult
}
}
lastContext = context
return context
}
func readPicoClawMemory() string {
memoryPath := filepath.Join(getPicoclawWorkspace(), "memory", "MEMORY.md")
data, err := os.ReadFile(memoryPath)
if err != nil {
return ""
}
content := strings.TrimSpace(string(data))
lines := strings.Split(content, "\n")
var sb strings.Builder
inSection := false
for _, line := range lines {
trimmed := strings.TrimSpace(line)
if trimmed == "" {
continue
}
if strings.HasPrefix(trimmed, "#") {
inSection = true
continue
}
if inSection {
sb.WriteString(trimmed)
sb.WriteString("\n")
}
}
return strings.TrimSpace(sb.String())
}
func getPicoclawWorkspace() string {
home, err := os.UserHomeDir()
if err != nil {
return ""
}
return filepath.Join(home, ".picoclaw", "workspace")
}
func shouldRecall(input string) bool {
memoryCfg := internal.GetProjectConfig().Memory
for _, cmd := range []string{"/recall", "/memory"} {
if strings.HasPrefix(input, cmd) {
return true
}
}
recallCfg := memoryCfg.Recall
for _, keyword := range recallCfg.Keywords {
if strings.Contains(input, keyword) {
return true
}
}
if recallCfg.AutoRecall {
return shouldAutoRecall(input)
}
return false
}
func shouldAutoRecall(input string) bool {
vs := GetVectorService()
if vs == nil {
return false
}
memoryCfg := internal.GetProjectConfig().Memory
threshold := memoryCfg.Recall.SimilarityThreshold
if threshold <= 0 {
threshold = 0.7
}
session := currentSession
if session == nil {
return false
}
if session.Summary == "" || len(session.SummaryEmbedding) == 0 {
return false
}
inputEmb, err := vs.Generate(input)
if err != nil {
return false
}
similarity := vs.CosineSimilarity(inputEmb, session.SummaryEmbedding)
return similarity > threshold
}
func detectAndRecall(input string) string {
input = strings.TrimSpace(input)
if strings.HasPrefix(input, "/recall") {
input = strings.TrimPrefix(input, "/recall")
input = strings.TrimSpace(input)
if input == "" {
result, _ := RecallHistory()
return FormatRecallResults(result)
}
}
for _, keyword := range []string{"之前", "聊过", "记得", "曾经", "谈论过", "提过"} {
if strings.Contains(input, keyword) {
return recallByKeyword(input, keyword)
}
}
return ""
}
func recallByKeyword(input, keyword string) string {
db := GetDB()
if db == nil {
return ""
}
if keyword == "之前" || keyword == "聊过" || keyword == "曾经" {
if strings.Contains(input, "什么") || input == "之前" || strings.Contains(input, "都聊") {
result, _ := RecallHistory()
return FormatRecallResults(result)
}
topic := extractTopic(input)
if topic != "" {
result, _ := RecallTopic(topic, 5)
return FormatRecallResults(result)
}
}
if strings.Contains(input, "那次") || strings.Contains(input, "那次") {
result, _ := RecallHistory()
return FormatRecallResults(result)
}
topic := extractTopic(input)
if topic != "" {
result, _ := RecallTopic(topic, 5)
return FormatRecallResults(result)
}
return ""
}
func extractTopic(input string) string {
topic := input
removePrefixes := []string{"之前", "聊过", "谈论过", "提过", "记得", "曾经", "关于", "我问过"}
for _, prefix := range removePrefixes {
topic = strings.Replace(topic, prefix, "", -1)
}
topic = strings.Trim(topic, " ,。.?!:")
return topic
}
func ShouldSkipSummaryUpdate(input string) bool {
return shouldRecall(input)
}
func SaveChat(userInput, aiReply string, useSessionSummary bool) (int, error) {
db := GetDB()
if db == nil {
return 0, fmt.Errorf("数据库未初始化")
}
memoryCfg := internal.GetProjectConfig().Memory
if !memoryCfg.Enabled {
return 0, nil
}
// 获取或创建 Session
session := currentSession
if session == nil {
var err error
session, err = db.GetLatestSession()
if err != nil || session == nil {
session, err = createSession()
if err != nil {
return 0, fmt.Errorf("创建会话失败: %v", err)
}
}
currentSession = session
}
// 保存原始 session summary用于恢复
originalSummary := session.Summary
// 创建聊天记录
chat, err := db.CreateChat(session.ID, userInput)
if err != nil {
return 0, fmt.Errorf("创建聊天记录失败: %v", err)
}
// 添加 AI 回复
chat.AddAIReply(aiReply)
chat.Summary = GenerateSummary(userInput, aiReply)
chat.UpdatedAt = time.Now().Unix()
// 生成并保存向量
vs := GetVectorService()
if vs != nil && memoryCfg.Vector.APIKey != "" {
go generateEmbedding(chat)
}
if err := db.UpdateChat(chat); err != nil {
return 0, fmt.Errorf("更新聊天记录失败: %v", err)
}
// 更新 Session
session.AddChatID(chat.ID)
// 只有普通对话才更新 session summaryrecall 查询保持原 summary
if useSessionSummary {
session.Summary = GenerateSummary("", session.Summary+"\n"+aiReply)
}
session.UpdatedAt = time.Now().Unix()
// Session 也要生成向量
if vs != nil && memoryCfg.Vector.APIKey != "" && useSessionSummary {
go generateSessionEmbedding(session)
}
if err := db.UpdateSession(session); err != nil {
return 0, fmt.Errorf("更新会话失败: %v", err)
}
// 恢复原始 session summary避免 recall 结果污染)
session.Summary = originalSummary
currentSession = session
return len(session.ChatIDs), nil
}
func generateEmbedding(chat *Chat) {
vs := GetVectorService()
if vs == nil {
return
}
text := chat.UserInput
if len(chat.AIReplies) > 0 {
text = text + "\n" + chat.AIReplies[len(chat.AIReplies)-1]
}
embedding, err := vs.Generate(text)
if err != nil {
fmt.Printf("[memory] 向量生成失败: %v\n", err)
return
}
chat.SummaryEmbedding = embedding
chat.UpdatedAt = time.Now().Unix()
if err := GetDB().UpdateChat(chat); err != nil {
fmt.Printf("[memory] 向量保存失败: %v\n", err)
}
}
func generateSessionEmbedding(session *Session) {
vs := GetVectorService()
if vs == nil {
return
}
embedding, err := vs.Generate(session.Summary)
if err != nil {
fmt.Printf("[memory] Session 向量生成失败: %v\n", err)
return
}
session.SummaryEmbedding = embedding
session.UpdatedAt = time.Now().Unix()
if err := GetDB().UpdateSession(session); err != nil {
fmt.Printf("[memory] Session 向量保存失败: %v\n", err)
}
}
func createSession() (*Session, error) {
return CreateNewSession()
}
type SearchResult struct {
Chat *Chat
Similarity float64
}
func SearchSimilar(query string, topK int) ([]SearchResult, error) {
db := GetDB()
if db == nil {
return nil, fmt.Errorf("数据库未初始化")
}
vs := GetVectorService()
if vs == nil {
return nil, fmt.Errorf("向量服务未初始化")
}
// 获取配置的最大检索数
maxResults := internal.GetProjectConfig().Memory.Vector.MaxSearchResults
if maxResults <= 0 {
maxResults = 10
}
if topK > maxResults {
topK = maxResults
}
// 生成查询向量
queryEmbedding, err := vs.Generate(query)
if err != nil {
return nil, fmt.Errorf("查询向量生成失败: %w", err)
}
// 获取所有有向量的 chat
chats, err := db.GetAllChats()
if err != nil {
return nil, fmt.Errorf("获取聊天记录失败: %w", err)
}
// 计算相似度并排序
type scoredChat struct {
chat *Chat
similarity float64
}
var scored []scoredChat
for _, chat := range chats {
if len(chat.SummaryEmbedding) == 0 {
continue
}
sim := vs.CosineSimilarity(queryEmbedding, chat.SummaryEmbedding)
scored = append(scored, scoredChat{chat: chat, similarity: sim})
}
// 按相似度排序
for i := 0; i < len(scored); i++ {
for j := i + 1; j < len(scored); j++ {
if scored[j].similarity > scored[i].similarity {
scored[i], scored[j] = scored[j], scored[i]
}
}
}
// 取前 K 个
if topK > len(scored) {
topK = len(scored)
}
results := make([]SearchResult, topK)
for i := 0; i < topK; i++ {
results[i] = SearchResult{
Chat: scored[i].chat,
Similarity: scored[i].similarity,
}
}
return results, nil
}
func (db *DB) GetAllChats() ([]*Chat, error) {
rows, err := db.db.Query(
`SELECT id, session_id, user_input, ai_replies, summary, summary_embedding, created_at, updated_at FROM chats WHERE summary_embedding IS NOT NULL AND length(summary_embedding) > 0 ORDER BY id DESC`,
)
if err != nil {
return nil, err
}
defer rows.Close()
var chats []*Chat
for rows.Next() {
var c Chat
var aiRepliesJSON string
var summaryEmbedding []byte
if err := rows.Scan(&c.ID, &c.SessionID, &c.UserInput, &aiRepliesJSON, &c.Summary, &summaryEmbedding, &c.CreatedAt, &c.UpdatedAt); err != nil {
continue
}
c.SummaryEmbedding = summaryEmbedding
c.SetAIRepliesFromJSON(aiRepliesJSON)
chats = append(chats, &c)
}
return chats, nil
}

View File

@@ -0,0 +1,214 @@
package memory
import (
"fmt"
"github.com/hxclaw/hxclaw/cmd/hxclaw/internal"
)
type RecallResult struct {
Type string `json:"type"`
Message string `json:"message"`
Sessions []*Session `json:"sessions,omitempty"`
Chats []*Chat `json:"chats,omitempty"`
}
func RecallHistory() (*RecallResult, error) {
db := GetDB()
if db == nil {
return nil, fmt.Errorf("数据库未初始化")
}
sessions, err := db.GetAllSessions()
if err != nil {
return nil, fmt.Errorf("查询会话失败: %v", err)
}
if len(sessions) == 0 {
return &RecallResult{
Type: "history",
Message: "暂无历史会话记录",
}, nil
}
msg := "以下是所有历史会话摘要:\n\n"
for i, s := range sessions {
summary := s.Summary
if summary == "" {
summary = "(暂无摘要)"
}
msg += fmt.Sprintf("**会话 %d** (ID: %d)\n%s\n\n", i+1, s.ID, summary)
}
return &RecallResult{
Type: "history",
Message: msg,
Sessions: sessions,
}, nil
}
func RecallTopic(query string, topK int) (*RecallResult, error) {
vs := GetVectorService()
if vs == nil {
return nil, fmt.Errorf("向量服务未初始化")
}
cfg := internal.GetProjectConfig().Memory
if topK <= 0 {
topK = cfg.Vector.MaxSearchResults
if topK <= 0 {
topK = 5
}
}
results, err := SearchSimilar(query, topK)
if err != nil {
return nil, fmt.Errorf("向量检索失败: %v", err)
}
if len(results) == 0 {
return &RecallResult{
Type: "topic",
Message: fmt.Sprintf("未找到与「%s」相关的内容", query),
}, nil
}
type sessionChat struct {
sessionID int64
chat *Chat
similarity float64
}
sessionChats := make(map[int64]*sessionChat)
for _, r := range results {
if existing, ok := sessionChats[r.Chat.SessionID]; ok {
if r.Similarity > existing.similarity {
sessionChats[r.Chat.SessionID] = &sessionChat{
sessionID: r.Chat.SessionID,
chat: r.Chat,
similarity: r.Similarity,
}
}
} else {
sessionChats[r.Chat.SessionID] = &sessionChat{
sessionID: r.Chat.SessionID,
chat: r.Chat,
similarity: r.Similarity,
}
}
}
msg := fmt.Sprintf("找到与「%s」相关的内容\n\n", query)
var chats []*Chat
for _, sc := range sessionChats {
msg += fmt.Sprintf("**会话 %d** (相似度: %.2f)\n%s\n\n",
sc.sessionID, sc.similarity, sc.chat.Summary)
chats = append(chats, sc.chat)
}
return &RecallResult{
Type: "topic",
Message: msg,
Chats: chats,
}, nil
}
func RecallSession(sessionID int64) (*RecallResult, error) {
db := GetDB()
if db == nil {
return nil, fmt.Errorf("数据库未初始化")
}
session, err := db.GetSessionByID(sessionID)
if err != nil {
return nil, fmt.Errorf("查询会话失败: %v", err)
}
if session == nil {
return &RecallResult{
Type: "session",
Message: fmt.Sprintf("未找到会话 ID: %d", sessionID),
}, nil
}
msg := fmt.Sprintf("**会话 %d 的摘要**\n\n%s", sessionID, session.Summary)
return &RecallResult{
Type: "session",
Message: msg,
Sessions: []*Session{session},
}, nil
}
func RecallWithinSession(sessionID int64, query string) (*RecallResult, error) {
db := GetDB()
if db == nil {
return nil, fmt.Errorf("数据库未初始化")
}
chats, err := db.GetChatsBySessionID(sessionID)
if err != nil {
return nil, fmt.Errorf("查询聊天记录失败: %v", err)
}
if len(chats) == 0 {
return &RecallResult{
Type: "within_session",
Message: fmt.Sprintf("会话 %d 中暂无聊天记录", sessionID),
}, nil
}
vs := GetVectorService()
if vs == nil {
return nil, fmt.Errorf("向量服务未初始化")
}
queryEmb, err := vs.Generate(query)
if err != nil {
return nil, fmt.Errorf("生成查询向量失败: %v", err)
}
type scoredChat struct {
chat *Chat
similarity float64
}
var scored []scoredChat
for _, c := range chats {
if len(c.SummaryEmbedding) == 0 {
continue
}
sim := vs.CosineSimilarity(queryEmb, c.SummaryEmbedding)
scored = append(scored, scoredChat{chat: c, similarity: sim})
}
for i := 0; i < len(scored); i++ {
for j := i + 1; j < len(scored); j++ {
if scored[j].similarity > scored[i].similarity {
scored[i], scored[j] = scored[j], scored[i]
}
}
}
topK := 3
if topK > len(scored) {
topK = len(scored)
}
msg := fmt.Sprintf("会话 %d 中与「%s」相关内容\n\n", sessionID, query)
var matchedChats []*Chat
for i := 0; i < topK; i++ {
msg += fmt.Sprintf("**相似度: %.2f**\n%s\n\n",
scored[i].similarity, scored[i].chat.Summary)
matchedChats = append(matchedChats, scored[i].chat)
}
return &RecallResult{
Type: "within_session",
Message: msg,
Chats: matchedChats,
}, nil
}
func FormatRecallResults(result *RecallResult) string {
return result.Message
}

View File

@@ -0,0 +1,234 @@
package memory
import (
"bytes"
"context"
"encoding/binary"
"encoding/json"
"fmt"
"io"
"math"
"net/http"
"sync"
)
type VectorService struct {
APIKey string
BaseURL string
Model string
Dimension int
}
var (
vectorSvc *VectorService
vectorMu sync.RWMutex
)
type VectorOption func(*VectorService)
func WithAPIKey(key string) VectorOption {
return func(v *VectorService) {
v.APIKey = key
}
}
func WithBaseURL(url string) VectorOption {
return func(v *VectorService) {
v.BaseURL = url
}
}
func WithModel(model string) VectorOption {
return func(v *VectorService) {
v.Model = model
}
}
func WithDimension(dim int) VectorOption {
return func(v *VectorService) {
v.Dimension = dim
}
}
func InitVector(opts ...VectorOption) error {
vectorSvc = &VectorService{
APIKey: "",
BaseURL: "https://api.siliconflow.cn/v1",
Model: "BAAI/bge-m3",
Dimension: 1024,
}
for _, opt := range opts {
opt(vectorSvc)
}
return nil
}
func GetVectorService() *VectorService {
vectorMu.RLock()
defer vectorMu.RUnlock()
return vectorSvc
}
func (v *VectorService) Generate(text string) ([]byte, error) {
if text == "" {
return make([]byte, v.Dimension*4), nil
}
reqBody := map[string]interface{}{
"input": text,
"model": v.Model,
}
bodyBytes, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("序列化请求失败: %w", err)
}
req, err := http.NewRequestWithContext(
context.Background(),
"POST",
v.BaseURL+"/embeddings",
bytes.NewReader(bodyBytes),
)
if err != nil {
return nil, fmt.Errorf("创建请求失败: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+v.APIKey)
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("请求失败: %w", err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("读取响应失败: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API 返回错误: %d %s", resp.StatusCode, string(respBody))
}
var result struct {
Data []struct {
Embedding []float32 `json:"embedding"`
} `json:"data"`
}
if err := json.Unmarshal(respBody, &result); err != nil {
return nil, fmt.Errorf("解析响应失败: %w", err)
}
if len(result.Data) == 0 || len(result.Data[0].Embedding) == 0 {
return nil, fmt.Errorf("未获取到向量")
}
embedding := result.Data[0].Embedding
buf := new(bytes.Buffer)
if err := binary.Write(buf, binary.LittleEndian, embedding); err != nil {
return nil, fmt.Errorf("编码向量失败: %w", err)
}
return buf.Bytes(), nil
}
func (v *VectorService) GenerateBatch(texts []string) ([][]byte, error) {
if len(texts) == 0 {
return [][]byte{}, nil
}
reqBody := map[string]interface{}{
"input": texts,
"model": v.Model,
}
bodyBytes, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("序列化请求失败: %w", err)
}
req, err := http.NewRequestWithContext(
context.Background(),
"POST",
v.BaseURL+"/embeddings",
bytes.NewReader(bodyBytes),
)
if err != nil {
return nil, fmt.Errorf("创建请求失败: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+v.APIKey)
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("请求失败: %w", err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("读取响应失败: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API 返回错误: %d %s", resp.StatusCode, string(respBody))
}
var result struct {
Data []struct {
Embedding []float32 `json:"embedding"`
} `json:"data"`
}
if err := json.Unmarshal(respBody, &result); err != nil {
return nil, fmt.Errorf("解析响应失败: %w", err)
}
embeddings := make([][]byte, len(result.Data))
for i, data := range result.Data {
embedding := data.Embedding
buf := new(bytes.Buffer)
if err := binary.Write(buf, binary.LittleEndian, embedding); err != nil {
continue
}
embeddings[i] = buf.Bytes()
}
return embeddings, nil
}
func (v *VectorService) CosineSimilarity(a, b []byte) float64 {
if len(a) == 0 || len(b) == 0 {
return 0
}
vecLen := len(a) / 4
if len(b)/4 < vecLen {
vecLen = len(b) / 4
}
var dotProduct float32
var normA float32
var normB float32
for i := 0; i < vecLen; i++ {
offset := i * 4
fa := binary.LittleEndian.Uint32(a[offset : offset+4])
fb := binary.LittleEndian.Uint32(b[offset : offset+4])
f32a := math.Float32frombits(fa)
f32b := math.Float32frombits(fb)
dotProduct += f32a * f32b
normA += f32a * f32a
normB += f32b * f32b
}
if normA == 0 || normB == 0 {
return 0
}
return float64(dotProduct / (normA * normB))
}