Files
HxClaw/cmd/hxclaw/main.go
titor 13ece24893 feat: 添加流式输出动画效果
- 添加 spinner 组件,使用 bubbletea v2 的 MiniDot 动画
- 用户输入后显示思考中动画
- 第一个 token 返回后显示思考完成
- 流式输出完成后添加空行分隔
2026-04-11 23:55:43 +08:00

199 lines
4.7 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package main
import (
"context"
"fmt"
"os"
"strings"
"github.com/hxclaw/hxclaw/cmd/hxclaw/internal"
"github.com/sipeed/picoclaw/pkg/agent"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
)
const Logo = "🦐"
func main() {
fmt.Printf("%s HxClaw - PicoClaw 增强版 CLI\n\n", Logo)
cfg, err := internal.LoadConfig()
if err != nil {
fmt.Fprintf(os.Stderr, "错误:加载配置失败: %v\n", err)
os.Exit(1)
}
logger.ConfigureFromEnv()
provider, modelID, err := providers.CreateProvider(cfg)
if err != nil {
fmt.Fprintf(os.Stderr, "错误:创建 Provider 失败: %v\n", err)
os.Exit(1)
}
if modelID != "" {
cfg.Agents.Defaults.ModelName = modelID
}
msgBus := bus.NewMessageBus()
defer msgBus.Close()
agentLoop := agent.NewAgentLoop(cfg, msgBus, provider)
defer agentLoop.Close()
startupInfo := agentLoop.GetStartupInfo()
logger.InfoCF("hxclaw", "HxClaw 已初始化",
map[string]any{
"tools_count": startupInfo["tools"].(map[string]any)["count"],
"skills_total": startupInfo["skills"].(map[string]any)["total"],
"skills_available": startupInfo["skills"].(map[string]any)["available"],
})
fmt.Printf("%s Interactive mode (Ctrl+C to exit)\n\n", Logo)
interactiveMode(agentLoop, "cli:default")
}
func interactiveMode(agentLoop *agent.AgentLoop, sessionKey string) {
prompt := fmt.Sprintf("%s You: ", Logo)
rl, err := internal.NewReadline(prompt)
if err != nil {
fmt.Printf("初始化 readline 失败: %v\n", err)
fmt.Println("回退到简单输入模式...")
simpleInteractiveMode(agentLoop, sessionKey)
return
}
defer rl.Close()
for {
line, err := rl.Readline()
if err != nil {
if err == internal.ErrInterrupt || err == internal.ErrEOF {
fmt.Println("\n再见!")
return
}
fmt.Printf("读取输入错误: %v\n", err)
continue
}
input := line
if input == "" {
continue
}
if input == "exit" || input == "quit" {
fmt.Println("再见!")
return
}
runWithStreaming(agentLoop, input, sessionKey)
}
}
func simpleInteractiveMode(agentLoop *agent.AgentLoop, sessionKey string) {
reader := internal.NewSimpleReader()
for {
fmt.Print(fmt.Sprintf("%s You: ", Logo))
line, err := reader.ReadString()
if err != nil {
if err == internal.ErrEOF {
fmt.Println("\n再见!")
return
}
fmt.Printf("读取输入错误: %v\n", err)
continue
}
input := line
if input == "" {
continue
}
if input == "exit" || input == "quit" {
fmt.Println("再见!")
return
}
runWithStreaming(agentLoop, input, sessionKey)
}
}
// runWithStreaming 尝试使用流式输出,如果 Provider 不支持则回退到普通模式
func runWithStreaming(agentLoop *agent.AgentLoop, input, sessionKey string) {
agentInstance := agentLoop.GetRegistry().GetDefaultAgent()
if agentInstance == nil {
fmt.Println("错误:无法获取 Agent 实例")
return
}
provider := agentInstance.Provider
ctx := context.Background()
// 判断是否支持流式
if sp, ok := provider.(providers.StreamingProvider); ok {
// 从 session 中获取历史消息
history := agentInstance.Sessions.GetHistory(sessionKey)
summary := agentInstance.Sessions.GetSummary(sessionKey)
// 使用 ContextBuilder 构建消息,包含历史
messages := agentInstance.ContextBuilder.BuildMessages(
history,
summary,
input,
nil, // media
"cli", // channel
sessionKey,
"", // senderID
"", // senderDisplayName
)
// 获取工具定义
toolDefs := agentInstance.Tools.ToProviderDefs()
// 启动 spinner显示 "思考中..."
spinner := internal.NewSpinner("思考中...")
spinner.Start()
fmt.Print("\n")
var result strings.Builder
var printedLen int
firstToken := true
_, err := sp.ChatStream(ctx, messages, toolDefs, agentInstance.Model, nil, func(accumulated string) {
// 检测到第一个 token 时,停止 spinner
if firstToken && len(accumulated) > 0 {
spinner.Stop()
firstToken = false
}
if len(accumulated) > printedLen {
fmt.Print(accumulated[printedLen:])
os.Stdout.Sync()
result.WriteString(accumulated[printedLen:])
printedLen = len(accumulated)
}
})
if err != nil {
fmt.Printf("流式调用错误: %v\n", err)
return
}
fmt.Println()
fmt.Println()
// 将用户消息和回复保存到 session
if result.Len() > 0 {
agentInstance.Sessions.AddMessage(sessionKey, "user", input)
agentInstance.Sessions.AddMessage(sessionKey, "assistant", result.String())
}
} else {
// 回退到普通模式
response, err := agentLoop.ProcessDirect(ctx, input, sessionKey)
if err != nil {
fmt.Printf("错误: %v\n", err)
return
}
fmt.Printf("\n%s %s\n\n", Logo, response)
}
}