2026-04-27 07:15:08 +08:00
package memory
import (
"fmt"
"os"
"path/filepath"
"strings"
"time"
"github.com/hxclaw/hxclaw/cmd/hxclaw/internal"
)
var lastContext string
2026-04-27 08:00:20 +08:00
var ErrNeedNewSession = fmt . Errorf ( "需要创建新会话" )
2026-04-27 07:15:08 +08:00
func GetContextPrompt ( userInput string ) string {
db := GetDB ( )
if db == nil {
return ""
}
memoryCfg := internal . GetProjectConfig ( ) . Memory
if ! memoryCfg . Enabled {
return ""
}
session := currentSession
if session == nil {
2026-04-27 08:03:16 +08:00
return "" // 没有 session, 不提供上下文
2026-04-27 07:15:08 +08:00
}
var context string
// 读取 picoclaw 的长期记忆
picoMemory := readPicoClawMemory ( )
if picoMemory != "" {
context += "=== 长期记忆 ===\n" + picoMemory + "\n============\n"
}
// 添加会话摘要
if session . Summary != "" {
context += "=== 当前会话摘要 ===\n" + session . Summary + "\n============\n"
}
if shouldRecall ( userInput ) {
recallResult := detectAndRecall ( userInput )
if recallResult != "" {
context += recallResult
}
}
lastContext = context
return context
}
func readPicoClawMemory ( ) string {
memoryPath := filepath . Join ( getPicoclawWorkspace ( ) , "memory" , "MEMORY.md" )
data , err := os . ReadFile ( memoryPath )
if err != nil {
return ""
}
content := strings . TrimSpace ( string ( data ) )
lines := strings . Split ( content , "\n" )
var sb strings . Builder
inSection := false
for _ , line := range lines {
trimmed := strings . TrimSpace ( line )
if trimmed == "" {
continue
}
if strings . HasPrefix ( trimmed , "#" ) {
inSection = true
continue
}
if inSection {
sb . WriteString ( trimmed )
sb . WriteString ( "\n" )
}
}
return strings . TrimSpace ( sb . String ( ) )
}
func getPicoclawWorkspace ( ) string {
home , err := os . UserHomeDir ( )
if err != nil {
return ""
}
return filepath . Join ( home , ".picoclaw" , "workspace" )
}
func shouldRecall ( input string ) bool {
memoryCfg := internal . GetProjectConfig ( ) . Memory
for _ , cmd := range [ ] string { "/recall" , "/memory" } {
if strings . HasPrefix ( input , cmd ) {
return true
}
}
recallCfg := memoryCfg . Recall
for _ , keyword := range recallCfg . Keywords {
if strings . Contains ( input , keyword ) {
return true
}
}
if recallCfg . AutoRecall {
return shouldAutoRecall ( input )
}
return false
}
func shouldAutoRecall ( input string ) bool {
vs := GetVectorService ( )
if vs == nil {
return false
}
memoryCfg := internal . GetProjectConfig ( ) . Memory
threshold := memoryCfg . Recall . SimilarityThreshold
if threshold <= 0 {
threshold = 0.7
}
session := currentSession
if session == nil {
return false
}
if session . Summary == "" || len ( session . SummaryEmbedding ) == 0 {
return false
}
inputEmb , err := vs . Generate ( input )
if err != nil {
return false
}
similarity := vs . CosineSimilarity ( inputEmb , session . SummaryEmbedding )
return similarity > threshold
}
func detectAndRecall ( input string ) string {
input = strings . TrimSpace ( input )
if strings . HasPrefix ( input , "/recall" ) {
input = strings . TrimPrefix ( input , "/recall" )
input = strings . TrimSpace ( input )
if input == "" {
result , _ := RecallHistory ( )
return FormatRecallResults ( result )
}
}
for _ , keyword := range [ ] string { "之前" , "聊过" , "记得" , "曾经" , "谈论过" , "提过" } {
if strings . Contains ( input , keyword ) {
return recallByKeyword ( input , keyword )
}
}
return ""
}
func recallByKeyword ( input , keyword string ) string {
db := GetDB ( )
if db == nil {
return ""
}
if keyword == "之前" || keyword == "聊过" || keyword == "曾经" {
if strings . Contains ( input , "什么" ) || input == "之前" || strings . Contains ( input , "都聊" ) {
result , _ := RecallHistory ( )
return FormatRecallResults ( result )
}
topic := extractTopic ( input )
if topic != "" {
result , _ := RecallTopic ( topic , 5 )
return FormatRecallResults ( result )
}
}
if strings . Contains ( input , "那次" ) || strings . Contains ( input , "那次" ) {
result , _ := RecallHistory ( )
return FormatRecallResults ( result )
}
topic := extractTopic ( input )
if topic != "" {
result , _ := RecallTopic ( topic , 5 )
return FormatRecallResults ( result )
}
return ""
}
func extractTopic ( input string ) string {
topic := input
removePrefixes := [ ] string { "之前" , "聊过" , "谈论过" , "提过" , "记得" , "曾经" , "关于" , "我问过" }
for _ , prefix := range removePrefixes {
topic = strings . Replace ( topic , prefix , "" , - 1 )
}
topic = strings . Trim ( topic , " , ,。.? ?! !: :" )
return topic
}
func ShouldSkipSummaryUpdate ( input string ) bool {
return shouldRecall ( input )
}
func SaveChat ( userInput , aiReply string , useSessionSummary bool ) ( int , error ) {
db := GetDB ( )
if db == nil {
return 0 , fmt . Errorf ( "数据库未初始化" )
}
memoryCfg := internal . GetProjectConfig ( ) . Memory
if ! memoryCfg . Enabled {
return 0 , nil
}
// 获取或创建 Session
session := currentSession
if session == nil {
2026-04-27 08:00:20 +08:00
// 如果没有 session, 返回错误让用户创建
return 0 , ErrNeedNewSession
2026-04-27 07:15:08 +08:00
}
// 保存原始 session summary( 用于恢复)
originalSummary := session . Summary
// 创建聊天记录
chat , err := db . CreateChat ( session . ID , userInput )
if err != nil {
return 0 , fmt . Errorf ( "创建聊天记录失败: %v" , err )
}
// 添加 AI 回复
chat . AddAIReply ( aiReply )
chat . Summary = GenerateSummary ( userInput , aiReply )
chat . UpdatedAt = time . Now ( ) . Unix ( )
// 生成并保存向量
vs := GetVectorService ( )
if vs != nil && memoryCfg . Vector . APIKey != "" {
2026-04-27 08:03:16 +08:00
generateEmbedding ( chat )
2026-04-27 07:15:08 +08:00
}
if err := db . UpdateChat ( chat ) ; err != nil {
return 0 , fmt . Errorf ( "更新聊天记录失败: %v" , err )
}
// 更新 Session
session . AddChatID ( chat . ID )
// 只有普通对话才更新 session summary, recall 查询保持原 summary
if useSessionSummary {
session . Summary = GenerateSummary ( "" , session . Summary + "\n" + aiReply )
}
session . UpdatedAt = time . Now ( ) . Unix ( )
// Session 也要生成向量
if vs != nil && memoryCfg . Vector . APIKey != "" && useSessionSummary {
2026-04-27 08:03:16 +08:00
generateSessionEmbedding ( session )
2026-04-27 07:15:08 +08:00
}
if err := db . UpdateSession ( session ) ; err != nil {
return 0 , fmt . Errorf ( "更新会话失败: %v" , err )
}
// 恢复原始 session summary( 避免 recall 结果污染)
session . Summary = originalSummary
currentSession = session
return len ( session . ChatIDs ) , nil
}
func generateEmbedding ( chat * Chat ) {
vs := GetVectorService ( )
if vs == nil {
return
}
text := chat . UserInput
if len ( chat . AIReplies ) > 0 {
text = text + "\n" + chat . AIReplies [ len ( chat . AIReplies ) - 1 ]
}
embedding , err := vs . Generate ( text )
if err != nil {
fmt . Printf ( "[memory] 向量生成失败: %v\n" , err )
return
}
chat . SummaryEmbedding = embedding
chat . UpdatedAt = time . Now ( ) . Unix ( )
if err := GetDB ( ) . UpdateChat ( chat ) ; err != nil {
fmt . Printf ( "[memory] 向量保存失败: %v\n" , err )
}
}
func generateSessionEmbedding ( session * Session ) {
vs := GetVectorService ( )
if vs == nil {
return
}
embedding , err := vs . Generate ( session . Summary )
if err != nil {
fmt . Printf ( "[memory] Session 向量生成失败: %v\n" , err )
return
}
session . SummaryEmbedding = embedding
session . UpdatedAt = time . Now ( ) . Unix ( )
if err := GetDB ( ) . UpdateSession ( session ) ; err != nil {
fmt . Printf ( "[memory] Session 向量保存失败: %v\n" , err )
}
}
func createSession ( ) ( * Session , error ) {
return CreateNewSession ( )
}
type SearchResult struct {
Chat * Chat
Similarity float64
}
func SearchSimilar ( query string , topK int ) ( [ ] SearchResult , error ) {
db := GetDB ( )
if db == nil {
return nil , fmt . Errorf ( "数据库未初始化" )
}
vs := GetVectorService ( )
if vs == nil {
return nil , fmt . Errorf ( "向量服务未初始化" )
}
// 获取配置的最大检索数
maxResults := internal . GetProjectConfig ( ) . Memory . Vector . MaxSearchResults
if maxResults <= 0 {
maxResults = 10
}
if topK > maxResults {
topK = maxResults
}
// 生成查询向量
queryEmbedding , err := vs . Generate ( query )
if err != nil {
return nil , fmt . Errorf ( "查询向量生成失败: %w" , err )
}
// 获取所有有向量的 chat
chats , err := db . GetAllChats ( )
if err != nil {
return nil , fmt . Errorf ( "获取聊天记录失败: %w" , err )
}
// 计算相似度并排序
type scoredChat struct {
chat * Chat
similarity float64
}
var scored [ ] scoredChat
for _ , chat := range chats {
if len ( chat . SummaryEmbedding ) == 0 {
continue
}
sim := vs . CosineSimilarity ( queryEmbedding , chat . SummaryEmbedding )
scored = append ( scored , scoredChat { chat : chat , similarity : sim } )
}
// 按相似度排序
for i := 0 ; i < len ( scored ) ; i ++ {
for j := i + 1 ; j < len ( scored ) ; j ++ {
if scored [ j ] . similarity > scored [ i ] . similarity {
scored [ i ] , scored [ j ] = scored [ j ] , scored [ i ]
}
}
}
// 取前 K 个
if topK > len ( scored ) {
topK = len ( scored )
}
results := make ( [ ] SearchResult , topK )
for i := 0 ; i < topK ; i ++ {
results [ i ] = SearchResult {
Chat : scored [ i ] . chat ,
Similarity : scored [ i ] . similarity ,
}
}
return results , nil
}
func ( db * DB ) GetAllChats ( ) ( [ ] * Chat , error ) {
rows , err := db . db . Query (
` SELECT id, session_id, user_input, ai_replies, summary, summary_embedding, created_at, updated_at FROM chats WHERE summary_embedding IS NOT NULL AND length(summary_embedding) > 0 ORDER BY id DESC ` ,
)
if err != nil {
return nil , err
}
defer rows . Close ( )
var chats [ ] * Chat
for rows . Next ( ) {
var c Chat
var aiRepliesJSON string
var summaryEmbedding [ ] byte
if err := rows . Scan ( & c . ID , & c . SessionID , & c . UserInput , & aiRepliesJSON , & c . Summary , & summaryEmbedding , & c . CreatedAt , & c . UpdatedAt ) ; err != nil {
continue
}
c . SummaryEmbedding = summaryEmbedding
c . SetAIRepliesFromJSON ( aiRepliesJSON )
chats = append ( chats , & c )
}
return chats , nil
}