Files
HxClaw/cmd/hxclaw/main.go
titor 124b0baa26
Some checks failed
Release / build (push) Failing after 3h11m18s
feat: 实现 LLM 文言文摘要生成,优化 Session 创建逻辑
2026-04-27 08:58:08 +08:00

516 lines
12 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package main
import (
"context"
"fmt"
"math"
"os"
"strings"
"time"
"charm.land/lipgloss/v2"
"github.com/hxclaw/hxclaw/cmd/hxclaw/internal"
"github.com/hxclaw/hxclaw/cmd/hxclaw/internal/memory"
"github.com/muesli/termenv"
"github.com/sipeed/picoclaw/pkg/agent"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
)
const Logo = "🦐"
var currentSession *memory.Session
func main() {
if err := internal.LoadProjectConfig(); err != nil {
fmt.Fprintf(os.Stderr, "错误:加载项目配置失败: %v\n", err)
os.Exit(1)
}
logo := internal.GetProjectConfig().UI.Logo
fmt.Printf("%s HxClaw - PicoClaw 增强版 CLI\n\n", logo)
cfg, err := internal.LoadConfig()
if err != nil {
fmt.Fprintf(os.Stderr, "错误:加载配置失败: %v\n", err)
os.Exit(1)
}
logger.ConfigureFromEnv()
provider, modelID, err := providers.CreateProvider(cfg)
if err != nil {
fmt.Fprintf(os.Stderr, "错误:创建 Provider 失败: %v\n", err)
os.Exit(1)
}
if modelID != "" {
cfg.Agents.Defaults.ModelName = modelID
}
msgBus := bus.NewMessageBus()
defer msgBus.Close()
agentLoop := agent.NewAgentLoop(cfg, msgBus, provider)
defer agentLoop.Close()
startupInfo := agentLoop.GetStartupInfo()
logger.InfoCF("hxclaw", "HxClaw 已初始化",
map[string]any{
"tools_count": startupInfo["tools"].(map[string]any)["count"],
"skills_total": startupInfo["skills"].(map[string]any)["total"],
"skills_available": startupInfo["skills"].(map[string]any)["available"],
})
memoryCfg := internal.GetProjectConfig().Memory
if memoryCfg.Enabled {
// 优先使用用户配置中的 db_path如果没有则使用默认路径
dbPath := memoryCfg.DBPath
if dbPath == "" {
dbPath = memory.GetDefaultDBPath()
}
fmt.Printf("初始化记忆体db_path: %s\n", dbPath)
if err := memory.Init(agentLoop, memory.WithDBPath(dbPath)); err != nil {
fmt.Fprintf(os.Stderr, "警告:初始化记忆体失败: %v将使用无记忆模式\n", err)
} else {
fmt.Println("记忆体初始化成功")
memory.InitVector(
memory.WithAPIKey(memoryCfg.Vector.APIKey),
memory.WithBaseURL(memoryCfg.Vector.BaseURL),
memory.WithModel(memoryCfg.Vector.Model),
)
}
}
fmt.Printf("%s Interactive mode (Ctrl+C to exit)\n\n", Logo)
interactiveMode(agentLoop, "cli:default")
}
func interactiveMode(agentLoop *agent.AgentLoop, sessionKey string) {
basePrompt := internal.GetProjectConfig().UI.UserIcon
prompt := internal.GetTTSPrompt(basePrompt)
rl, err := internal.NewReadline(prompt)
if err != nil {
fmt.Printf("初始化 readline 失败: %v\n", err)
fmt.Println("回退到简单输入模式...")
simpleInteractiveMode(agentLoop, sessionKey)
return
}
defer rl.Close()
ttsCfg := internal.GetProjectConfig().TTS
if ttsCfg.Enabled {
internal.SetTTSEnabled(true)
}
for {
line, err := rl.Readline()
if err != nil {
if err == internal.ErrInterrupt || err == internal.ErrEOF {
fmt.Println("\n再见!")
memory.ExportIfNeeded()
return
}
fmt.Printf("读取输入错误: %v\n", err)
continue
}
input := line
if input == "" {
continue
}
if input == "exit" || input == "quit" {
fmt.Println("再见!")
memory.ExportIfNeeded()
return
}
isTempTTS := false
if len(input) > 0 && input[0] == 'T' && (len(input) == 1 || input[1] == ' ') {
input = strings.TrimPrefix(input, "T")
input = strings.TrimPrefix(input, " ")
isTempTTS = true
}
if strings.HasPrefix(input, "/tts") {
handleTTSCommand(input, rl, basePrompt)
continue
}
if strings.HasPrefix(input, "/new") {
handleNewSessionCommand(rl, basePrompt)
continue
}
if strings.HasPrefix(input, "/memory") {
handleMemoryCommand(input)
continue
}
if strings.HasPrefix(input, "/sessions") {
handleSessionsCommand()
continue
}
if isTempTTS {
enabled := internal.ToggleTTS()
if enabled {
rl.SetPrompt(internal.GetTTSPrompt(basePrompt))
}
}
runWithStreaming(agentLoop, input, sessionKey, isTempTTS)
}
}
func simpleInteractiveMode(agentLoop *agent.AgentLoop, sessionKey string) {
reader := internal.NewSimpleReader()
ttsCfg := internal.GetProjectConfig().TTS
if ttsCfg.Enabled {
internal.SetTTSEnabled(true)
}
for {
fmt.Print(internal.GetTTSPrompt(internal.GetProjectConfig().UI.UserIcon))
line, err := reader.ReadString()
if err != nil {
if err == internal.ErrEOF {
fmt.Println("\n再见!")
memory.ExportIfNeeded()
return
}
fmt.Printf("读取输入错误: %v\n", err)
continue
}
input := line
if input == "" {
continue
}
if input == "exit" || input == "quit" {
fmt.Println("再见!")
memory.ExportIfNeeded()
return
}
isTempTTS := false
if len(input) > 0 && input[0] == 'T' && (len(input) == 1 || input[1] == ' ') {
input = strings.TrimPrefix(input, "T")
input = strings.TrimPrefix(input, " ")
isTempTTS = true
}
if strings.HasPrefix(input, "/tts") {
handleTTSCommandSimple(input)
continue
}
if strings.HasPrefix(input, "/new") {
handleNewSessionCommand(nil, internal.GetProjectConfig().UI.UserIcon)
continue
}
if strings.HasPrefix(input, "/memory") {
handleMemoryCommand(input)
continue
}
if strings.HasPrefix(input, "/sessions") {
handleSessionsCommand()
continue
}
if isTempTTS {
internal.ToggleTTS()
}
runWithStreaming(agentLoop, input, sessionKey, isTempTTS)
}
}
// runWithStreaming 使用 ProcessDirect 处理请求,支持工具调用和结果显示
func runWithStreaming(agentLoop *agent.AgentLoop, input, sessionKey string, tempTTS bool) {
startTime := time.Now()
// 保存原始输入用于后续保存
originalInput := input
// 注入 hxclaw 的上下文摘要
memoryCfg := internal.GetProjectConfig().Memory
if memoryCfg.Enabled {
contextPrompt := memory.GetContextPrompt(input)
if contextPrompt != "" {
input = contextPrompt + "\n用户新问题: " + input
}
}
spinner := internal.NewSpinner("思考中...")
spinner.Start()
resp, err := agentLoop.ProcessDirect(context.Background(), input, sessionKey)
spinner.Stop()
if err != nil {
fmt.Printf("警告: %v\n", err)
}
if resp == "" {
fmt.Println("(空响应,跳过保存)")
return
}
rendered := internal.RenderMarkdown(resp)
clearSpinnerLine()
outputLineByLine(rendered)
ttsCfg := internal.GetProjectConfig().TTS
if ttsCfg.Enabled || tempTTS || internal.IsTTSEnabled() {
go internal.SpeakText(resp)
}
// 保存聊天记录到数据库
var chatCount int
var saveErr error
memoryCfg = internal.GetProjectConfig().Memory
if memoryCfg.Enabled {
chatCount, saveErr = memory.SaveChat(originalInput, resp, !memory.ShouldSkipSummaryUpdate(originalInput))
// 如果需要新 session提示用户
if saveErr == memory.ErrNeedNewSession {
fmt.Println("请使用 /new 创建新会话后再对话")
saveErr = nil
}
}
elapsed := time.Since(startTime)
printElapsed(elapsed, chatCount, saveErr)
}
func clearSpinnerLine() {
output := termenv.DefaultOutput()
output.ClearLine()
fmt.Print("\r")
os.Stdout.Sync()
}
func outputLineByLine(text string) {
if text == "" {
return
}
lines := strings.Split(text, "\n")
totalLines := len(lines)
cfg := internal.GetProjectConfig()
lineDelay := time.Duration(cfg.Streaming.LineDelayMs) * time.Millisecond
lastLineDelay := time.Duration(cfg.Streaming.LastLineDelayMs) * time.Millisecond
for i, line := range lines {
if line == "" {
lipgloss.Print("\n")
continue
}
lipgloss.Print(line + "\n")
if i < totalLines-1 {
time.Sleep(lineDelay)
} else {
time.Sleep(lastLineDelay)
}
}
lipgloss.Print("\n")
}
var (
iconStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("#f0c75e"))
textStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("#2b2e32"))
memoryOkStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("#4a9e6b")) // 暗绿色
memoryErrStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("#c75050")) // 暗红色
)
func printElapsed(elapsed time.Duration, chatCount int, saveErr error) {
elapsedSec := math.Round(elapsed.Seconds()*10) / 10
elapsedStr := formatDuration(elapsedSec)
icon := iconStyle.Render("▣ ")
timeText := textStyle.Render(fmt.Sprintf("耗时: %s", elapsedStr))
var statusText string
if saveErr != nil {
statusText = memoryErrStyle.Render("会话保存异常")
} else if chatCount > 0 {
statusText = memoryOkStyle.Render("会话已保存")
}
memCountText := textStyle.Render(fmt.Sprintf("当前会话 %d 条消息", chatCount))
if statusText != "" {
fmt.Printf(" %s%s · %s · %s\n\n", icon, timeText, statusText, memCountText)
} else {
fmt.Printf(" %s%s · %s\n\n", icon, timeText, memCountText)
}
}
func formatTokens(n int) string {
if n >= 1000 {
return fmt.Sprintf("%.1fk", float64(n)/1000)
}
return fmt.Sprintf("%d", n)
}
func formatDuration(s float64) string {
if s >= 60 {
return fmt.Sprintf("%.1fm", s/60)
}
return fmt.Sprintf("%.1fs", s)
}
func handleTTSCommand(input string, rl *internal.Readline, basePrompt string) {
args := strings.Fields(input)
if len(args) == 1 {
enabled := internal.ToggleTTS()
rl.SetPrompt(internal.GetTTSPrompt(basePrompt))
status := "关闭"
if enabled {
status = "开启"
}
fmt.Printf("TTS 已%s\n", status)
return
}
switch args[1] {
case "on":
internal.SetTTSEnabled(true)
rl.SetPrompt(internal.GetTTSPrompt(basePrompt))
fmt.Println("TTS 已开启")
case "off":
internal.SetTTSEnabled(false)
rl.SetPrompt(internal.GetTTSPrompt(basePrompt))
fmt.Println("TTS 已关闭")
case "status":
status := "关闭"
if internal.IsTTSEnabled() {
status = "开启"
}
fmt.Printf("TTS 状态: %s\n", status)
default:
fmt.Println("用法: /tts [on|off|status]")
}
}
func handleTTSCommandSimple(input string) {
args := strings.Fields(input)
if len(args) == 1 {
internal.ToggleTTS()
status := "关闭"
if internal.IsTTSEnabled() {
status = "开启"
}
fmt.Printf("TTS 已%s\n", status)
return
}
switch args[1] {
case "on":
internal.SetTTSEnabled(true)
fmt.Println("TTS 已开启")
case "off":
internal.SetTTSEnabled(false)
fmt.Println("TTS 已关闭")
case "status":
status := "关闭"
if internal.IsTTSEnabled() {
status = "开启"
}
fmt.Printf("TTS 状态: %s\n", status)
default:
fmt.Println("用法: /tts [on|off|status]")
}
}
func handleNewSessionCommand(rl *internal.Readline, basePrompt string) {
currentSession = nil
fmt.Println("已重置会话,输入聊天消息后将创建新会话")
}
func handleMemoryCommand(input string) {
args := strings.Fields(input)
if len(args) == 1 || args[1] == "list" {
sessions, err := memory.ListSessions()
if err != nil {
fmt.Printf("查询会话失败: %v\n", err)
return
}
if len(sessions) == 0 {
fmt.Println("暂无会话记录")
return
}
fmt.Printf("共有 %d 个会话记录:\n", len(sessions))
for _, s := range sessions {
summary := s.Summary
if summary == "" {
summary = "(无摘要)"
} else if len(summary) > 50 {
summary = summary[:50] + "..."
}
fmt.Printf(" - %s: %s\n", s.UUID[:8], summary)
}
return
}
switch args[1] {
case "show":
session := memory.GetCurrentSession()
if session == nil {
latest, err := memory.GetLatestSession()
if err != nil || latest == nil {
fmt.Println("暂无会话记录")
return
}
session = latest
}
if session.Summary == "" {
fmt.Println("当前会话暂无摘要")
return
}
fmt.Printf("=== 会话摘要 (%s) ===\n%s\n", session.UUID[:8], session.Summary)
case "current":
if currentSession == nil {
fmt.Println("当前无活跃会话")
return
}
fmt.Printf("当前会话: %s\n", currentSession.UUID)
fmt.Printf("聊天记录数: %d\n", len(currentSession.ChatIDs))
default:
fmt.Println("用法: /memory [list|show|current]")
}
}
func handleSessionsCommand() {
sessions, err := memory.ListSessions()
if err != nil {
fmt.Printf("查询会话失败: %v\n", err)
return
}
if len(sessions) == 0 {
fmt.Println("暂无会话记录")
return
}
fmt.Printf("共有 %d 个会话记录:\n", len(sessions))
for _, s := range sessions {
summary := s.Summary
if summary == "" {
summary = "(无摘要)"
} else if len(summary) > 30 {
summary = summary[:30] + "..."
}
fmt.Printf(" %s | %d 条消息 | %s\n", s.UUID[:8], len(s.ChatIDs), summary)
}
}