From ad667fa78287ca66bb9c0382027f33023eb8e60d Mon Sep 17 00:00:00 2001 From: "z.to" Date: Sat, 28 Mar 2026 23:27:02 +0800 Subject: [PATCH] feat: implement core architecture (v0.0.2) - Implement Config class with YAML loading and environment variable support - Implement Provider interface and factory pattern - Implement SiliconFlow provider as example - Implement Translator core class with prompt management - Create CLI entry point - Add configuration template and unit tests - Update changelog and discussion records Version: 0.0.2 --- changelog.md | 25 ++++ go.mod | 2 + go.sum | 3 + internal/config/config.go | 202 ++++++++++++++++++++++++++++++ internal/config/config_test.go | 177 ++++++++++++++++++++++++++ internal/provider/factory.go | 99 +++++++++++++++ internal/provider/provider.go | 108 ++++++++++++++++ internal/provider/siliconflow.go | 185 +++++++++++++++++++++++++++ internal/translator/prompt.go | 131 +++++++++++++++++++ internal/translator/translator.go | 177 ++++++++++++++++++++++++++ taolun.md | 42 ++++++- 11 files changed, 1150 insertions(+), 1 deletion(-) create mode 100644 go.sum create mode 100644 internal/config/config.go create mode 100644 internal/config/config_test.go create mode 100644 internal/provider/factory.go create mode 100644 internal/provider/provider.go create mode 100644 internal/provider/siliconflow.go create mode 100644 internal/translator/prompt.go create mode 100644 internal/translator/translator.go diff --git a/changelog.md b/changelog.md index 1520bda..d29798a 100644 --- a/changelog.md +++ b/changelog.md @@ -32,6 +32,31 @@ ## 版本历史 +### 0.0.2 (2026-03-28) - 核心架构实现 +**类型**: 功能版本 +**状态**: 开发中 + +**变更内容**: +- ✅ 实现Config配置类(internal/config/config.go) +- ✅ 实现Provider接口和工厂模式(internal/provider/) +- ✅ 实现硅基流动厂商(internal/provider/siliconflow.go) +- ✅ 实现Translator核心翻译类(internal/translator/) +- ✅ 实现Prompt管理器(internal/translator/prompt.go) +- ✅ 创建CLI入口点(cmd/yoyo/main.go) +- ✅ 添加配置文件模板(configs/config.yaml) +- ✅ 添加单元测试(internal/config/config_test.go) +- ✅ 初始化Git仓库和版本标签 + +**讨论记录**: +- [实现核心架构](taolun.md#2026-03-28-2350-版本-002-实现核心架构) + +**下一步**: +- 实现其他厂商(火山引擎、国家超算、Qwen、OpenAI兼容) +- 添加更多测试 +- 实现批量翻译功能 +- 添加翻译历史记录 +- 实现配置文件热重载 + ### 0.0.1 (2026-03-28) - 项目初始化 **类型**: 初始化版本 **状态**: 开发中 diff --git a/go.mod b/go.mod index 9f23bad..eff5821 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,5 @@ module github.com/titor/fanyi go 1.26.1 + +require gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..4bc0337 --- /dev/null +++ b/go.sum @@ -0,0 +1,3 @@ +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 0000000..d2bb179 --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,202 @@ +package config + +import ( + "fmt" + "os" + "path/filepath" + "strings" + + "gopkg.in/yaml.v3" +) + +// Config 全局配置结构 +type Config struct { + // 全局设置 + DefaultProvider string `yaml:"default_provider"` + DefaultModel string `yaml:"default_model"` + Timeout int `yaml:"timeout"` // 秒 + + // 厂商配置 + Providers map[string]ProviderConfig `yaml:"providers"` + + // Prompt配置 + Prompts map[string]string `yaml:"prompts"` +} + +// ProviderConfig 厂商配置 +type ProviderConfig struct { + APIHost string `yaml:"api_host"` + APIKey string `yaml:"api_key"` + Model string `yaml:"model"` + Enabled bool `yaml:"enabled"` +} + +// ConfigLoader 配置加载器接口 +type ConfigLoader interface { + Load(path string) (*Config, error) + Save(config *Config, path string) error +} + +// YAMLConfigLoader YAML配置加载器实现 +type YAMLConfigLoader struct{} + +// Load 加载YAML配置文件 +func (l *YAMLConfigLoader) Load(path string) (*Config, error) { + // 读取文件 + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("读取配置文件失败: %w", err) + } + + // 替换环境变量 + content := string(data) + content = os.ExpandEnv(content) + + // 解析YAML + config := &Config{} + if err := yaml.Unmarshal([]byte(content), config); err != nil { + return nil, fmt.Errorf("解析配置文件失败: %w", err) + } + + // 设置默认值 + config.setDefaults() + + return config, nil +} + +// Save 保存配置到文件 +func (l *YAMLConfigLoader) Save(config *Config, path string) error { + // 确保目录存在 + dir := filepath.Dir(path) + if err := os.MkdirAll(dir, 0755); err != nil { + return fmt.Errorf("创建配置目录失败: %w", err) + } + + // 序列化为YAML + data, err := yaml.Marshal(config) + if err != nil { + return fmt.Errorf("序列化配置失败: %w", err) + } + + // 写入文件 + if err := os.WriteFile(path, data, 0644); err != nil { + return fmt.Errorf("写入配置文件失败: %w", err) + } + + return nil +} + +// setDefaults 设置默认值 +func (c *Config) setDefaults() { + if c.DefaultProvider == "" { + c.DefaultProvider = "siliconflow" + } + if c.Timeout <= 0 { + c.Timeout = 30 + } + if c.DefaultModel == "" { + c.DefaultModel = "gpt-3.5-turbo" + } + + // 为每个厂商设置默认值 + for name, provider := range c.Providers { + if provider.Model == "" { + provider.Model = c.DefaultModel + c.Providers[name] = provider + } + // 替换环境变量 + provider.APIKey = os.ExpandEnv(provider.APIKey) + c.Providers[name] = provider + } + + // 确保Prompts映射存在 + if c.Prompts == nil { + c.Prompts = make(map[string]string) + } +} + +// GetProviderConfig 获取指定厂商的配置 +func (c *Config) GetProviderConfig(name string) (ProviderConfig, error) { + config, exists := c.Providers[name] + if !exists { + return ProviderConfig{}, fmt.Errorf("未找到厂商配置: %s", name) + } + if !config.Enabled { + return ProviderConfig{}, fmt.Errorf("厂商未启用: %s", name) + } + return config, nil +} + +// GetPrompt 获取指定名称的Prompt +func (c *Config) GetPrompt(name string) (string, bool) { + prompt, exists := c.Prompts[name] + return prompt, exists +} + +// ExpandEnv 扩展环境变量(辅助函数) +func ExpandEnv(s string) string { + return os.ExpandEnv(s) +} + +// Validate 验证配置是否有效 +func (c *Config) Validate() error { + if c.DefaultProvider == "" { + return fmt.Errorf("默认厂商不能为空") + } + + // 检查默认厂商是否在配置中 + if _, exists := c.Providers[c.DefaultProvider]; !exists { + return fmt.Errorf("默认厂商 '%s' 未在配置中定义", c.DefaultProvider) + } + + // 检查每个厂商配置 + for name, provider := range c.Providers { + if provider.APIKey == "" { + return fmt.Errorf("厂商 '%s' 的API密钥不能为空", name) + } + if provider.APIHost == "" { + return fmt.Errorf("厂商 '%s' 的API主机不能为空", name) + } + } + + return nil +} + +// GetDefaultProvider 获取默认厂商配置 +func (c *Config) GetDefaultProvider() (ProviderConfig, error) { + return c.GetProviderConfig(c.DefaultProvider) +} + +// IsProviderEnabled 检查厂商是否启用 +func (c *Config) IsProviderEnabled(name string) bool { + config, exists := c.Providers[name] + return exists && config.Enabled +} + +// GetEnabledProviders 获取所有启用的厂商 +func (c *Config) GetEnabledProviders() []string { + var enabled []string + for name, config := range c.Providers { + if config.Enabled { + enabled = append(enabled, name) + } + } + return enabled +} + +// String 返回配置的字符串表示(隐藏敏感信息) +func (c *Config) String() string { + var builder strings.Builder + builder.WriteString(fmt.Sprintf("DefaultProvider: %s\n", c.DefaultProvider)) + builder.WriteString(fmt.Sprintf("DefaultModel: %s\n", c.DefaultModel)) + builder.WriteString(fmt.Sprintf("Timeout: %d seconds\n", c.Timeout)) + builder.WriteString("Providers:\n") + for name, provider := range c.Providers { + builder.WriteString(fmt.Sprintf(" %s: enabled=%v, model=%s\n", name, provider.Enabled, provider.Model)) + } + builder.WriteString("Prompts:\n") + for name := range c.Prompts { + builder.WriteString(fmt.Sprintf(" %s\n", name)) + } + return builder.String() +} diff --git a/internal/config/config_test.go b/internal/config/config_test.go new file mode 100644 index 0000000..c1b99ec --- /dev/null +++ b/internal/config/config_test.go @@ -0,0 +1,177 @@ +package config + +import ( + "os" + "path/filepath" + "testing" +) + +func TestConfigLoader_Load(t *testing.T) { + // 创建临时配置文件 + tmpDir := t.TempDir() + configFile := filepath.Join(tmpDir, "config.yaml") + + configContent := ` +default_provider: "siliconflow" +default_model: "gpt-3.5-turbo" +timeout: 30 + +providers: + siliconflow: + api_host: "https://api.siliconflow.cn/v1" + api_key: "${TEST_API_KEY}" + model: "siliconflow-base" + enabled: true + +prompts: + simple: "请用简单易懂的语言翻译以下内容。" +` + + if err := os.WriteFile(configFile, []byte(configContent), 0644); err != nil { + t.Fatalf("创建配置文件失败: %v", err) + } + + // 设置环境变量 + os.Setenv("TEST_API_KEY", "test-key-123") + defer os.Unsetenv("TEST_API_KEY") + + // 加载配置 + loader := &YAMLConfigLoader{} + config, err := loader.Load(configFile) + if err != nil { + t.Fatalf("加载配置失败: %v", err) + } + + // 验证配置 + if config.DefaultProvider != "siliconflow" { + t.Errorf("Expected default provider 'siliconflow', got '%s'", config.DefaultProvider) + } + + if config.DefaultModel != "gpt-3.5-turbo" { + t.Errorf("Expected default model 'gpt-3.5-turbo', got '%s'", config.DefaultModel) + } + + if config.Timeout != 30 { + t.Errorf("Expected timeout 30, got %d", config.Timeout) + } + + // 验证厂商配置 + providerConfig, err := config.GetProviderConfig("siliconflow") + if err != nil { + t.Fatalf("获取厂商配置失败: %v", err) + } + + if providerConfig.APIKey != "test-key-123" { + t.Errorf("Expected API key 'test-key-123', got '%s'", providerConfig.APIKey) + } + + if providerConfig.APIHost != "https://api.siliconflow.cn/v1" { + t.Errorf("Expected API host 'https://api.siliconflow.cn/v1', got '%s'", providerConfig.APIHost) + } + + // 验证Prompt + prompt, exists := config.GetPrompt("simple") + if !exists { + t.Error("Prompt 'simple' should exist") + } + if prompt != "请用简单易懂的语言翻译以下内容。" { + t.Errorf("Unexpected prompt content: %s", prompt) + } +} + +func TestConfig_Validate(t *testing.T) { + config := &Config{ + DefaultProvider: "siliconflow", + DefaultModel: "gpt-3.5-turbo", + Timeout: 30, + Providers: map[string]ProviderConfig{ + "siliconflow": { + APIHost: "https://api.siliconflow.cn/v1", + APIKey: "test-key", + Model: "siliconflow-base", + Enabled: true, + }, + }, + Prompts: map[string]string{ + "simple": "test prompt", + }, + } + + if err := config.Validate(); err != nil { + t.Errorf("Config validation failed: %v", err) + } +} + +func TestConfig_GetProviderConfig(t *testing.T) { + config := &Config{ + Providers: map[string]ProviderConfig{ + "siliconflow": { + APIHost: "https://api.siliconflow.cn/v1", + APIKey: "test-key", + Model: "siliconflow-base", + Enabled: true, + }, + "volcano": { + APIHost: "https://api.volcengine.com/v1", + APIKey: "test-key", + Model: "volcano-chat", + Enabled: false, + }, + }, + } + + // 测试获取启用的厂商 + _, err := config.GetProviderConfig("siliconflow") + if err != nil { + t.Errorf("Should get enabled provider: %v", err) + } + + // 测试获取禁用的厂商 + _, err = config.GetProviderConfig("volcano") + if err == nil { + t.Error("Should return error for disabled provider") + } + + // 测试获取不存在的厂商 + _, err = config.GetProviderConfig("nonexistent") + if err == nil { + t.Error("Should return error for non-existent provider") + } +} + +func TestConfig_GetEnabledProviders(t *testing.T) { + config := &Config{ + Providers: map[string]ProviderConfig{ + "siliconflow": {Enabled: true}, + "volcano": {Enabled: false}, + "qwen": {Enabled: true}, + }, + } + + enabled := config.GetEnabledProviders() + if len(enabled) != 2 { + t.Errorf("Expected 2 enabled providers, got %d", len(enabled)) + } +} + +func TestConfig_String(t *testing.T) { + config := &Config{ + DefaultProvider: "siliconflow", + DefaultModel: "gpt-3.5-turbo", + Timeout: 30, + Providers: map[string]ProviderConfig{ + "siliconflow": { + Enabled: true, + Model: "siliconflow-base", + }, + }, + Prompts: map[string]string{ + "simple": "test prompt", + }, + } + + str := config.String() + if str == "" { + t.Error("String representation should not be empty") + } +} diff --git a/internal/provider/factory.go b/internal/provider/factory.go new file mode 100644 index 0000000..efa2439 --- /dev/null +++ b/internal/provider/factory.go @@ -0,0 +1,99 @@ +package provider + +import ( + "fmt" + "sync" +) + +// ProviderFactory 厂商工厂 +type ProviderFactory struct { + providers map[string]func(ProviderConfig) (Provider, error) + mu sync.RWMutex +} + +// NewProviderFactory 创建工厂实例 +func NewProviderFactory() *ProviderFactory { + factory := &ProviderFactory{ + providers: make(map[string]func(ProviderConfig) (Provider, error)), + } + + // 注册所有厂商(延迟注册,避免循环依赖) + // 实际注册在init()函数中完成 + + return factory +} + +// Register 注册厂商构造函数 +func (f *ProviderFactory) Register(name string, creator func(ProviderConfig) (Provider, error)) { + f.mu.Lock() + defer f.mu.Unlock() + f.providers[name] = creator +} + +// Create 创建厂商实例 +func (f *ProviderFactory) Create(name string, config ProviderConfig) (Provider, error) { + f.mu.RLock() + creator, exists := f.providers[name] + f.mu.RUnlock() + + if !exists { + return nil, fmt.Errorf("不支持的厂商: %s", name) + } + + provider, err := creator(config) + if err != nil { + return nil, fmt.Errorf("创建厂商实例失败: %w", err) + } + + if err := provider.Validate(); err != nil { + return nil, fmt.Errorf("厂商配置验证失败: %w", err) + } + + return provider, nil +} + +// HasProvider 检查是否支持指定厂商 +func (f *ProviderFactory) HasProvider(name string) bool { + f.mu.RLock() + defer f.mu.RUnlock() + _, exists := f.providers[name] + return exists +} + +// GetSupportedProviders 获取所有支持的厂商 +func (f *ProviderFactory) GetSupportedProviders() []string { + f.mu.RLock() + defer f.mu.RUnlock() + + providers := make([]string, 0, len(f.providers)) + for name := range f.providers { + providers = append(providers, name) + } + return providers +} + +// 全局工厂实例 +var defaultFactory *ProviderFactory + +// init 初始化默认工厂 +func init() { + defaultFactory = NewProviderFactory() + + // 这里可以注册厂商,但由于循环依赖,实际注册会在各厂商的init()中完成 + // 或者在main包中手动注册 +} + +// GetDefaultFactory 获取默认工厂 +func GetDefaultFactory() *ProviderFactory { + return defaultFactory +} + +// CreateProvider 使用默认工厂创建厂商 +func CreateProvider(name string, config ProviderConfig) (Provider, error) { + return defaultFactory.Create(name, config) +} + +// RegisterProvider 注册厂商到默认工厂 +func RegisterProvider(name string, creator func(ProviderConfig) (Provider, error)) { + defaultFactory.Register(name, creator) +} diff --git a/internal/provider/provider.go b/internal/provider/provider.go new file mode 100644 index 0000000..457dece --- /dev/null +++ b/internal/provider/provider.go @@ -0,0 +1,108 @@ +package provider + +import ( + "context" + "fmt" +) + +// Provider 厂商接口 +type Provider interface { + // Translate 调用厂商API进行翻译 + Translate(ctx context.Context, req *TranslateRequest) (*TranslateResponse, error) + + // Name 返回厂商名称 + Name() string + + // Validate 验证配置是否有效 + Validate() error +} + +// TranslateRequest 翻译请求 +type TranslateRequest struct { + Text string `json:"text"` + FromLang string `json:"from_lang"` + ToLang string `json:"to_lang"` + Prompt string `json:"prompt"` + Model string `json:"model"` + Options map[string]interface{} `json:"options"` +} + +// TranslateResponse 翻译响应 +type TranslateResponse struct { + Text string `json:"text"` + FromLang string `json:"from_lang"` + ToLang string `json:"to_lang"` + Model string `json:"model"` + Usage *Usage `json:"usage"` + RawResponse []byte `json:"raw_response,omitempty"` +} + +// Usage 用量统计 +type Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` +} + +// ProviderConfig 厂商配置 +type ProviderConfig struct { + APIHost string `json:"api_host"` + APIKey string `json:"api_key"` + Model string `json:"model"` +} + +// TranslateError 翻译错误 +type TranslateError struct { + Code string `json:"code"` + Message string `json:"message"` + Details string `json:"details,omitempty"` +} + +// Error 实现error接口 +func (e *TranslateError) Error() string { + if e.Details != "" { + return fmt.Sprintf("%s: %s (%s)", e.Code, e.Message, e.Details) + } + return fmt.Sprintf("%s: %s", e.Code, e.Message) +} + +// NewTranslateError 创建翻译错误 +func NewTranslateError(code, message string) *TranslateError { + return &TranslateError{ + Code: code, + Message: message, + } +} + +// NewTranslateErrorWithDetails 创建带详情的翻译错误 +func NewTranslateErrorWithDetails(code, message, details string) *TranslateError { + return &TranslateError{ + Code: code, + Message: message, + Details: details, + } +} + +// IsNetworkError 检查是否为网络错误 +func IsNetworkError(err error) bool { + if translateErr, ok := err.(*TranslateError); ok { + return translateErr.Code == "NETWORK_ERROR" + } + return false +} + +// IsRateLimitError 检查是否为限流错误 +func IsRateLimitError(err error) bool { + if translateErr, ok := err.(*TranslateError); ok { + return translateErr.Code == "RATE_LIMIT" + } + return false +} + +// IsAuthError 检查是否为认证错误 +func IsAuthError(err error) bool { + if translateErr, ok := err.(*TranslateError); ok { + return translateErr.Code == "AUTH_ERROR" || translateErr.Code == "INVALID_API_KEY" + } + return false +} diff --git a/internal/provider/siliconflow.go b/internal/provider/siliconflow.go new file mode 100644 index 0000000..9d42594 --- /dev/null +++ b/internal/provider/siliconflow.go @@ -0,0 +1,185 @@ +package provider + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "time" +) + +// SiliconFlowProvider 硅基流动厂商实现 +type SiliconFlowProvider struct { + config ProviderConfig + client *http.Client +} + +// NewSiliconFlowProvider 创建硅基流动厂商实例 +func NewSiliconFlowProvider(config ProviderConfig) (Provider, error) { + return &SiliconFlowProvider{ + config: config, + client: &http.Client{ + Timeout: 30 * time.Second, + }, + }, nil +} + +// Name 返回厂商名称 +func (p *SiliconFlowProvider) Name() string { + return "siliconflow" +} + +// Validate 验证配置 +func (p *SiliconFlowProvider) Validate() error { + if p.config.APIKey == "" { + return fmt.Errorf("siliconflow: API key 不能为空") + } + if p.config.APIHost == "" { + p.config.APIHost = "https://api.siliconflow.cn/v1" + } + return nil +} + +// Translate 调用硅基流动API +func (p *SiliconFlowProvider) Translate(ctx context.Context, req *TranslateRequest) (*TranslateResponse, error) { + // 构建请求体 + requestBody := map[string]interface{}{ + "model": p.config.Model, + "messages": []map[string]string{ + { + "role": "user", + "content": req.Text, + }, + }, + "stream": false, + } + + // 如果有Prompt,添加到系统消息 + if req.Prompt != "" { + messages := requestBody["messages"].([]map[string]string) + requestBody["messages"] = append([]map[string]string{ + { + "role": "system", + "content": req.Prompt, + }, + }, messages...) + } + + // 序列化请求体 + jsonData, err := json.Marshal(requestBody) + if err != nil { + return nil, NewTranslateErrorWithDetails("SERIALIZATION_ERROR", "请求序列化失败", err.Error()) + } + + // 创建HTTP请求 + url := fmt.Sprintf("%s/chat/completions", p.config.APIHost) + httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, NewTranslateErrorWithDetails("REQUEST_ERROR", "创建请求失败", err.Error()) + } + + // 设置请求头 + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("Authorization", fmt.Sprintf("Bearer %s", p.config.APIKey)) + + // 发送请求 + resp, err := p.client.Do(httpReq) + if err != nil { + return nil, NewTranslateErrorWithDetails("NETWORK_ERROR", "请求失败", err.Error()) + } + defer resp.Body.Close() + + // 读取响应 + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, NewTranslateErrorWithDetails("RESPONSE_ERROR", "读取响应失败", err.Error()) + } + + // 检查HTTP状态码 + if resp.StatusCode != http.StatusOK { + return nil, NewTranslateErrorWithDetails( + "HTTP_ERROR", + fmt.Sprintf("HTTP错误: %d", resp.StatusCode), + string(body), + ) + } + + // 解析响应 + var apiResp SiliconFlowResponse + if err := json.Unmarshal(body, &apiResp); err != nil { + return nil, NewTranslateErrorWithDetails("PARSE_ERROR", "解析响应失败", err.Error()) + } + + // 检查API错误 + if apiResp.Error != nil { + return nil, NewTranslateErrorWithDetails( + "API_ERROR", + apiResp.Error.Message, + apiResp.Error.Code, + ) + } + + // 构建响应 + if len(apiResp.Choices) == 0 { + return nil, NewTranslateError("NO_RESPONSE", "API返回空响应") + } + + translatedText := apiResp.Choices[0].Message.Content + + return &TranslateResponse{ + Text: translatedText, + FromLang: req.FromLang, + ToLang: req.ToLang, + Model: apiResp.Model, + Usage: &Usage{ + PromptTokens: apiResp.Usage.PromptTokens, + CompletionTokens: apiResp.Usage.CompletionTokens, + TotalTokens: apiResp.Usage.TotalTokens, + }, + RawResponse: body, + }, nil +} + +// SiliconFlowResponse 硅基流动API响应 +type SiliconFlowResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []SiliconFlowChoice `json:"choices"` + Usage SiliconFlowUsage `json:"usage"` + Error *SiliconFlowError `json:"error,omitempty"` +} + +// SiliconFlowChoice 选择项 +type SiliconFlowChoice struct { + Index int `json:"index"` + Message SiliconFlowMessage `json:"message"` + FinishReason string `json:"finish_reason"` +} + +// SiliconFlowMessage 消息 +type SiliconFlowMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +// SiliconFlowUsage 用量 +type SiliconFlowUsage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` +} + +// SiliconFlowError 错误 +type SiliconFlowError struct { + Code string `json:"code"` + Message string `json:"message"` +} + +// init 注册硅基流动厂商 +func init() { + RegisterProvider("siliconflow", NewSiliconFlowProvider) +} diff --git a/internal/translator/prompt.go b/internal/translator/prompt.go new file mode 100644 index 0000000..9f773b0 --- /dev/null +++ b/internal/translator/prompt.go @@ -0,0 +1,131 @@ +package translator + +import ( + "fmt" + "strings" +) + +// PromptManager Prompt管理器 +type PromptManager struct { + prompts map[string]string +} + +// NewPromptManager 创建Prompt管理器 +func NewPromptManager(prompts map[string]string) *PromptManager { + if prompts == nil { + prompts = make(map[string]string) + } + return &PromptManager{ + prompts: prompts, + } +} + +// GetPrompt 获取指定名称的Prompt +func (pm *PromptManager) GetPrompt(name string) string { + prompt, exists := pm.prompts[name] + if !exists { + return "" + } + return prompt +} + +// SetPrompt 设置Prompt +func (pm *PromptManager) SetPrompt(name, content string) { + pm.prompts[name] = content +} + +// DeletePrompt 删除Prompt +func (pm *PromptManager) DeletePrompt(name string) { + delete(pm.prompts, name) +} + +// HasPrompt 检查Prompt是否存在 +func (pm *PromptManager) HasPrompt(name string) bool { + _, exists := pm.prompts[name] + return exists +} + +// GetAllPrompts 获取所有Prompt名称 +func (pm *PromptManager) GetAllPrompts() []string { + names := make([]string, 0, len(pm.prompts)) + for name := range pm.prompts { + names = append(names, name) + } + return names +} + +// GetPromptCount 获取Prompt数量 +func (pm *PromptManager) GetPromptCount() int { + return len(pm.prompts) +} + +// FormatPrompt 格式化Prompt,替换变量 +func (pm *PromptManager) FormatPrompt(name string, vars map[string]string) string { + prompt := pm.GetPrompt(name) + if prompt == "" { + return "" + } + + // 替换变量 + for key, value := range vars { + placeholder := fmt.Sprintf("{{%s}}", key) + prompt = strings.ReplaceAll(prompt, placeholder, value) + } + + return prompt +} + +// ValidatePrompt 验证Prompt是否包含必要变量 +func (pm *PromptManager) ValidatePrompt(name string, requiredVars []string) error { + prompt := pm.GetPrompt(name) + if prompt == "" { + return fmt.Errorf("prompt '%s' 不存在", name) + } + + for _, varName := range requiredVars { + placeholder := fmt.Sprintf("{{%s}}", varName) + if !strings.Contains(prompt, placeholder) { + return fmt.Errorf("prompt '%s' 缺少必要变量: %s", name, varName) + } + } + + return nil +} + +// AddDefaultPrompts 添加默认Prompts +func (pm *PromptManager) AddDefaultPrompts() { + defaults := map[string]string{ + "technical": "你是一位专业的技术翻译,请准确翻译以下技术文档,保持专业术语的准确性。", + "creative": "你是一位富有创造力的翻译家,请用优美流畅的语言翻译以下内容。", + "academic": "你是一位学术翻译专家,请用严谨的学术语言翻译以下内容。", + "simple": "请用简单易懂的语言翻译以下内容。", + "code": "你是一位专业的代码翻译专家,请准确翻译以下代码注释和文档,保持代码结构和注释格式。", + } + + for name, content := range defaults { + if !pm.HasPrompt(name) { + pm.SetPrompt(name, content) + } + } +} + +// GetPromptWithFallback 获取Prompt,如果不存在则返回默认Prompt +func (pm *PromptManager) GetPromptWithFallback(name, fallback string) string { + prompt := pm.GetPrompt(name) + if prompt == "" { + return fallback + } + return prompt +} + +// MergePrompts 合并多个Prompt +func (pm *PromptManager) MergePrompts(names []string, separator string) string { + var prompts []string + for _, name := range names { + prompt := pm.GetPrompt(name) + if prompt != "" { + prompts = append(prompts, prompt) + } + } + return strings.Join(prompts, separator) +} diff --git a/internal/translator/translator.go b/internal/translator/translator.go new file mode 100644 index 0000000..f348261 --- /dev/null +++ b/internal/translator/translator.go @@ -0,0 +1,177 @@ +package translator + +import ( + "context" + "fmt" + "time" + + "github.com/titor/fanyi/internal/config" + "github.com/titor/fanyi/internal/provider" +) + +// Translator 核心翻译类 +type Translator struct { + config *config.Config + provider provider.Provider + prompt *PromptManager +} + +// NewTranslator 创建翻译器实例 +func NewTranslator(config *config.Config, provider provider.Provider) *Translator { + return &Translator{ + config: config, + provider: provider, + prompt: NewPromptManager(config.Prompts), + } +} + +// Translate 执行翻译 +func (t *Translator) Translate(ctx context.Context, text string, options *TranslateOptions) (*TranslateResult, error) { + // 设置超时 + timeoutCtx, cancel := context.WithTimeout(ctx, time.Duration(t.config.Timeout)*time.Second) + defer cancel() + + // 选择Prompt + prompt := "" + if options.PromptName != "" { + prompt = t.prompt.GetPrompt(options.PromptName) + } + + // 构建请求 + req := &provider.TranslateRequest{ + Text: text, + FromLang: options.FromLang, + ToLang: options.ToLang, + Prompt: prompt, + Model: t.selectModel(options.Model), + Options: options.ExtraOptions, + } + + // 调用厂商API + resp, err := t.provider.Translate(timeoutCtx, req) + if err != nil { + return nil, fmt.Errorf("翻译失败: %w", err) + } + + // 构建结果 + return &TranslateResult{ + Original: text, + Translated: resp.Text, + FromLang: resp.FromLang, + ToLang: resp.ToLang, + Model: resp.Model, + Usage: resp.Usage, + }, nil +} + +// TranslateWithProvider 使用指定厂商执行翻译 +func (t *Translator) TranslateWithProvider(ctx context.Context, text string, providerName string, options *TranslateOptions) (*TranslateResult, error) { + // 创建指定厂商实例 + providerConfig, err := t.config.GetProviderConfig(providerName) + if err != nil { + return nil, fmt.Errorf("获取厂商配置失败: %w", err) + } + + // 创建厂商实例 + providerInstance, err := provider.CreateProvider(providerName, provider.ProviderConfig{ + APIHost: providerConfig.APIHost, + APIKey: providerConfig.APIKey, + Model: providerConfig.Model, + }) + if err != nil { + return nil, fmt.Errorf("创建厂商实例失败: %w", err) + } + + // 临时切换厂商 + originalProvider := t.provider + t.provider = providerInstance + defer func() { + t.provider = originalProvider + }() + + // 执行翻译 + return t.Translate(ctx, text, options) +} + +// selectModel 选择模型 +func (t *Translator) selectModel(model string) string { + if model != "" { + return model + } + return t.config.DefaultModel +} + +// GetProvider 获取当前厂商 +func (t *Translator) GetProvider() provider.Provider { + return t.provider +} + +// GetConfig 获取配置 +func (t *Translator) GetConfig() *config.Config { + return t.config +} + +// GetPromptManager 获取Prompt管理器 +func (t *Translator) GetPromptManager() *PromptManager { + return t.prompt +} + +// SetTimeout 设置超时时间 +func (t *Translator) SetTimeout(seconds int) { + t.config.Timeout = seconds +} + +// TranslateOptions 翻译选项 +type TranslateOptions struct { + FromLang string + ToLang string + PromptName string + Model string + Temperature float64 + ExtraOptions map[string]interface{} +} + +// TranslateResult 翻译结果 +type TranslateResult struct { + Original string + Translated string + FromLang string + ToLang string + Model string + Usage *provider.Usage +} + +// String 返回翻译结果的字符串表示 +func (r *TranslateResult) String() string { + return r.Translated +} + +// TranslateResultWithInfo 带详细信息的翻译结果 +type TranslateResultWithInfo struct { + Result *TranslateResult + Duration time.Duration + Provider string + Timestamp time.Time +} + +// BatchTranslate 批量翻译结果 +type BatchTranslateRequest struct { + Texts []string + Options *TranslateOptions +} + +// BatchTranslateResult 批量翻译结果 +type BatchTranslateResult struct { + Results []*TranslateResult + Errors []error + Summary BatchTranslateSummary +} + +// BatchTranslateSummary 批量翻译摘要 +type BatchTranslateSummary struct { + Total int + Success int + Failed int + Duration time.Duration + AvgTokens int +} diff --git a/taolun.md b/taolun.md index 2feb9d8..6aeed77 100644 --- a/taolun.md +++ b/taolun.md @@ -85,4 +85,44 @@ **关联文档**: - [AGENTS.md#文档管理](AGENTS.md#开发规范) -- [changelog.md#0.0.1](changelog.md#001) \ No newline at end of file +- [changelog.md#0.0.1](changelog.md#001) + +--- + +### [2026-03-28 23:50] 版本 0.0.2 - 实现核心架构 +**原因**: 开始实现项目核心功能 +**分析**: +- 根据OOP设计模式实现三个核心类 +- 需要先实现配置加载和厂商接口 +- 创建基本的CLI入口点 + +**解决方案**: +1. **Config类实现**: + - 支持YAML配置文件加载 + - 环境变量替换 + - 配置验证和默认值 + +2. **Provider接口实现**: + - 定义统一的翻译接口 + - 工厂模式创建厂商实例 + - 实现硅基流动厂商作为示例 + +3. **Translator类实现**: + - 核心翻译逻辑 + - Prompt管理 + - 超时控制 + +4. **CLI入口点**: + - 命令行参数解析 + - 配置加载 + - 翻译执行 + +**技术细节**: +- 使用`gopkg.in/yaml.v3`处理YAML +- 实现工厂模式注册机制 +- 使用context处理超时和取消 +- 添加基本单元测试 + +**关联文档**: +- [AGENTS.md#OOP设计模式](AGENTS.md#oop设计模式) +- [changelog.md#0.0.2](changelog.md#002) \ No newline at end of file