Files
HxClaw/cmd/hxclaw/main.go

521 lines
12 KiB
Go
Raw Normal View History

package main
import (
"context"
"fmt"
"math"
"os"
"strings"
"time"
"charm.land/lipgloss/v2"
"github.com/hxclaw/hxclaw/cmd/hxclaw/internal"
"github.com/hxclaw/hxclaw/cmd/hxclaw/internal/memory"
"github.com/muesli/termenv"
"github.com/sipeed/picoclaw/pkg/agent"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
)
const Logo = "🦐"
var currentSession *memory.Session
func main() {
if err := internal.LoadProjectConfig(); err != nil {
fmt.Fprintf(os.Stderr, "错误:加载项目配置失败: %v\n", err)
os.Exit(1)
}
logo := internal.GetProjectConfig().UI.Logo
fmt.Printf("%s HxClaw - PicoClaw 增强版 CLI\n\n", logo)
cfg, err := internal.LoadConfig()
if err != nil {
fmt.Fprintf(os.Stderr, "错误:加载配置失败: %v\n", err)
os.Exit(1)
}
logger.ConfigureFromEnv()
provider, modelID, err := providers.CreateProvider(cfg)
if err != nil {
fmt.Fprintf(os.Stderr, "错误:创建 Provider 失败: %v\n", err)
os.Exit(1)
}
if modelID != "" {
cfg.Agents.Defaults.ModelName = modelID
}
msgBus := bus.NewMessageBus()
defer msgBus.Close()
agentLoop := agent.NewAgentLoop(cfg, msgBus, provider)
defer agentLoop.Close()
startupInfo := agentLoop.GetStartupInfo()
logger.InfoCF("hxclaw", "HxClaw 已初始化",
map[string]any{
"tools_count": startupInfo["tools"].(map[string]any)["count"],
"skills_total": startupInfo["skills"].(map[string]any)["total"],
"skills_available": startupInfo["skills"].(map[string]any)["available"],
})
memoryCfg := internal.GetProjectConfig().Memory
if memoryCfg.Enabled {
// 优先使用用户配置中的 db_path如果没有则使用默认路径
dbPath := memoryCfg.DBPath
if dbPath == "" {
dbPath = memory.GetDefaultDBPath()
}
fmt.Printf("初始化记忆体db_path: %s\n", dbPath)
if err := memory.Init(memory.WithDBPath(dbPath)); err != nil {
fmt.Fprintf(os.Stderr, "警告:初始化记忆体失败: %v将使用无记忆模式\n", err)
} else {
fmt.Println("记忆体初始化成功")
memory.InitVector(
memory.WithAPIKey(memoryCfg.Vector.APIKey),
memory.WithBaseURL(memoryCfg.Vector.BaseURL),
memory.WithModel(memoryCfg.Vector.Model),
)
}
}
fmt.Printf("%s Interactive mode (Ctrl+C to exit)\n\n", Logo)
interactiveMode(agentLoop, "cli:default")
}
func interactiveMode(agentLoop *agent.AgentLoop, sessionKey string) {
basePrompt := internal.GetProjectConfig().UI.UserIcon
prompt := internal.GetTTSPrompt(basePrompt)
rl, err := internal.NewReadline(prompt)
if err != nil {
fmt.Printf("初始化 readline 失败: %v\n", err)
fmt.Println("回退到简单输入模式...")
simpleInteractiveMode(agentLoop, sessionKey)
return
}
defer rl.Close()
ttsCfg := internal.GetProjectConfig().TTS
if ttsCfg.Enabled {
internal.SetTTSEnabled(true)
}
for {
line, err := rl.Readline()
if err != nil {
if err == internal.ErrInterrupt || err == internal.ErrEOF {
fmt.Println("\n再见!")
memory.ExportIfNeeded()
return
}
fmt.Printf("读取输入错误: %v\n", err)
continue
}
input := line
if input == "" {
continue
}
if input == "exit" || input == "quit" {
fmt.Println("再见!")
memory.ExportIfNeeded()
return
}
isTempTTS := false
if len(input) > 0 && input[0] == 'T' && (len(input) == 1 || input[1] == ' ') {
input = strings.TrimPrefix(input, "T")
input = strings.TrimPrefix(input, " ")
isTempTTS = true
}
if strings.HasPrefix(input, "/tts") {
handleTTSCommand(input, rl, basePrompt)
continue
}
if strings.HasPrefix(input, "/new") {
handleNewSessionCommand(rl, basePrompt)
continue
}
if strings.HasPrefix(input, "/memory") {
handleMemoryCommand(input)
continue
}
if strings.HasPrefix(input, "/sessions") {
handleSessionsCommand()
continue
}
if isTempTTS {
enabled := internal.ToggleTTS()
if enabled {
rl.SetPrompt(internal.GetTTSPrompt(basePrompt))
}
}
runWithStreaming(agentLoop, input, sessionKey, isTempTTS)
}
}
func simpleInteractiveMode(agentLoop *agent.AgentLoop, sessionKey string) {
reader := internal.NewSimpleReader()
ttsCfg := internal.GetProjectConfig().TTS
if ttsCfg.Enabled {
internal.SetTTSEnabled(true)
}
for {
fmt.Print(internal.GetTTSPrompt(internal.GetProjectConfig().UI.UserIcon))
line, err := reader.ReadString()
if err != nil {
if err == internal.ErrEOF {
fmt.Println("\n再见!")
memory.ExportIfNeeded()
return
}
fmt.Printf("读取输入错误: %v\n", err)
continue
}
input := line
if input == "" {
continue
}
if input == "exit" || input == "quit" {
fmt.Println("再见!")
memory.ExportIfNeeded()
return
}
isTempTTS := false
if len(input) > 0 && input[0] == 'T' && (len(input) == 1 || input[1] == ' ') {
input = strings.TrimPrefix(input, "T")
input = strings.TrimPrefix(input, " ")
isTempTTS = true
}
if strings.HasPrefix(input, "/tts") {
handleTTSCommandSimple(input)
continue
}
if strings.HasPrefix(input, "/new") {
handleNewSessionCommand(nil, internal.GetProjectConfig().UI.UserIcon)
continue
}
if strings.HasPrefix(input, "/memory") {
handleMemoryCommand(input)
continue
}
if strings.HasPrefix(input, "/sessions") {
handleSessionsCommand()
continue
}
if isTempTTS {
internal.ToggleTTS()
}
runWithStreaming(agentLoop, input, sessionKey, isTempTTS)
}
}
// runWithStreaming 使用 ProcessDirect 处理请求,支持工具调用和结果显示
func runWithStreaming(agentLoop *agent.AgentLoop, input, sessionKey string, tempTTS bool) {
startTime := time.Now()
// 保存原始输入用于后续保存
originalInput := input
// 注入 hxclaw 的上下文摘要
memoryCfg := internal.GetProjectConfig().Memory
if memoryCfg.Enabled {
contextPrompt := memory.GetContextPrompt(input)
if contextPrompt != "" {
input = contextPrompt + "\n用户新问题: " + input
}
}
spinner := internal.NewSpinner("思考中...")
spinner.Start()
resp, err := agentLoop.ProcessDirect(context.Background(), input, sessionKey)
spinner.Stop()
if err != nil {
fmt.Printf("警告: %v\n", err)
}
if resp == "" {
fmt.Println("(空响应,跳过保存)")
return
}
rendered := internal.RenderMarkdown(resp)
clearSpinnerLine()
outputLineByLine(rendered)
ttsCfg := internal.GetProjectConfig().TTS
if ttsCfg.Enabled || tempTTS || internal.IsTTSEnabled() {
go internal.SpeakText(resp)
}
// 保存聊天记录到数据库
var chatCount int
var saveErr error
memoryCfg = internal.GetProjectConfig().Memory
if memoryCfg.Enabled {
chatCount, saveErr = memory.SaveChat(originalInput, resp, !memory.ShouldSkipSummaryUpdate(originalInput))
// 如果需要新 session提示用户
if saveErr == memory.ErrNeedNewSession {
fmt.Println("请使用 /new 创建新会话后再对话")
saveErr = nil
}
}
elapsed := time.Since(startTime)
printElapsed(elapsed, chatCount, saveErr)
}
func clearSpinnerLine() {
output := termenv.DefaultOutput()
output.ClearLine()
fmt.Print("\r")
os.Stdout.Sync()
}
func outputLineByLine(text string) {
if text == "" {
return
}
lines := strings.Split(text, "\n")
totalLines := len(lines)
cfg := internal.GetProjectConfig()
lineDelay := time.Duration(cfg.Streaming.LineDelayMs) * time.Millisecond
lastLineDelay := time.Duration(cfg.Streaming.LastLineDelayMs) * time.Millisecond
for i, line := range lines {
if line == "" {
lipgloss.Print("\n")
continue
}
lipgloss.Print(line + "\n")
if i < totalLines-1 {
time.Sleep(lineDelay)
} else {
time.Sleep(lastLineDelay)
}
}
lipgloss.Print("\n")
}
var (
iconStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("#f0c75e"))
textStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("#2b2e32"))
memoryOkStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("#4a9e6b")) // 暗绿色
memoryErrStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("#c75050")) // 暗红色
)
func printElapsed(elapsed time.Duration, chatCount int, saveErr error) {
elapsedSec := math.Round(elapsed.Seconds()*10) / 10
elapsedStr := formatDuration(elapsedSec)
icon := iconStyle.Render("▣ ")
timeText := textStyle.Render(fmt.Sprintf("耗时: %s", elapsedStr))
var statusText string
if saveErr != nil {
statusText = memoryErrStyle.Render("会话保存异常")
} else if chatCount > 0 {
statusText = memoryOkStyle.Render("会话已保存")
}
memCountText := textStyle.Render(fmt.Sprintf("当前会话 %d 条消息", chatCount))
if statusText != "" {
fmt.Printf(" %s%s · %s · %s\n\n", icon, timeText, statusText, memCountText)
} else {
fmt.Printf(" %s%s · %s\n\n", icon, timeText, memCountText)
}
}
func formatTokens(n int) string {
if n >= 1000 {
return fmt.Sprintf("%.1fk", float64(n)/1000)
}
return fmt.Sprintf("%d", n)
}
func formatDuration(s float64) string {
if s >= 60 {
return fmt.Sprintf("%.1fm", s/60)
}
return fmt.Sprintf("%.1fs", s)
}
func handleTTSCommand(input string, rl *internal.Readline, basePrompt string) {
args := strings.Fields(input)
if len(args) == 1 {
enabled := internal.ToggleTTS()
rl.SetPrompt(internal.GetTTSPrompt(basePrompt))
status := "关闭"
if enabled {
status = "开启"
}
fmt.Printf("TTS 已%s\n", status)
return
}
switch args[1] {
case "on":
internal.SetTTSEnabled(true)
rl.SetPrompt(internal.GetTTSPrompt(basePrompt))
fmt.Println("TTS 已开启")
case "off":
internal.SetTTSEnabled(false)
rl.SetPrompt(internal.GetTTSPrompt(basePrompt))
fmt.Println("TTS 已关闭")
case "status":
status := "关闭"
if internal.IsTTSEnabled() {
status = "开启"
}
fmt.Printf("TTS 状态: %s\n", status)
default:
fmt.Println("用法: /tts [on|off|status]")
}
}
func handleTTSCommandSimple(input string) {
args := strings.Fields(input)
if len(args) == 1 {
internal.ToggleTTS()
status := "关闭"
if internal.IsTTSEnabled() {
status = "开启"
}
fmt.Printf("TTS 已%s\n", status)
return
}
switch args[1] {
case "on":
internal.SetTTSEnabled(true)
fmt.Println("TTS 已开启")
case "off":
internal.SetTTSEnabled(false)
fmt.Println("TTS 已关闭")
case "status":
status := "关闭"
if internal.IsTTSEnabled() {
status = "开启"
}
fmt.Printf("TTS 状态: %s\n", status)
default:
fmt.Println("用法: /tts [on|off|status]")
}
}
func handleNewSessionCommand(rl *internal.Readline, basePrompt string) {
uuid, err := memory.CreateNewSession()
if err != nil {
fmt.Printf("创建新会话失败: %v\n", err)
return
}
fmt.Printf("已创建新会话: %s\n", uuid)
currentSession = nil
}
func handleMemoryCommand(input string) {
args := strings.Fields(input)
if len(args) == 1 || args[1] == "list" {
sessions, err := memory.ListSessions()
if err != nil {
fmt.Printf("查询会话失败: %v\n", err)
return
}
if len(sessions) == 0 {
fmt.Println("暂无会话记录")
return
}
fmt.Printf("共有 %d 个会话记录:\n", len(sessions))
for _, s := range sessions {
summary := s.Summary
if summary == "" {
summary = "(无摘要)"
} else if len(summary) > 50 {
summary = summary[:50] + "..."
}
fmt.Printf(" - %s: %s\n", s.UUID[:8], summary)
}
return
}
switch args[1] {
case "show":
session := memory.GetCurrentSession()
if session == nil {
latest, err := memory.GetLatestSession()
if err != nil || latest == nil {
fmt.Println("暂无会话记录")
return
}
session = latest
}
if session.Summary == "" {
fmt.Println("当前会话暂无摘要")
return
}
fmt.Printf("=== 会话摘要 (%s) ===\n%s\n", session.UUID[:8], session.Summary)
case "current":
if currentSession == nil {
fmt.Println("当前无活跃会话")
return
}
fmt.Printf("当前会话: %s\n", currentSession.UUID)
fmt.Printf("聊天记录数: %d\n", len(currentSession.ChatIDs))
default:
fmt.Println("用法: /memory [list|show|current]")
}
}
func handleSessionsCommand() {
sessions, err := memory.ListSessions()
if err != nil {
fmt.Printf("查询会话失败: %v\n", err)
return
}
if len(sessions) == 0 {
fmt.Println("暂无会话记录")
return
}
fmt.Printf("共有 %d 个会话记录:\n", len(sessions))
for _, s := range sessions {
summary := s.Summary
if summary == "" {
summary = "(无摘要)"
} else if len(summary) > 30 {
summary = summary[:30] + "..."
}
fmt.Printf(" %s | %d 条消息 | %s\n", s.UUID[:8], len(s.ChatIDs), summary)
}
}