merge: v0.0.2 core architecture implementation
This commit is contained in:
25
changelog.md
25
changelog.md
@@ -32,6 +32,31 @@
|
||||
|
||||
## 版本历史
|
||||
|
||||
### 0.0.2 (2026-03-28) - 核心架构实现
|
||||
**类型**: 功能版本
|
||||
**状态**: 开发中
|
||||
|
||||
**变更内容**:
|
||||
- ✅ 实现Config配置类(internal/config/config.go)
|
||||
- ✅ 实现Provider接口和工厂模式(internal/provider/)
|
||||
- ✅ 实现硅基流动厂商(internal/provider/siliconflow.go)
|
||||
- ✅ 实现Translator核心翻译类(internal/translator/)
|
||||
- ✅ 实现Prompt管理器(internal/translator/prompt.go)
|
||||
- ✅ 创建CLI入口点(cmd/yoyo/main.go)
|
||||
- ✅ 添加配置文件模板(configs/config.yaml)
|
||||
- ✅ 添加单元测试(internal/config/config_test.go)
|
||||
- ✅ 初始化Git仓库和版本标签
|
||||
|
||||
**讨论记录**:
|
||||
- [实现核心架构](taolun.md#2026-03-28-2350-版本-002-实现核心架构)
|
||||
|
||||
**下一步**:
|
||||
- 实现其他厂商(火山引擎、国家超算、Qwen、OpenAI兼容)
|
||||
- 添加更多测试
|
||||
- 实现批量翻译功能
|
||||
- 添加翻译历史记录
|
||||
- 实现配置文件热重载
|
||||
|
||||
### 0.0.1 (2026-03-28) - 项目初始化
|
||||
**类型**: 初始化版本
|
||||
**状态**: 开发中
|
||||
|
||||
2
go.mod
2
go.mod
@@ -1,3 +1,5 @@
|
||||
module github.com/titor/fanyi
|
||||
|
||||
go 1.26.1
|
||||
|
||||
require gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
|
||||
3
go.sum
Normal file
3
go.sum
Normal file
@@ -0,0 +1,3 @@
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
202
internal/config/config.go
Normal file
202
internal/config/config.go
Normal file
@@ -0,0 +1,202 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// Config 全局配置结构
|
||||
type Config struct {
|
||||
// 全局设置
|
||||
DefaultProvider string `yaml:"default_provider"`
|
||||
DefaultModel string `yaml:"default_model"`
|
||||
Timeout int `yaml:"timeout"` // 秒
|
||||
|
||||
// 厂商配置
|
||||
Providers map[string]ProviderConfig `yaml:"providers"`
|
||||
|
||||
// Prompt配置
|
||||
Prompts map[string]string `yaml:"prompts"`
|
||||
}
|
||||
|
||||
// ProviderConfig 厂商配置
|
||||
type ProviderConfig struct {
|
||||
APIHost string `yaml:"api_host"`
|
||||
APIKey string `yaml:"api_key"`
|
||||
Model string `yaml:"model"`
|
||||
Enabled bool `yaml:"enabled"`
|
||||
}
|
||||
|
||||
// ConfigLoader 配置加载器接口
|
||||
type ConfigLoader interface {
|
||||
Load(path string) (*Config, error)
|
||||
Save(config *Config, path string) error
|
||||
}
|
||||
|
||||
// YAMLConfigLoader YAML配置加载器实现
|
||||
type YAMLConfigLoader struct{}
|
||||
|
||||
// Load 加载YAML配置文件
|
||||
func (l *YAMLConfigLoader) Load(path string) (*Config, error) {
|
||||
// 读取文件
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取配置文件失败: %w", err)
|
||||
}
|
||||
|
||||
// 替换环境变量
|
||||
content := string(data)
|
||||
content = os.ExpandEnv(content)
|
||||
|
||||
// 解析YAML
|
||||
config := &Config{}
|
||||
if err := yaml.Unmarshal([]byte(content), config); err != nil {
|
||||
return nil, fmt.Errorf("解析配置文件失败: %w", err)
|
||||
}
|
||||
|
||||
// 设置默认值
|
||||
config.setDefaults()
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// Save 保存配置到文件
|
||||
func (l *YAMLConfigLoader) Save(config *Config, path string) error {
|
||||
// 确保目录存在
|
||||
dir := filepath.Dir(path)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return fmt.Errorf("创建配置目录失败: %w", err)
|
||||
}
|
||||
|
||||
// 序列化为YAML
|
||||
data, err := yaml.Marshal(config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("序列化配置失败: %w", err)
|
||||
}
|
||||
|
||||
// 写入文件
|
||||
if err := os.WriteFile(path, data, 0644); err != nil {
|
||||
return fmt.Errorf("写入配置文件失败: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// setDefaults 设置默认值
|
||||
func (c *Config) setDefaults() {
|
||||
if c.DefaultProvider == "" {
|
||||
c.DefaultProvider = "siliconflow"
|
||||
}
|
||||
if c.Timeout <= 0 {
|
||||
c.Timeout = 30
|
||||
}
|
||||
if c.DefaultModel == "" {
|
||||
c.DefaultModel = "gpt-3.5-turbo"
|
||||
}
|
||||
|
||||
// 为每个厂商设置默认值
|
||||
for name, provider := range c.Providers {
|
||||
if provider.Model == "" {
|
||||
provider.Model = c.DefaultModel
|
||||
c.Providers[name] = provider
|
||||
}
|
||||
// 替换环境变量
|
||||
provider.APIKey = os.ExpandEnv(provider.APIKey)
|
||||
c.Providers[name] = provider
|
||||
}
|
||||
|
||||
// 确保Prompts映射存在
|
||||
if c.Prompts == nil {
|
||||
c.Prompts = make(map[string]string)
|
||||
}
|
||||
}
|
||||
|
||||
// GetProviderConfig 获取指定厂商的配置
|
||||
func (c *Config) GetProviderConfig(name string) (ProviderConfig, error) {
|
||||
config, exists := c.Providers[name]
|
||||
if !exists {
|
||||
return ProviderConfig{}, fmt.Errorf("未找到厂商配置: %s", name)
|
||||
}
|
||||
if !config.Enabled {
|
||||
return ProviderConfig{}, fmt.Errorf("厂商未启用: %s", name)
|
||||
}
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// GetPrompt 获取指定名称的Prompt
|
||||
func (c *Config) GetPrompt(name string) (string, bool) {
|
||||
prompt, exists := c.Prompts[name]
|
||||
return prompt, exists
|
||||
}
|
||||
|
||||
// ExpandEnv 扩展环境变量(辅助函数)
|
||||
func ExpandEnv(s string) string {
|
||||
return os.ExpandEnv(s)
|
||||
}
|
||||
|
||||
// Validate 验证配置是否有效
|
||||
func (c *Config) Validate() error {
|
||||
if c.DefaultProvider == "" {
|
||||
return fmt.Errorf("默认厂商不能为空")
|
||||
}
|
||||
|
||||
// 检查默认厂商是否在配置中
|
||||
if _, exists := c.Providers[c.DefaultProvider]; !exists {
|
||||
return fmt.Errorf("默认厂商 '%s' 未在配置中定义", c.DefaultProvider)
|
||||
}
|
||||
|
||||
// 检查每个厂商配置
|
||||
for name, provider := range c.Providers {
|
||||
if provider.APIKey == "" {
|
||||
return fmt.Errorf("厂商 '%s' 的API密钥不能为空", name)
|
||||
}
|
||||
if provider.APIHost == "" {
|
||||
return fmt.Errorf("厂商 '%s' 的API主机不能为空", name)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetDefaultProvider 获取默认厂商配置
|
||||
func (c *Config) GetDefaultProvider() (ProviderConfig, error) {
|
||||
return c.GetProviderConfig(c.DefaultProvider)
|
||||
}
|
||||
|
||||
// IsProviderEnabled 检查厂商是否启用
|
||||
func (c *Config) IsProviderEnabled(name string) bool {
|
||||
config, exists := c.Providers[name]
|
||||
return exists && config.Enabled
|
||||
}
|
||||
|
||||
// GetEnabledProviders 获取所有启用的厂商
|
||||
func (c *Config) GetEnabledProviders() []string {
|
||||
var enabled []string
|
||||
for name, config := range c.Providers {
|
||||
if config.Enabled {
|
||||
enabled = append(enabled, name)
|
||||
}
|
||||
}
|
||||
return enabled
|
||||
}
|
||||
|
||||
// String 返回配置的字符串表示(隐藏敏感信息)
|
||||
func (c *Config) String() string {
|
||||
var builder strings.Builder
|
||||
builder.WriteString(fmt.Sprintf("DefaultProvider: %s\n", c.DefaultProvider))
|
||||
builder.WriteString(fmt.Sprintf("DefaultModel: %s\n", c.DefaultModel))
|
||||
builder.WriteString(fmt.Sprintf("Timeout: %d seconds\n", c.Timeout))
|
||||
builder.WriteString("Providers:\n")
|
||||
for name, provider := range c.Providers {
|
||||
builder.WriteString(fmt.Sprintf(" %s: enabled=%v, model=%s\n", name, provider.Enabled, provider.Model))
|
||||
}
|
||||
builder.WriteString("Prompts:\n")
|
||||
for name := range c.Prompts {
|
||||
builder.WriteString(fmt.Sprintf(" %s\n", name))
|
||||
}
|
||||
return builder.String()
|
||||
}
|
||||
177
internal/config/config_test.go
Normal file
177
internal/config/config_test.go
Normal file
@@ -0,0 +1,177 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestConfigLoader_Load(t *testing.T) {
|
||||
// 创建临时配置文件
|
||||
tmpDir := t.TempDir()
|
||||
configFile := filepath.Join(tmpDir, "config.yaml")
|
||||
|
||||
configContent := `
|
||||
default_provider: "siliconflow"
|
||||
default_model: "gpt-3.5-turbo"
|
||||
timeout: 30
|
||||
|
||||
providers:
|
||||
siliconflow:
|
||||
api_host: "https://api.siliconflow.cn/v1"
|
||||
api_key: "${TEST_API_KEY}"
|
||||
model: "siliconflow-base"
|
||||
enabled: true
|
||||
|
||||
prompts:
|
||||
simple: "请用简单易懂的语言翻译以下内容。"
|
||||
`
|
||||
|
||||
if err := os.WriteFile(configFile, []byte(configContent), 0644); err != nil {
|
||||
t.Fatalf("创建配置文件失败: %v", err)
|
||||
}
|
||||
|
||||
// 设置环境变量
|
||||
os.Setenv("TEST_API_KEY", "test-key-123")
|
||||
defer os.Unsetenv("TEST_API_KEY")
|
||||
|
||||
// 加载配置
|
||||
loader := &YAMLConfigLoader{}
|
||||
config, err := loader.Load(configFile)
|
||||
if err != nil {
|
||||
t.Fatalf("加载配置失败: %v", err)
|
||||
}
|
||||
|
||||
// 验证配置
|
||||
if config.DefaultProvider != "siliconflow" {
|
||||
t.Errorf("Expected default provider 'siliconflow', got '%s'", config.DefaultProvider)
|
||||
}
|
||||
|
||||
if config.DefaultModel != "gpt-3.5-turbo" {
|
||||
t.Errorf("Expected default model 'gpt-3.5-turbo', got '%s'", config.DefaultModel)
|
||||
}
|
||||
|
||||
if config.Timeout != 30 {
|
||||
t.Errorf("Expected timeout 30, got %d", config.Timeout)
|
||||
}
|
||||
|
||||
// 验证厂商配置
|
||||
providerConfig, err := config.GetProviderConfig("siliconflow")
|
||||
if err != nil {
|
||||
t.Fatalf("获取厂商配置失败: %v", err)
|
||||
}
|
||||
|
||||
if providerConfig.APIKey != "test-key-123" {
|
||||
t.Errorf("Expected API key 'test-key-123', got '%s'", providerConfig.APIKey)
|
||||
}
|
||||
|
||||
if providerConfig.APIHost != "https://api.siliconflow.cn/v1" {
|
||||
t.Errorf("Expected API host 'https://api.siliconflow.cn/v1', got '%s'", providerConfig.APIHost)
|
||||
}
|
||||
|
||||
// 验证Prompt
|
||||
prompt, exists := config.GetPrompt("simple")
|
||||
if !exists {
|
||||
t.Error("Prompt 'simple' should exist")
|
||||
}
|
||||
if prompt != "请用简单易懂的语言翻译以下内容。" {
|
||||
t.Errorf("Unexpected prompt content: %s", prompt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_Validate(t *testing.T) {
|
||||
config := &Config{
|
||||
DefaultProvider: "siliconflow",
|
||||
DefaultModel: "gpt-3.5-turbo",
|
||||
Timeout: 30,
|
||||
Providers: map[string]ProviderConfig{
|
||||
"siliconflow": {
|
||||
APIHost: "https://api.siliconflow.cn/v1",
|
||||
APIKey: "test-key",
|
||||
Model: "siliconflow-base",
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
Prompts: map[string]string{
|
||||
"simple": "test prompt",
|
||||
},
|
||||
}
|
||||
|
||||
if err := config.Validate(); err != nil {
|
||||
t.Errorf("Config validation failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_GetProviderConfig(t *testing.T) {
|
||||
config := &Config{
|
||||
Providers: map[string]ProviderConfig{
|
||||
"siliconflow": {
|
||||
APIHost: "https://api.siliconflow.cn/v1",
|
||||
APIKey: "test-key",
|
||||
Model: "siliconflow-base",
|
||||
Enabled: true,
|
||||
},
|
||||
"volcano": {
|
||||
APIHost: "https://api.volcengine.com/v1",
|
||||
APIKey: "test-key",
|
||||
Model: "volcano-chat",
|
||||
Enabled: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// 测试获取启用的厂商
|
||||
_, err := config.GetProviderConfig("siliconflow")
|
||||
if err != nil {
|
||||
t.Errorf("Should get enabled provider: %v", err)
|
||||
}
|
||||
|
||||
// 测试获取禁用的厂商
|
||||
_, err = config.GetProviderConfig("volcano")
|
||||
if err == nil {
|
||||
t.Error("Should return error for disabled provider")
|
||||
}
|
||||
|
||||
// 测试获取不存在的厂商
|
||||
_, err = config.GetProviderConfig("nonexistent")
|
||||
if err == nil {
|
||||
t.Error("Should return error for non-existent provider")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_GetEnabledProviders(t *testing.T) {
|
||||
config := &Config{
|
||||
Providers: map[string]ProviderConfig{
|
||||
"siliconflow": {Enabled: true},
|
||||
"volcano": {Enabled: false},
|
||||
"qwen": {Enabled: true},
|
||||
},
|
||||
}
|
||||
|
||||
enabled := config.GetEnabledProviders()
|
||||
if len(enabled) != 2 {
|
||||
t.Errorf("Expected 2 enabled providers, got %d", len(enabled))
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_String(t *testing.T) {
|
||||
config := &Config{
|
||||
DefaultProvider: "siliconflow",
|
||||
DefaultModel: "gpt-3.5-turbo",
|
||||
Timeout: 30,
|
||||
Providers: map[string]ProviderConfig{
|
||||
"siliconflow": {
|
||||
Enabled: true,
|
||||
Model: "siliconflow-base",
|
||||
},
|
||||
},
|
||||
Prompts: map[string]string{
|
||||
"simple": "test prompt",
|
||||
},
|
||||
}
|
||||
|
||||
str := config.String()
|
||||
if str == "" {
|
||||
t.Error("String representation should not be empty")
|
||||
}
|
||||
}
|
||||
99
internal/provider/factory.go
Normal file
99
internal/provider/factory.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// ProviderFactory 厂商工厂
|
||||
type ProviderFactory struct {
|
||||
providers map[string]func(ProviderConfig) (Provider, error)
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewProviderFactory 创建工厂实例
|
||||
func NewProviderFactory() *ProviderFactory {
|
||||
factory := &ProviderFactory{
|
||||
providers: make(map[string]func(ProviderConfig) (Provider, error)),
|
||||
}
|
||||
|
||||
// 注册所有厂商(延迟注册,避免循环依赖)
|
||||
// 实际注册在init()函数中完成
|
||||
|
||||
return factory
|
||||
}
|
||||
|
||||
// Register 注册厂商构造函数
|
||||
func (f *ProviderFactory) Register(name string, creator func(ProviderConfig) (Provider, error)) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.providers[name] = creator
|
||||
}
|
||||
|
||||
// Create 创建厂商实例
|
||||
func (f *ProviderFactory) Create(name string, config ProviderConfig) (Provider, error) {
|
||||
f.mu.RLock()
|
||||
creator, exists := f.providers[name]
|
||||
f.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("不支持的厂商: %s", name)
|
||||
}
|
||||
|
||||
provider, err := creator(config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建厂商实例失败: %w", err)
|
||||
}
|
||||
|
||||
if err := provider.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("厂商配置验证失败: %w", err)
|
||||
}
|
||||
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
// HasProvider 检查是否支持指定厂商
|
||||
func (f *ProviderFactory) HasProvider(name string) bool {
|
||||
f.mu.RLock()
|
||||
defer f.mu.RUnlock()
|
||||
_, exists := f.providers[name]
|
||||
return exists
|
||||
}
|
||||
|
||||
// GetSupportedProviders 获取所有支持的厂商
|
||||
func (f *ProviderFactory) GetSupportedProviders() []string {
|
||||
f.mu.RLock()
|
||||
defer f.mu.RUnlock()
|
||||
|
||||
providers := make([]string, 0, len(f.providers))
|
||||
for name := range f.providers {
|
||||
providers = append(providers, name)
|
||||
}
|
||||
return providers
|
||||
}
|
||||
|
||||
// 全局工厂实例
|
||||
var defaultFactory *ProviderFactory
|
||||
|
||||
// init 初始化默认工厂
|
||||
func init() {
|
||||
defaultFactory = NewProviderFactory()
|
||||
|
||||
// 这里可以注册厂商,但由于循环依赖,实际注册会在各厂商的init()中完成
|
||||
// 或者在main包中手动注册
|
||||
}
|
||||
|
||||
// GetDefaultFactory 获取默认工厂
|
||||
func GetDefaultFactory() *ProviderFactory {
|
||||
return defaultFactory
|
||||
}
|
||||
|
||||
// CreateProvider 使用默认工厂创建厂商
|
||||
func CreateProvider(name string, config ProviderConfig) (Provider, error) {
|
||||
return defaultFactory.Create(name, config)
|
||||
}
|
||||
|
||||
// RegisterProvider 注册厂商到默认工厂
|
||||
func RegisterProvider(name string, creator func(ProviderConfig) (Provider, error)) {
|
||||
defaultFactory.Register(name, creator)
|
||||
}
|
||||
108
internal/provider/provider.go
Normal file
108
internal/provider/provider.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Provider 厂商接口
|
||||
type Provider interface {
|
||||
// Translate 调用厂商API进行翻译
|
||||
Translate(ctx context.Context, req *TranslateRequest) (*TranslateResponse, error)
|
||||
|
||||
// Name 返回厂商名称
|
||||
Name() string
|
||||
|
||||
// Validate 验证配置是否有效
|
||||
Validate() error
|
||||
}
|
||||
|
||||
// TranslateRequest 翻译请求
|
||||
type TranslateRequest struct {
|
||||
Text string `json:"text"`
|
||||
FromLang string `json:"from_lang"`
|
||||
ToLang string `json:"to_lang"`
|
||||
Prompt string `json:"prompt"`
|
||||
Model string `json:"model"`
|
||||
Options map[string]interface{} `json:"options"`
|
||||
}
|
||||
|
||||
// TranslateResponse 翻译响应
|
||||
type TranslateResponse struct {
|
||||
Text string `json:"text"`
|
||||
FromLang string `json:"from_lang"`
|
||||
ToLang string `json:"to_lang"`
|
||||
Model string `json:"model"`
|
||||
Usage *Usage `json:"usage"`
|
||||
RawResponse []byte `json:"raw_response,omitempty"`
|
||||
}
|
||||
|
||||
// Usage 用量统计
|
||||
type Usage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
// ProviderConfig 厂商配置
|
||||
type ProviderConfig struct {
|
||||
APIHost string `json:"api_host"`
|
||||
APIKey string `json:"api_key"`
|
||||
Model string `json:"model"`
|
||||
}
|
||||
|
||||
// TranslateError 翻译错误
|
||||
type TranslateError struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Details string `json:"details,omitempty"`
|
||||
}
|
||||
|
||||
// Error 实现error接口
|
||||
func (e *TranslateError) Error() string {
|
||||
if e.Details != "" {
|
||||
return fmt.Sprintf("%s: %s (%s)", e.Code, e.Message, e.Details)
|
||||
}
|
||||
return fmt.Sprintf("%s: %s", e.Code, e.Message)
|
||||
}
|
||||
|
||||
// NewTranslateError 创建翻译错误
|
||||
func NewTranslateError(code, message string) *TranslateError {
|
||||
return &TranslateError{
|
||||
Code: code,
|
||||
Message: message,
|
||||
}
|
||||
}
|
||||
|
||||
// NewTranslateErrorWithDetails 创建带详情的翻译错误
|
||||
func NewTranslateErrorWithDetails(code, message, details string) *TranslateError {
|
||||
return &TranslateError{
|
||||
Code: code,
|
||||
Message: message,
|
||||
Details: details,
|
||||
}
|
||||
}
|
||||
|
||||
// IsNetworkError 检查是否为网络错误
|
||||
func IsNetworkError(err error) bool {
|
||||
if translateErr, ok := err.(*TranslateError); ok {
|
||||
return translateErr.Code == "NETWORK_ERROR"
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsRateLimitError 检查是否为限流错误
|
||||
func IsRateLimitError(err error) bool {
|
||||
if translateErr, ok := err.(*TranslateError); ok {
|
||||
return translateErr.Code == "RATE_LIMIT"
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsAuthError 检查是否为认证错误
|
||||
func IsAuthError(err error) bool {
|
||||
if translateErr, ok := err.(*TranslateError); ok {
|
||||
return translateErr.Code == "AUTH_ERROR" || translateErr.Code == "INVALID_API_KEY"
|
||||
}
|
||||
return false
|
||||
}
|
||||
185
internal/provider/siliconflow.go
Normal file
185
internal/provider/siliconflow.go
Normal file
@@ -0,0 +1,185 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SiliconFlowProvider 硅基流动厂商实现
|
||||
type SiliconFlowProvider struct {
|
||||
config ProviderConfig
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
// NewSiliconFlowProvider 创建硅基流动厂商实例
|
||||
func NewSiliconFlowProvider(config ProviderConfig) (Provider, error) {
|
||||
return &SiliconFlowProvider{
|
||||
config: config,
|
||||
client: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Name 返回厂商名称
|
||||
func (p *SiliconFlowProvider) Name() string {
|
||||
return "siliconflow"
|
||||
}
|
||||
|
||||
// Validate 验证配置
|
||||
func (p *SiliconFlowProvider) Validate() error {
|
||||
if p.config.APIKey == "" {
|
||||
return fmt.Errorf("siliconflow: API key 不能为空")
|
||||
}
|
||||
if p.config.APIHost == "" {
|
||||
p.config.APIHost = "https://api.siliconflow.cn/v1"
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Translate 调用硅基流动API
|
||||
func (p *SiliconFlowProvider) Translate(ctx context.Context, req *TranslateRequest) (*TranslateResponse, error) {
|
||||
// 构建请求体
|
||||
requestBody := map[string]interface{}{
|
||||
"model": p.config.Model,
|
||||
"messages": []map[string]string{
|
||||
{
|
||||
"role": "user",
|
||||
"content": req.Text,
|
||||
},
|
||||
},
|
||||
"stream": false,
|
||||
}
|
||||
|
||||
// 如果有Prompt,添加到系统消息
|
||||
if req.Prompt != "" {
|
||||
messages := requestBody["messages"].([]map[string]string)
|
||||
requestBody["messages"] = append([]map[string]string{
|
||||
{
|
||||
"role": "system",
|
||||
"content": req.Prompt,
|
||||
},
|
||||
}, messages...)
|
||||
}
|
||||
|
||||
// 序列化请求体
|
||||
jsonData, err := json.Marshal(requestBody)
|
||||
if err != nil {
|
||||
return nil, NewTranslateErrorWithDetails("SERIALIZATION_ERROR", "请求序列化失败", err.Error())
|
||||
}
|
||||
|
||||
// 创建HTTP请求
|
||||
url := fmt.Sprintf("%s/chat/completions", p.config.APIHost)
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, NewTranslateErrorWithDetails("REQUEST_ERROR", "创建请求失败", err.Error())
|
||||
}
|
||||
|
||||
// 设置请求头
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("Authorization", fmt.Sprintf("Bearer %s", p.config.APIKey))
|
||||
|
||||
// 发送请求
|
||||
resp, err := p.client.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, NewTranslateErrorWithDetails("NETWORK_ERROR", "请求失败", err.Error())
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 读取响应
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, NewTranslateErrorWithDetails("RESPONSE_ERROR", "读取响应失败", err.Error())
|
||||
}
|
||||
|
||||
// 检查HTTP状态码
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, NewTranslateErrorWithDetails(
|
||||
"HTTP_ERROR",
|
||||
fmt.Sprintf("HTTP错误: %d", resp.StatusCode),
|
||||
string(body),
|
||||
)
|
||||
}
|
||||
|
||||
// 解析响应
|
||||
var apiResp SiliconFlowResponse
|
||||
if err := json.Unmarshal(body, &apiResp); err != nil {
|
||||
return nil, NewTranslateErrorWithDetails("PARSE_ERROR", "解析响应失败", err.Error())
|
||||
}
|
||||
|
||||
// 检查API错误
|
||||
if apiResp.Error != nil {
|
||||
return nil, NewTranslateErrorWithDetails(
|
||||
"API_ERROR",
|
||||
apiResp.Error.Message,
|
||||
apiResp.Error.Code,
|
||||
)
|
||||
}
|
||||
|
||||
// 构建响应
|
||||
if len(apiResp.Choices) == 0 {
|
||||
return nil, NewTranslateError("NO_RESPONSE", "API返回空响应")
|
||||
}
|
||||
|
||||
translatedText := apiResp.Choices[0].Message.Content
|
||||
|
||||
return &TranslateResponse{
|
||||
Text: translatedText,
|
||||
FromLang: req.FromLang,
|
||||
ToLang: req.ToLang,
|
||||
Model: apiResp.Model,
|
||||
Usage: &Usage{
|
||||
PromptTokens: apiResp.Usage.PromptTokens,
|
||||
CompletionTokens: apiResp.Usage.CompletionTokens,
|
||||
TotalTokens: apiResp.Usage.TotalTokens,
|
||||
},
|
||||
RawResponse: body,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SiliconFlowResponse 硅基流动API响应
|
||||
type SiliconFlowResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []SiliconFlowChoice `json:"choices"`
|
||||
Usage SiliconFlowUsage `json:"usage"`
|
||||
Error *SiliconFlowError `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// SiliconFlowChoice 选择项
|
||||
type SiliconFlowChoice struct {
|
||||
Index int `json:"index"`
|
||||
Message SiliconFlowMessage `json:"message"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
}
|
||||
|
||||
// SiliconFlowMessage 消息
|
||||
type SiliconFlowMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
// SiliconFlowUsage 用量
|
||||
type SiliconFlowUsage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
// SiliconFlowError 错误
|
||||
type SiliconFlowError struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// init 注册硅基流动厂商
|
||||
func init() {
|
||||
RegisterProvider("siliconflow", NewSiliconFlowProvider)
|
||||
}
|
||||
131
internal/translator/prompt.go
Normal file
131
internal/translator/prompt.go
Normal file
@@ -0,0 +1,131 @@
|
||||
package translator
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// PromptManager Prompt管理器
|
||||
type PromptManager struct {
|
||||
prompts map[string]string
|
||||
}
|
||||
|
||||
// NewPromptManager 创建Prompt管理器
|
||||
func NewPromptManager(prompts map[string]string) *PromptManager {
|
||||
if prompts == nil {
|
||||
prompts = make(map[string]string)
|
||||
}
|
||||
return &PromptManager{
|
||||
prompts: prompts,
|
||||
}
|
||||
}
|
||||
|
||||
// GetPrompt 获取指定名称的Prompt
|
||||
func (pm *PromptManager) GetPrompt(name string) string {
|
||||
prompt, exists := pm.prompts[name]
|
||||
if !exists {
|
||||
return ""
|
||||
}
|
||||
return prompt
|
||||
}
|
||||
|
||||
// SetPrompt 设置Prompt
|
||||
func (pm *PromptManager) SetPrompt(name, content string) {
|
||||
pm.prompts[name] = content
|
||||
}
|
||||
|
||||
// DeletePrompt 删除Prompt
|
||||
func (pm *PromptManager) DeletePrompt(name string) {
|
||||
delete(pm.prompts, name)
|
||||
}
|
||||
|
||||
// HasPrompt 检查Prompt是否存在
|
||||
func (pm *PromptManager) HasPrompt(name string) bool {
|
||||
_, exists := pm.prompts[name]
|
||||
return exists
|
||||
}
|
||||
|
||||
// GetAllPrompts 获取所有Prompt名称
|
||||
func (pm *PromptManager) GetAllPrompts() []string {
|
||||
names := make([]string, 0, len(pm.prompts))
|
||||
for name := range pm.prompts {
|
||||
names = append(names, name)
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
// GetPromptCount 获取Prompt数量
|
||||
func (pm *PromptManager) GetPromptCount() int {
|
||||
return len(pm.prompts)
|
||||
}
|
||||
|
||||
// FormatPrompt 格式化Prompt,替换变量
|
||||
func (pm *PromptManager) FormatPrompt(name string, vars map[string]string) string {
|
||||
prompt := pm.GetPrompt(name)
|
||||
if prompt == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// 替换变量
|
||||
for key, value := range vars {
|
||||
placeholder := fmt.Sprintf("{{%s}}", key)
|
||||
prompt = strings.ReplaceAll(prompt, placeholder, value)
|
||||
}
|
||||
|
||||
return prompt
|
||||
}
|
||||
|
||||
// ValidatePrompt 验证Prompt是否包含必要变量
|
||||
func (pm *PromptManager) ValidatePrompt(name string, requiredVars []string) error {
|
||||
prompt := pm.GetPrompt(name)
|
||||
if prompt == "" {
|
||||
return fmt.Errorf("prompt '%s' 不存在", name)
|
||||
}
|
||||
|
||||
for _, varName := range requiredVars {
|
||||
placeholder := fmt.Sprintf("{{%s}}", varName)
|
||||
if !strings.Contains(prompt, placeholder) {
|
||||
return fmt.Errorf("prompt '%s' 缺少必要变量: %s", name, varName)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddDefaultPrompts 添加默认Prompts
|
||||
func (pm *PromptManager) AddDefaultPrompts() {
|
||||
defaults := map[string]string{
|
||||
"technical": "你是一位专业的技术翻译,请准确翻译以下技术文档,保持专业术语的准确性。",
|
||||
"creative": "你是一位富有创造力的翻译家,请用优美流畅的语言翻译以下内容。",
|
||||
"academic": "你是一位学术翻译专家,请用严谨的学术语言翻译以下内容。",
|
||||
"simple": "请用简单易懂的语言翻译以下内容。",
|
||||
"code": "你是一位专业的代码翻译专家,请准确翻译以下代码注释和文档,保持代码结构和注释格式。",
|
||||
}
|
||||
|
||||
for name, content := range defaults {
|
||||
if !pm.HasPrompt(name) {
|
||||
pm.SetPrompt(name, content)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetPromptWithFallback 获取Prompt,如果不存在则返回默认Prompt
|
||||
func (pm *PromptManager) GetPromptWithFallback(name, fallback string) string {
|
||||
prompt := pm.GetPrompt(name)
|
||||
if prompt == "" {
|
||||
return fallback
|
||||
}
|
||||
return prompt
|
||||
}
|
||||
|
||||
// MergePrompts 合并多个Prompt
|
||||
func (pm *PromptManager) MergePrompts(names []string, separator string) string {
|
||||
var prompts []string
|
||||
for _, name := range names {
|
||||
prompt := pm.GetPrompt(name)
|
||||
if prompt != "" {
|
||||
prompts = append(prompts, prompt)
|
||||
}
|
||||
}
|
||||
return strings.Join(prompts, separator)
|
||||
}
|
||||
177
internal/translator/translator.go
Normal file
177
internal/translator/translator.go
Normal file
@@ -0,0 +1,177 @@
|
||||
package translator
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/titor/fanyi/internal/config"
|
||||
"github.com/titor/fanyi/internal/provider"
|
||||
)
|
||||
|
||||
// Translator 核心翻译类
|
||||
type Translator struct {
|
||||
config *config.Config
|
||||
provider provider.Provider
|
||||
prompt *PromptManager
|
||||
}
|
||||
|
||||
// NewTranslator 创建翻译器实例
|
||||
func NewTranslator(config *config.Config, provider provider.Provider) *Translator {
|
||||
return &Translator{
|
||||
config: config,
|
||||
provider: provider,
|
||||
prompt: NewPromptManager(config.Prompts),
|
||||
}
|
||||
}
|
||||
|
||||
// Translate 执行翻译
|
||||
func (t *Translator) Translate(ctx context.Context, text string, options *TranslateOptions) (*TranslateResult, error) {
|
||||
// 设置超时
|
||||
timeoutCtx, cancel := context.WithTimeout(ctx, time.Duration(t.config.Timeout)*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// 选择Prompt
|
||||
prompt := ""
|
||||
if options.PromptName != "" {
|
||||
prompt = t.prompt.GetPrompt(options.PromptName)
|
||||
}
|
||||
|
||||
// 构建请求
|
||||
req := &provider.TranslateRequest{
|
||||
Text: text,
|
||||
FromLang: options.FromLang,
|
||||
ToLang: options.ToLang,
|
||||
Prompt: prompt,
|
||||
Model: t.selectModel(options.Model),
|
||||
Options: options.ExtraOptions,
|
||||
}
|
||||
|
||||
// 调用厂商API
|
||||
resp, err := t.provider.Translate(timeoutCtx, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("翻译失败: %w", err)
|
||||
}
|
||||
|
||||
// 构建结果
|
||||
return &TranslateResult{
|
||||
Original: text,
|
||||
Translated: resp.Text,
|
||||
FromLang: resp.FromLang,
|
||||
ToLang: resp.ToLang,
|
||||
Model: resp.Model,
|
||||
Usage: resp.Usage,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// TranslateWithProvider 使用指定厂商执行翻译
|
||||
func (t *Translator) TranslateWithProvider(ctx context.Context, text string, providerName string, options *TranslateOptions) (*TranslateResult, error) {
|
||||
// 创建指定厂商实例
|
||||
providerConfig, err := t.config.GetProviderConfig(providerName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取厂商配置失败: %w", err)
|
||||
}
|
||||
|
||||
// 创建厂商实例
|
||||
providerInstance, err := provider.CreateProvider(providerName, provider.ProviderConfig{
|
||||
APIHost: providerConfig.APIHost,
|
||||
APIKey: providerConfig.APIKey,
|
||||
Model: providerConfig.Model,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建厂商实例失败: %w", err)
|
||||
}
|
||||
|
||||
// 临时切换厂商
|
||||
originalProvider := t.provider
|
||||
t.provider = providerInstance
|
||||
defer func() {
|
||||
t.provider = originalProvider
|
||||
}()
|
||||
|
||||
// 执行翻译
|
||||
return t.Translate(ctx, text, options)
|
||||
}
|
||||
|
||||
// selectModel 选择模型
|
||||
func (t *Translator) selectModel(model string) string {
|
||||
if model != "" {
|
||||
return model
|
||||
}
|
||||
return t.config.DefaultModel
|
||||
}
|
||||
|
||||
// GetProvider 获取当前厂商
|
||||
func (t *Translator) GetProvider() provider.Provider {
|
||||
return t.provider
|
||||
}
|
||||
|
||||
// GetConfig 获取配置
|
||||
func (t *Translator) GetConfig() *config.Config {
|
||||
return t.config
|
||||
}
|
||||
|
||||
// GetPromptManager 获取Prompt管理器
|
||||
func (t *Translator) GetPromptManager() *PromptManager {
|
||||
return t.prompt
|
||||
}
|
||||
|
||||
// SetTimeout 设置超时时间
|
||||
func (t *Translator) SetTimeout(seconds int) {
|
||||
t.config.Timeout = seconds
|
||||
}
|
||||
|
||||
// TranslateOptions 翻译选项
|
||||
type TranslateOptions struct {
|
||||
FromLang string
|
||||
ToLang string
|
||||
PromptName string
|
||||
Model string
|
||||
Temperature float64
|
||||
ExtraOptions map[string]interface{}
|
||||
}
|
||||
|
||||
// TranslateResult 翻译结果
|
||||
type TranslateResult struct {
|
||||
Original string
|
||||
Translated string
|
||||
FromLang string
|
||||
ToLang string
|
||||
Model string
|
||||
Usage *provider.Usage
|
||||
}
|
||||
|
||||
// String 返回翻译结果的字符串表示
|
||||
func (r *TranslateResult) String() string {
|
||||
return r.Translated
|
||||
}
|
||||
|
||||
// TranslateResultWithInfo 带详细信息的翻译结果
|
||||
type TranslateResultWithInfo struct {
|
||||
Result *TranslateResult
|
||||
Duration time.Duration
|
||||
Provider string
|
||||
Timestamp time.Time
|
||||
}
|
||||
|
||||
// BatchTranslate 批量翻译结果
|
||||
type BatchTranslateRequest struct {
|
||||
Texts []string
|
||||
Options *TranslateOptions
|
||||
}
|
||||
|
||||
// BatchTranslateResult 批量翻译结果
|
||||
type BatchTranslateResult struct {
|
||||
Results []*TranslateResult
|
||||
Errors []error
|
||||
Summary BatchTranslateSummary
|
||||
}
|
||||
|
||||
// BatchTranslateSummary 批量翻译摘要
|
||||
type BatchTranslateSummary struct {
|
||||
Total int
|
||||
Success int
|
||||
Failed int
|
||||
Duration time.Duration
|
||||
AvgTokens int
|
||||
}
|
||||
40
taolun.md
40
taolun.md
@@ -86,3 +86,43 @@
|
||||
**关联文档**:
|
||||
- [AGENTS.md#文档管理](AGENTS.md#开发规范)
|
||||
- [changelog.md#0.0.1](changelog.md#001)
|
||||
|
||||
---
|
||||
|
||||
### [2026-03-28 23:50] 版本 0.0.2 - 实现核心架构
|
||||
**原因**: 开始实现项目核心功能
|
||||
**分析**:
|
||||
- 根据OOP设计模式实现三个核心类
|
||||
- 需要先实现配置加载和厂商接口
|
||||
- 创建基本的CLI入口点
|
||||
|
||||
**解决方案**:
|
||||
1. **Config类实现**:
|
||||
- 支持YAML配置文件加载
|
||||
- 环境变量替换
|
||||
- 配置验证和默认值
|
||||
|
||||
2. **Provider接口实现**:
|
||||
- 定义统一的翻译接口
|
||||
- 工厂模式创建厂商实例
|
||||
- 实现硅基流动厂商作为示例
|
||||
|
||||
3. **Translator类实现**:
|
||||
- 核心翻译逻辑
|
||||
- Prompt管理
|
||||
- 超时控制
|
||||
|
||||
4. **CLI入口点**:
|
||||
- 命令行参数解析
|
||||
- 配置加载
|
||||
- 翻译执行
|
||||
|
||||
**技术细节**:
|
||||
- 使用`gopkg.in/yaml.v3`处理YAML
|
||||
- 实现工厂模式注册机制
|
||||
- 使用context处理超时和取消
|
||||
- 添加基本单元测试
|
||||
|
||||
**关联文档**:
|
||||
- [AGENTS.md#OOP设计模式](AGENTS.md#oop设计模式)
|
||||
- [changelog.md#0.0.2](changelog.md#002)
|
||||
Reference in New Issue
Block a user