Files
HxClaw/cmd/hxclaw/main.go

217 lines
5.2 KiB
Go
Raw Normal View History

package main
import (
"context"
"fmt"
"os"
"strings"
"github.com/hxclaw/hxclaw/cmd/hxclaw/internal"
"github.com/muesli/termenv"
"github.com/sipeed/picoclaw/pkg/agent"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
)
const Logo = "🦐"
func main() {
fmt.Printf("%s HxClaw - PicoClaw 增强版 CLI\n\n", Logo)
cfg, err := internal.LoadConfig()
if err != nil {
fmt.Fprintf(os.Stderr, "错误:加载配置失败: %v\n", err)
os.Exit(1)
}
logger.ConfigureFromEnv()
provider, modelID, err := providers.CreateProvider(cfg)
if err != nil {
fmt.Fprintf(os.Stderr, "错误:创建 Provider 失败: %v\n", err)
os.Exit(1)
}
if modelID != "" {
cfg.Agents.Defaults.ModelName = modelID
}
msgBus := bus.NewMessageBus()
defer msgBus.Close()
agentLoop := agent.NewAgentLoop(cfg, msgBus, provider)
defer agentLoop.Close()
startupInfo := agentLoop.GetStartupInfo()
logger.InfoCF("hxclaw", "HxClaw 已初始化",
map[string]any{
"tools_count": startupInfo["tools"].(map[string]any)["count"],
"skills_total": startupInfo["skills"].(map[string]any)["total"],
"skills_available": startupInfo["skills"].(map[string]any)["available"],
})
fmt.Printf("%s Interactive mode (Ctrl+C to exit)\n\n", Logo)
interactiveMode(agentLoop, "cli:default")
}
func interactiveMode(agentLoop *agent.AgentLoop, sessionKey string) {
prompt := fmt.Sprintf("%s You: ", Logo)
rl, err := internal.NewReadline(prompt)
if err != nil {
fmt.Printf("初始化 readline 失败: %v\n", err)
fmt.Println("回退到简单输入模式...")
simpleInteractiveMode(agentLoop, sessionKey)
return
}
defer rl.Close()
for {
line, err := rl.Readline()
if err != nil {
if err == internal.ErrInterrupt || err == internal.ErrEOF {
fmt.Println("\n再见!")
return
}
fmt.Printf("读取输入错误: %v\n", err)
continue
}
input := line
if input == "" {
continue
}
if input == "exit" || input == "quit" {
fmt.Println("再见!")
return
}
runWithStreaming(agentLoop, input, sessionKey)
}
}
func simpleInteractiveMode(agentLoop *agent.AgentLoop, sessionKey string) {
reader := internal.NewSimpleReader()
for {
fmt.Print(fmt.Sprintf("%s You: ", Logo))
line, err := reader.ReadString()
if err != nil {
if err == internal.ErrEOF {
fmt.Println("\n再见!")
return
}
fmt.Printf("读取输入错误: %v\n", err)
continue
}
input := line
if input == "" {
continue
}
if input == "exit" || input == "quit" {
fmt.Println("再见!")
return
}
runWithStreaming(agentLoop, input, sessionKey)
}
}
// runWithStreaming 尝试使用流式输出,如果 Provider 不支持则回退到普通模式
func runWithStreaming(agentLoop *agent.AgentLoop, input, sessionKey string) {
agentInstance := agentLoop.GetRegistry().GetDefaultAgent()
if agentInstance == nil {
fmt.Println("错误:无法获取 Agent 实例")
return
}
provider := agentInstance.Provider
ctx := context.Background()
// 判断是否支持流式
if sp, ok := provider.(providers.StreamingProvider); ok {
// 从 session 中获取历史消息
history := agentInstance.Sessions.GetHistory(sessionKey)
summary := agentInstance.Sessions.GetSummary(sessionKey)
// 使用 ContextBuilder 构建消息,包含历史
messages := agentInstance.ContextBuilder.BuildMessages(
history,
summary,
input,
nil, // media
"cli", // channel
sessionKey,
"", // senderID
"", // senderDisplayName
)
// 获取工具定义
toolDefs := agentInstance.Tools.ToProviderDefs()
// 启动 spinner显示 "思考中..."
spinner := internal.NewSpinner("思考中...")
spinner.Start()
fmt.Print("\n")
var result strings.Builder
var printedLen int
firstToken := true
_, err := sp.ChatStream(ctx, messages, toolDefs, agentInstance.Model, nil, func(accumulated string) {
if firstToken && len(accumulated) > 0 {
spinner.Stop()
firstToken = false
}
if len(accumulated) > printedLen {
newText := accumulated[printedLen:]
fmt.Print(newText)
os.Stdout.Sync()
result.WriteString(newText)
printedLen = len(accumulated)
}
})
if err != nil {
fmt.Printf("流式调用错误: %v\n", err)
return
}
if result.Len() > 0 {
allOutput := result.String()
rendered := internal.RenderMarkdown(allOutput)
if rendered != allOutput && rendered != "" {
// 计算流式输出的行数,清除
lines := strings.Count(allOutput, "\n") + 1
output := termenv.DefaultOutput()
output.CursorUp(1)
output.ClearLine()
output.ClearLines(lines)
fmt.Print(rendered)
fmt.Println()
fmt.Println()
} else {
fmt.Println()
fmt.Println()
}
agentInstance.Sessions.AddMessage(sessionKey, "user", input)
agentInstance.Sessions.AddMessage(sessionKey, "assistant", allOutput)
}
} else {
response, err := agentLoop.ProcessDirect(ctx, input, sessionKey)
if err != nil {
fmt.Printf("错误: %v\n", err)
return
}
rendered := internal.RenderMarkdown(response)
if rendered != "" && rendered != response {
fmt.Printf("\n%s\n\n", rendered)
} else {
fmt.Printf("\n%s %s\n\n", Logo, response)
}
}
}