Files
yoyo/internal/onboard/onboard.go

306 lines
7.6 KiB
Go
Raw Normal View History

package onboard
import (
"errors"
"fmt"
"os"
"path/filepath"
"strconv"
"strings"
"charm.land/huh/v2"
"github.com/titor/fanyi/internal/config"
"github.com/titor/fanyi/internal/lang"
)
// RunOnboard 启动配置向导
func RunOnboard(force bool) error {
configPath := config.GetUserConfigPath()
// 检查配置文件是否存在
if _, err := os.Stat(configPath); err == nil && !force {
var overwrite bool
form := huh.NewForm(
huh.NewGroup(
huh.NewConfirm().
Title("检测到配置文件已存在,是否要重新配置?").
Affirmative("是").
Negative("否").
Value(&overwrite),
),
)
if err := form.Run(); err != nil {
if errors.Is(err, huh.ErrUserAborted) {
fmt.Println("\n你已取消本次配置")
return nil
}
return fmt.Errorf("用户输入错误: %w", err)
}
if !overwrite {
fmt.Println("配置已取消。")
return nil
}
}
// 步骤1: 选择主要厂商
var providerName string
providerForm := huh.NewForm(
huh.NewGroup(
huh.NewSelect[string]().
Title("请选择要使用的翻译服务提供商").
Options(
huh.NewOption("硅基流动 (推荐,免费额度)", "siliconflow"),
huh.NewOption("火山引擎", "volcano"),
huh.NewOption("国家超算", "national"),
huh.NewOption("Qwen (通义千问)", "qwen"),
huh.NewOption("OpenAI兼容格式", "openai"),
).
Value(&providerName),
),
)
if err := providerForm.Run(); err != nil {
if errors.Is(err, huh.ErrUserAborted) {
fmt.Println("\n你已取消本次配置")
return nil
}
return fmt.Errorf("选择厂商失败: %w", err)
}
// 步骤2: 配置主要厂商
providerConfig, err := ConfigureProviderHuh(providerName)
if err != nil {
if errors.Is(err, huh.ErrUserAborted) {
fmt.Println("\n你已取消本次配置")
return nil
}
return fmt.Errorf("配置厂商失败: %w", err)
}
// 步骤3: 全局设置
globalConfig, err := GlobalSettingsHuh()
if err != nil {
if errors.Is(err, huh.ErrUserAborted) {
fmt.Println("\n你已取消本次配置")
return nil
}
return fmt.Errorf("全局设置失败: %w", err)
}
// 步骤4: 确认并保存配置
configData := BuildConfig(providerName, providerConfig, globalConfig)
var confirmSave bool
confirmForm := huh.NewForm(
huh.NewGroup(
huh.NewConfirm().
Title("确认保存配置?").
Description(fmt.Sprintf("配置文件将保存到: %s", configPath)).
Affirmative("是,保存").
Negative("否,取消").
Value(&confirmSave),
),
)
if err := confirmForm.Run(); err != nil {
if errors.Is(err, huh.ErrUserAborted) {
fmt.Println("\n你已取消本次配置")
return nil
}
return fmt.Errorf("用户输入错误: %w", err)
}
if !confirmSave {
fmt.Println("配置已取消。")
return nil
}
if err := SaveConfig(configData, configPath); err != nil {
return fmt.Errorf("保存配置失败: %w", err)
}
fmt.Printf("\n配置完成! 配置文件已保存到: %s\n", configPath)
fmt.Println("\n您现在可以使用以下命令进行翻译:")
fmt.Println(" yoyo \"Hello world\"")
fmt.Println(" yoyo --lang=cn \"Hello world\"")
fmt.Println("\n更多帮助请运行: yoyo --help")
return nil
}
// GlobalConfig 全局设置配置
type GlobalConfig struct {
DefaultProvider string
DefaultModel string
Timeout int
DefaultSourceLang string
DefaultTargetLang string
}
// ConfigureProviderHuh 使用 huh 配置厂商
func ConfigureProviderHuh(providerName string) (config.ProviderConfig, error) {
defaults := map[string]config.ProviderConfig{
"siliconflow": {
APIHost: "https://api.siliconflow.cn/v1",
Model: "siliconflow-base",
Enabled: true,
},
"volcano": {
APIHost: "https://api.volcengine.com/v1",
Model: "volcano-chat",
Enabled: true,
},
"national": {
APIHost: "https://api.nsc.gov.cn/v1",
Model: "nsc-base",
Enabled: true,
},
"qwen": {
APIHost: "https://dashscope.aliyuncs.com/compatible-mode/v1",
Model: "qwen-turbo",
Enabled: true,
},
"openai": {
APIHost: "https://api.openai.com/v1",
Model: "gpt-3.5-turbo",
Enabled: true,
},
}
defaultConfig := defaults[providerName]
cfg := config.ProviderConfig{
APIHost: defaultConfig.APIHost,
Model: defaultConfig.Model,
Enabled: defaultConfig.Enabled,
}
var apiKey string
apiKeyForm := huh.NewForm(
huh.NewGroup(
huh.NewInput().
Title(fmt.Sprintf("请输入 %s 的API密钥", providerName)).
Description("API密钥用于身份验证将存储在配置文件中").
Value(&apiKey).
Validate(func(str string) error {
if strings.TrimSpace(str) == "" {
return fmt.Errorf("API密钥不能为空")
}
return nil
}),
huh.NewInput().
Title("API HOST").
Description("直接回车使用默认值").
Value(&cfg.APIHost).
Placeholder(defaultConfig.APIHost),
huh.NewInput().
Title("默认模型").
Description("直接回车使用默认值").
Value(&cfg.Model).
Placeholder(defaultConfig.Model),
),
)
if err := apiKeyForm.Run(); err != nil {
return config.ProviderConfig{}, err
}
cfg.APIKey = apiKey
return cfg, nil
}
// GlobalSettingsHuh 使用 huh 进行全局设置
func GlobalSettingsHuh() (*GlobalConfig, error) {
cfg := &GlobalConfig{
DefaultProvider: "siliconflow",
DefaultModel: "siliconflow-base",
Timeout: 30,
DefaultSourceLang: "auto",
DefaultTargetLang: "zh-CN",
}
targetLangOptions := lang.GetCommonLanguages()
var options []huh.Option[string]
for _, code := range targetLangOptions {
options = append(options, huh.NewOption(
fmt.Sprintf("%s (%s)", code, lang.GetLanguageName(code)),
code,
))
}
var timeoutStr string
form := huh.NewForm(
huh.NewGroup(
huh.NewSelect[string]().
Title("请选择默认目标语言").
Options(options...).
Value(&cfg.DefaultTargetLang),
huh.NewInput().
Title("API超时时间(秒)").
Value(&timeoutStr).
Placeholder("30"),
),
)
if err := form.Run(); err != nil {
return nil, err
}
if timeout := parseIntOrDefault(timeoutStr, 30); timeout > 0 {
cfg.Timeout = timeout
}
return cfg, nil
}
// BuildConfig 构建配置对象
func BuildConfig(providerName string, providerConfig config.ProviderConfig, globalConfig *GlobalConfig) *config.Config {
providers := map[string]config.ProviderConfig{
providerName: providerConfig,
}
prompts := map[string]string{
"technical": "你是一位专业的技术翻译,请准确翻译以下技术文档,保持专业术语的准确性。",
"creative": "你是一位富有创造力的翻译家,请用优美流畅的语言翻译以下内容。",
"academic": "你是一位学术翻译专家,请用严谨的学术语言翻译以下内容。",
"simple": "请用简单易懂的语言翻译以下内容。",
}
return &config.Config{
DefaultProvider: providerName,
DefaultModel: providerConfig.Model,
Timeout: globalConfig.Timeout,
DefaultSourceLang: globalConfig.DefaultSourceLang,
DefaultTargetLang: globalConfig.DefaultTargetLang,
Providers: providers,
Prompts: prompts,
}
}
// SaveConfig 保存配置文件
func SaveConfig(cfg *config.Config, path string) error {
dir := filepath.Dir(path)
if err := os.MkdirAll(dir, 0755); err != nil {
return fmt.Errorf("创建配置目录失败: %w", err)
}
loader := &config.YAMLConfigLoader{}
return loader.Save(cfg, path)
}
// parseIntOrDefault 解析整数,失败时返回默认值
func parseIntOrDefault(s string, defaultValue int) int {
s = strings.TrimSpace(s)
if s == "" {
return defaultValue
}
result, err := strconv.Atoi(s)
if err != nil {
return defaultValue
}
if result <= 0 {
return defaultValue
}
return result
}