Files
HxClaw/cmd/hxclaw/main.go
titor 532466c343 feat: 添加 /context 命令显示会话上下文使用情况
- 集成 picoclaw v0.2.7 的上下文统计功能
- 显示消息数、token 使用量、压缩阈值和剩余空间
- 在 interactiveMode 和 simpleInteractiveMode 中均支持
2026-05-03 02:43:27 +08:00

603 lines
14 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package main
import (
"context"
"fmt"
"math"
"os"
"strings"
"time"
"charm.land/lipgloss/v2"
"github.com/hxclaw/hxclaw/cmd/hxclaw/internal"
"github.com/hxclaw/hxclaw/cmd/hxclaw/internal/memory"
"github.com/muesli/termenv"
"github.com/sipeed/picoclaw/pkg/agent"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
)
const Logo = "🦐"
var currentSession *memory.Session
func main() {
if err := internal.LoadProjectConfig(); err != nil {
fmt.Fprintf(os.Stderr, "错误:加载项目配置失败: %v\n", err)
os.Exit(1)
}
logo := internal.GetProjectConfig().UI.Logo
fmt.Printf("%s HxClaw - PicoClaw 增强版 CLI\n\n", logo)
cfg, err := internal.LoadConfig()
if err != nil {
fmt.Fprintf(os.Stderr, "错误:加载配置失败: %v\n", err)
os.Exit(1)
}
logger.ConfigureFromEnv()
provider, modelID, err := providers.CreateProvider(cfg)
if err != nil {
fmt.Fprintf(os.Stderr, "错误:创建 Provider 失败: %v\n", err)
os.Exit(1)
}
if modelID != "" {
cfg.Agents.Defaults.ModelName = modelID
}
msgBus := bus.NewMessageBus()
defer msgBus.Close()
agentLoop := agent.NewAgentLoop(cfg, msgBus, provider)
defer agentLoop.Close()
startupInfo := agentLoop.GetStartupInfo()
logger.InfoCF("hxclaw", "HxClaw 已初始化",
map[string]any{
"tools_count": startupInfo["tools"].(map[string]any)["count"],
"skills_total": startupInfo["skills"].(map[string]any)["total"],
"skills_available": startupInfo["skills"].(map[string]any)["available"],
})
memoryCfg := internal.GetProjectConfig().Memory
if memoryCfg.Enabled {
// 优先使用用户配置中的 db_path如果没有则使用默认路径
dbPath := memoryCfg.DBPath
if dbPath == "" {
dbPath = memory.GetDefaultDBPath()
}
fmt.Printf("初始化记忆体db_path: %s\n", dbPath)
if err := memory.Init(agentLoop, memory.WithDBPath(dbPath)); err != nil {
fmt.Fprintf(os.Stderr, "警告:初始化记忆体失败: %v将使用无记忆模式\n", err)
} else {
fmt.Println("记忆体初始化成功")
memory.InitVector(
memory.WithAPIKey(memoryCfg.Vector.APIKey),
memory.WithBaseURL(memoryCfg.Vector.BaseURL),
memory.WithModel(memoryCfg.Vector.Model),
)
}
}
fmt.Printf("%s Interactive mode (Ctrl+C to exit)\n\n", Logo)
interactiveMode(agentLoop, "cli:default")
}
func interactiveMode(agentLoop *agent.AgentLoop, sessionKey string) {
basePrompt := internal.GetProjectConfig().UI.UserIcon
prompt := internal.GetTTSPrompt(basePrompt)
rl, err := internal.NewReadline(prompt)
if err != nil {
fmt.Printf("初始化 readline 失败: %v\n", err)
fmt.Println("回退到简单输入模式...")
simpleInteractiveMode(agentLoop, sessionKey)
return
}
defer rl.Close()
ttsCfg := internal.GetProjectConfig().TTS
if ttsCfg.Enabled {
internal.SetTTSEnabled(true)
}
for {
line, err := rl.Readline()
if err != nil {
if err == internal.ErrInterrupt || err == internal.ErrEOF {
fmt.Println("\n再见!")
memory.ExportIfNeeded()
return
}
fmt.Printf("读取输入错误: %v\n", err)
continue
}
input := line
if input == "" {
continue
}
if input == "exit" || input == "quit" {
fmt.Println("再见!")
memory.ExportIfNeeded()
return
}
isTempTTS := false
if len(input) > 0 && input[0] == 'T' && (len(input) == 1 || input[1] == ' ') {
input = strings.TrimPrefix(input, "T")
input = strings.TrimPrefix(input, " ")
isTempTTS = true
}
if strings.HasPrefix(input, "/tts") {
handleTTSCommand(input, rl, basePrompt)
continue
}
if strings.HasPrefix(input, "/new") {
handleNewSessionCommand(rl, basePrompt)
continue
}
if strings.HasPrefix(input, "/memory") {
handleMemoryCommand(input)
continue
}
if strings.HasPrefix(input, "/sessions") {
handleSessionsCommand()
continue
}
if strings.HasPrefix(input, "/context") {
handleContextCommand(agentLoop, sessionKey)
continue
}
if isTempTTS {
enabled := internal.ToggleTTS()
if enabled {
rl.SetPrompt(internal.GetTTSPrompt(basePrompt))
}
}
runWithStreaming(agentLoop, input, sessionKey, isTempTTS)
}
}
func simpleInteractiveMode(agentLoop *agent.AgentLoop, sessionKey string) {
reader := internal.NewSimpleReader()
ttsCfg := internal.GetProjectConfig().TTS
if ttsCfg.Enabled {
internal.SetTTSEnabled(true)
}
for {
fmt.Print(internal.GetTTSPrompt(internal.GetProjectConfig().UI.UserIcon))
line, err := reader.ReadString()
if err != nil {
if err == internal.ErrEOF {
fmt.Println("\n再见!")
memory.ExportIfNeeded()
return
}
fmt.Printf("读取输入错误: %v\n", err)
continue
}
input := line
if input == "" {
continue
}
if input == "exit" || input == "quit" {
fmt.Println("再见!")
memory.ExportIfNeeded()
return
}
isTempTTS := false
if len(input) > 0 && input[0] == 'T' && (len(input) == 1 || input[1] == ' ') {
input = strings.TrimPrefix(input, "T")
input = strings.TrimPrefix(input, " ")
isTempTTS = true
}
if strings.HasPrefix(input, "/tts") {
handleTTSCommandSimple(input)
continue
}
if strings.HasPrefix(input, "/new") {
handleNewSessionCommand(nil, internal.GetProjectConfig().UI.UserIcon)
continue
}
if strings.HasPrefix(input, "/memory") {
handleMemoryCommand(input)
continue
}
if strings.HasPrefix(input, "/sessions") {
handleSessionsCommand()
continue
}
if strings.HasPrefix(input, "/context") {
handleContextCommand(agentLoop, sessionKey)
continue
}
if isTempTTS {
internal.ToggleTTS()
}
runWithStreaming(agentLoop, input, sessionKey, isTempTTS)
}
}
// runWithStreaming 使用 ProcessDirect 处理请求,支持工具调用和结果显示
func runWithStreaming(agentLoop *agent.AgentLoop, input, sessionKey string, tempTTS bool) {
startTime := time.Now()
// 保存原始输入用于后续保存
originalInput := input
// 注入 hxclaw 的上下文摘要
memoryCfg := internal.GetProjectConfig().Memory
if memoryCfg.Enabled {
contextPrompt := memory.GetContextPrompt(input)
if contextPrompt != "" {
input = contextPrompt + "\n用户新问题: " + input
}
}
spinner := internal.NewSpinner("思考中...")
spinner.Start()
resp, err := agentLoop.ProcessDirect(context.Background(), input, sessionKey)
spinner.Stop()
if err != nil {
fmt.Printf("警告: %v\n", err)
}
if resp == "" {
fmt.Println("(空响应,跳过保存)")
return
}
rendered := internal.RenderMarkdown(resp)
clearSpinnerLine()
outputLineByLine(rendered)
ttsCfg := internal.GetProjectConfig().TTS
if ttsCfg.Enabled || tempTTS || internal.IsTTSEnabled() {
go internal.SpeakText(resp)
}
// 保存聊天记录到数据库
var chatCount int
var saveErr error
memoryCfg = internal.GetProjectConfig().Memory
if memoryCfg.Enabled {
chatCount, saveErr = memory.SaveChat(originalInput, resp, !memory.ShouldSkipSummaryUpdate(originalInput))
// 如果需要新 session提示用户
if saveErr == memory.ErrNeedNewSession {
fmt.Println("请使用 /new 创建新会话后再对话")
saveErr = nil
}
}
elapsed := time.Since(startTime)
printElapsed(elapsed, chatCount, saveErr)
}
func clearSpinnerLine() {
output := termenv.DefaultOutput()
output.ClearLine()
fmt.Print("\r")
os.Stdout.Sync()
}
func outputLineByLine(text string) {
if text == "" {
return
}
lines := strings.Split(text, "\n")
totalLines := len(lines)
cfg := internal.GetProjectConfig()
lineDelay := time.Duration(cfg.Streaming.LineDelayMs) * time.Millisecond
lastLineDelay := time.Duration(cfg.Streaming.LastLineDelayMs) * time.Millisecond
for i, line := range lines {
if line == "" {
lipgloss.Print("\n")
continue
}
lipgloss.Print(line + "\n")
if i < totalLines-1 {
time.Sleep(lineDelay)
} else {
time.Sleep(lastLineDelay)
}
}
lipgloss.Print("\n")
}
var (
iconStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("#f0c75e"))
textStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("#2b2e32"))
memoryOkStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("#4a9e6b")) // 暗绿色
memoryErrStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("#c75050")) // 暗红色
)
func printElapsed(elapsed time.Duration, chatCount int, saveErr error) {
elapsedSec := math.Round(elapsed.Seconds()*10) / 10
elapsedStr := formatDuration(elapsedSec)
icon := iconStyle.Render("▣ ")
timeText := textStyle.Render(fmt.Sprintf("耗时: %s", elapsedStr))
var statusText string
if saveErr != nil {
statusText = memoryErrStyle.Render("会话保存异常")
} else if chatCount > 0 {
statusText = memoryOkStyle.Render("会话已保存")
}
memCountText := textStyle.Render(fmt.Sprintf("当前会话 %d 条消息", chatCount))
if statusText != "" {
fmt.Printf(" %s%s · %s · %s\n\n", icon, timeText, statusText, memCountText)
} else {
fmt.Printf(" %s%s · %s\n\n", icon, timeText, memCountText)
}
}
func formatTokens(n int) string {
if n >= 1000 {
return fmt.Sprintf("%.1fk", float64(n)/1000)
}
return fmt.Sprintf("%d", n)
}
func formatDuration(s float64) string {
if s >= 60 {
return fmt.Sprintf("%.1fm", s/60)
}
return fmt.Sprintf("%.1fs", s)
}
func handleTTSCommand(input string, rl *internal.Readline, basePrompt string) {
args := strings.Fields(input)
if len(args) == 1 {
enabled := internal.ToggleTTS()
rl.SetPrompt(internal.GetTTSPrompt(basePrompt))
status := "关闭"
if enabled {
status = "开启"
}
fmt.Printf("TTS 已%s\n", status)
return
}
switch args[1] {
case "on":
internal.SetTTSEnabled(true)
rl.SetPrompt(internal.GetTTSPrompt(basePrompt))
fmt.Println("TTS 已开启")
case "off":
internal.SetTTSEnabled(false)
rl.SetPrompt(internal.GetTTSPrompt(basePrompt))
fmt.Println("TTS 已关闭")
case "status":
status := "关闭"
if internal.IsTTSEnabled() {
status = "开启"
}
fmt.Printf("TTS 状态: %s\n", status)
default:
fmt.Println("用法: /tts [on|off|status]")
}
}
func handleTTSCommandSimple(input string) {
args := strings.Fields(input)
if len(args) == 1 {
internal.ToggleTTS()
status := "关闭"
if internal.IsTTSEnabled() {
status = "开启"
}
fmt.Printf("TTS 已%s\n", status)
return
}
switch args[1] {
case "on":
internal.SetTTSEnabled(true)
fmt.Println("TTS 已开启")
case "off":
internal.SetTTSEnabled(false)
fmt.Println("TTS 已关闭")
case "status":
status := "关闭"
if internal.IsTTSEnabled() {
status = "开启"
}
fmt.Printf("TTS 状态: %s\n", status)
default:
fmt.Println("用法: /tts [on|off|status]")
}
}
func handleNewSessionCommand(rl *internal.Readline, basePrompt string) {
currentSession = nil
fmt.Println("已重置会话,输入聊天消息后将创建新会话")
}
func handleMemoryCommand(input string) {
args := strings.Fields(input)
if len(args) == 1 || args[1] == "list" {
sessions, err := memory.ListSessions()
if err != nil {
fmt.Printf("查询会话失败: %v\n", err)
return
}
if len(sessions) == 0 {
fmt.Println("暂无会话记录")
return
}
fmt.Printf("共有 %d 个会话记录:\n", len(sessions))
for _, s := range sessions {
summary := s.Summary
if summary == "" {
summary = "(无摘要)"
} else if len(summary) > 50 {
summary = summary[:50] + "..."
}
fmt.Printf(" - %s: %s\n", s.UUID[:8], summary)
}
return
}
switch args[1] {
case "show":
session := memory.GetCurrentSession()
if session == nil {
latest, err := memory.GetLatestSession()
if err != nil || latest == nil {
fmt.Println("暂无会话记录")
return
}
session = latest
}
if session.Summary == "" {
fmt.Println("当前会话暂无摘要")
return
}
fmt.Printf("=== 会话摘要 (%s) ===\n%s\n", session.UUID[:8], session.Summary)
case "current":
if currentSession == nil {
fmt.Println("当前无活跃会话")
return
}
fmt.Printf("当前会话: %s\n", currentSession.UUID)
fmt.Printf("聊天记录数: %d\n", len(currentSession.ChatIDs))
default:
fmt.Println("用法: /memory [list|show|current]")
}
}
func handleSessionsCommand() {
sessions, err := memory.ListSessions()
if err != nil {
fmt.Printf("查询会话失败: %v\n", err)
return
}
if len(sessions) == 0 {
fmt.Println("暂无会话记录")
return
}
fmt.Printf("共有 %d 个会话记录:\n", len(sessions))
for _, s := range sessions {
summary := s.Summary
if summary == "" {
summary = "(无摘要)"
} else if len(summary) > 30 {
summary = summary[:30] + "..."
}
fmt.Printf(" %s | %d 条消息 | %s\n", s.UUID[:8], len(s.ChatIDs), summary)
}
}
func handleContextCommand(agentLoop *agent.AgentLoop, sessionKey string) {
if agentLoop == nil {
fmt.Println("无法获取上下文信息")
return
}
cfg := agentLoop.GetConfig()
if cfg == nil {
fmt.Println("无法获取配置信息")
return
}
registry := agentLoop.GetRegistry()
if registry == nil {
fmt.Println("无法获取代理注册信息")
return
}
instance := registry.GetDefaultAgent()
if instance == nil {
fmt.Println("无可用代理")
return
}
contextWindow := instance.ContextWindow
if contextWindow <= 0 {
fmt.Println("上下文窗口大小未知")
return
}
history := instance.Sessions.GetHistory(sessionKey)
historyTokens := 0
for _, m := range history {
historyTokens += agent.EstimateMessageTokens(m)
}
systemTokens := 0
if instance.ContextBuilder != nil {
summary := instance.Sessions.GetSummary(sessionKey)
systemTokens = instance.ContextBuilder.EstimateSystemTokens(summary, nil)
}
toolTokens := 0
if instance.Tools != nil {
toolDefs := instance.Tools.ToProviderDefs()
toolTokens = agent.EstimateToolDefsTokens(toolDefs)
}
usedTokens := historyTokens + systemTokens + toolTokens
effectiveWindow := contextWindow - instance.MaxTokens
if effectiveWindow < 0 {
effectiveWindow = contextWindow
}
compressAt := effectiveWindow
usedPercent := 0
if compressAt > 0 {
usedPercent = usedTokens * 100 / compressAt
}
if usedPercent > 100 {
usedPercent = 100
}
remaining := compressAt - usedTokens
if remaining < 0 {
remaining = 0
}
fmt.Println("上下文使用情况")
fmt.Printf("消息数: %d\n", len(history))
fmt.Printf("已用: ~%d / %d tokens (%d%%)\n", usedTokens, contextWindow, usedPercent)
fmt.Printf("压缩阈值: %d tokens\n", compressAt)
fmt.Printf("压缩进度: %d%%\n", usedPercent)
fmt.Printf("剩余: ~%d tokens\n", remaining)
}