feat: implement core architecture (v0.0.2)
- Implement Config class with YAML loading and environment variable support - Implement Provider interface and factory pattern - Implement SiliconFlow provider as example - Implement Translator core class with prompt management - Create CLI entry point - Add configuration template and unit tests - Update changelog and discussion records Version: 0.0.2
This commit is contained in:
202
internal/config/config.go
Normal file
202
internal/config/config.go
Normal file
@@ -0,0 +1,202 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// Config 全局配置结构
|
||||
type Config struct {
|
||||
// 全局设置
|
||||
DefaultProvider string `yaml:"default_provider"`
|
||||
DefaultModel string `yaml:"default_model"`
|
||||
Timeout int `yaml:"timeout"` // 秒
|
||||
|
||||
// 厂商配置
|
||||
Providers map[string]ProviderConfig `yaml:"providers"`
|
||||
|
||||
// Prompt配置
|
||||
Prompts map[string]string `yaml:"prompts"`
|
||||
}
|
||||
|
||||
// ProviderConfig 厂商配置
|
||||
type ProviderConfig struct {
|
||||
APIHost string `yaml:"api_host"`
|
||||
APIKey string `yaml:"api_key"`
|
||||
Model string `yaml:"model"`
|
||||
Enabled bool `yaml:"enabled"`
|
||||
}
|
||||
|
||||
// ConfigLoader 配置加载器接口
|
||||
type ConfigLoader interface {
|
||||
Load(path string) (*Config, error)
|
||||
Save(config *Config, path string) error
|
||||
}
|
||||
|
||||
// YAMLConfigLoader YAML配置加载器实现
|
||||
type YAMLConfigLoader struct{}
|
||||
|
||||
// Load 加载YAML配置文件
|
||||
func (l *YAMLConfigLoader) Load(path string) (*Config, error) {
|
||||
// 读取文件
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取配置文件失败: %w", err)
|
||||
}
|
||||
|
||||
// 替换环境变量
|
||||
content := string(data)
|
||||
content = os.ExpandEnv(content)
|
||||
|
||||
// 解析YAML
|
||||
config := &Config{}
|
||||
if err := yaml.Unmarshal([]byte(content), config); err != nil {
|
||||
return nil, fmt.Errorf("解析配置文件失败: %w", err)
|
||||
}
|
||||
|
||||
// 设置默认值
|
||||
config.setDefaults()
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// Save 保存配置到文件
|
||||
func (l *YAMLConfigLoader) Save(config *Config, path string) error {
|
||||
// 确保目录存在
|
||||
dir := filepath.Dir(path)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return fmt.Errorf("创建配置目录失败: %w", err)
|
||||
}
|
||||
|
||||
// 序列化为YAML
|
||||
data, err := yaml.Marshal(config)
|
||||
if err != nil {
|
||||
return fmt.Errorf("序列化配置失败: %w", err)
|
||||
}
|
||||
|
||||
// 写入文件
|
||||
if err := os.WriteFile(path, data, 0644); err != nil {
|
||||
return fmt.Errorf("写入配置文件失败: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// setDefaults 设置默认值
|
||||
func (c *Config) setDefaults() {
|
||||
if c.DefaultProvider == "" {
|
||||
c.DefaultProvider = "siliconflow"
|
||||
}
|
||||
if c.Timeout <= 0 {
|
||||
c.Timeout = 30
|
||||
}
|
||||
if c.DefaultModel == "" {
|
||||
c.DefaultModel = "gpt-3.5-turbo"
|
||||
}
|
||||
|
||||
// 为每个厂商设置默认值
|
||||
for name, provider := range c.Providers {
|
||||
if provider.Model == "" {
|
||||
provider.Model = c.DefaultModel
|
||||
c.Providers[name] = provider
|
||||
}
|
||||
// 替换环境变量
|
||||
provider.APIKey = os.ExpandEnv(provider.APIKey)
|
||||
c.Providers[name] = provider
|
||||
}
|
||||
|
||||
// 确保Prompts映射存在
|
||||
if c.Prompts == nil {
|
||||
c.Prompts = make(map[string]string)
|
||||
}
|
||||
}
|
||||
|
||||
// GetProviderConfig 获取指定厂商的配置
|
||||
func (c *Config) GetProviderConfig(name string) (ProviderConfig, error) {
|
||||
config, exists := c.Providers[name]
|
||||
if !exists {
|
||||
return ProviderConfig{}, fmt.Errorf("未找到厂商配置: %s", name)
|
||||
}
|
||||
if !config.Enabled {
|
||||
return ProviderConfig{}, fmt.Errorf("厂商未启用: %s", name)
|
||||
}
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// GetPrompt 获取指定名称的Prompt
|
||||
func (c *Config) GetPrompt(name string) (string, bool) {
|
||||
prompt, exists := c.Prompts[name]
|
||||
return prompt, exists
|
||||
}
|
||||
|
||||
// ExpandEnv 扩展环境变量(辅助函数)
|
||||
func ExpandEnv(s string) string {
|
||||
return os.ExpandEnv(s)
|
||||
}
|
||||
|
||||
// Validate 验证配置是否有效
|
||||
func (c *Config) Validate() error {
|
||||
if c.DefaultProvider == "" {
|
||||
return fmt.Errorf("默认厂商不能为空")
|
||||
}
|
||||
|
||||
// 检查默认厂商是否在配置中
|
||||
if _, exists := c.Providers[c.DefaultProvider]; !exists {
|
||||
return fmt.Errorf("默认厂商 '%s' 未在配置中定义", c.DefaultProvider)
|
||||
}
|
||||
|
||||
// 检查每个厂商配置
|
||||
for name, provider := range c.Providers {
|
||||
if provider.APIKey == "" {
|
||||
return fmt.Errorf("厂商 '%s' 的API密钥不能为空", name)
|
||||
}
|
||||
if provider.APIHost == "" {
|
||||
return fmt.Errorf("厂商 '%s' 的API主机不能为空", name)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetDefaultProvider 获取默认厂商配置
|
||||
func (c *Config) GetDefaultProvider() (ProviderConfig, error) {
|
||||
return c.GetProviderConfig(c.DefaultProvider)
|
||||
}
|
||||
|
||||
// IsProviderEnabled 检查厂商是否启用
|
||||
func (c *Config) IsProviderEnabled(name string) bool {
|
||||
config, exists := c.Providers[name]
|
||||
return exists && config.Enabled
|
||||
}
|
||||
|
||||
// GetEnabledProviders 获取所有启用的厂商
|
||||
func (c *Config) GetEnabledProviders() []string {
|
||||
var enabled []string
|
||||
for name, config := range c.Providers {
|
||||
if config.Enabled {
|
||||
enabled = append(enabled, name)
|
||||
}
|
||||
}
|
||||
return enabled
|
||||
}
|
||||
|
||||
// String 返回配置的字符串表示(隐藏敏感信息)
|
||||
func (c *Config) String() string {
|
||||
var builder strings.Builder
|
||||
builder.WriteString(fmt.Sprintf("DefaultProvider: %s\n", c.DefaultProvider))
|
||||
builder.WriteString(fmt.Sprintf("DefaultModel: %s\n", c.DefaultModel))
|
||||
builder.WriteString(fmt.Sprintf("Timeout: %d seconds\n", c.Timeout))
|
||||
builder.WriteString("Providers:\n")
|
||||
for name, provider := range c.Providers {
|
||||
builder.WriteString(fmt.Sprintf(" %s: enabled=%v, model=%s\n", name, provider.Enabled, provider.Model))
|
||||
}
|
||||
builder.WriteString("Prompts:\n")
|
||||
for name := range c.Prompts {
|
||||
builder.WriteString(fmt.Sprintf(" %s\n", name))
|
||||
}
|
||||
return builder.String()
|
||||
}
|
||||
177
internal/config/config_test.go
Normal file
177
internal/config/config_test.go
Normal file
@@ -0,0 +1,177 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestConfigLoader_Load(t *testing.T) {
|
||||
// 创建临时配置文件
|
||||
tmpDir := t.TempDir()
|
||||
configFile := filepath.Join(tmpDir, "config.yaml")
|
||||
|
||||
configContent := `
|
||||
default_provider: "siliconflow"
|
||||
default_model: "gpt-3.5-turbo"
|
||||
timeout: 30
|
||||
|
||||
providers:
|
||||
siliconflow:
|
||||
api_host: "https://api.siliconflow.cn/v1"
|
||||
api_key: "${TEST_API_KEY}"
|
||||
model: "siliconflow-base"
|
||||
enabled: true
|
||||
|
||||
prompts:
|
||||
simple: "请用简单易懂的语言翻译以下内容。"
|
||||
`
|
||||
|
||||
if err := os.WriteFile(configFile, []byte(configContent), 0644); err != nil {
|
||||
t.Fatalf("创建配置文件失败: %v", err)
|
||||
}
|
||||
|
||||
// 设置环境变量
|
||||
os.Setenv("TEST_API_KEY", "test-key-123")
|
||||
defer os.Unsetenv("TEST_API_KEY")
|
||||
|
||||
// 加载配置
|
||||
loader := &YAMLConfigLoader{}
|
||||
config, err := loader.Load(configFile)
|
||||
if err != nil {
|
||||
t.Fatalf("加载配置失败: %v", err)
|
||||
}
|
||||
|
||||
// 验证配置
|
||||
if config.DefaultProvider != "siliconflow" {
|
||||
t.Errorf("Expected default provider 'siliconflow', got '%s'", config.DefaultProvider)
|
||||
}
|
||||
|
||||
if config.DefaultModel != "gpt-3.5-turbo" {
|
||||
t.Errorf("Expected default model 'gpt-3.5-turbo', got '%s'", config.DefaultModel)
|
||||
}
|
||||
|
||||
if config.Timeout != 30 {
|
||||
t.Errorf("Expected timeout 30, got %d", config.Timeout)
|
||||
}
|
||||
|
||||
// 验证厂商配置
|
||||
providerConfig, err := config.GetProviderConfig("siliconflow")
|
||||
if err != nil {
|
||||
t.Fatalf("获取厂商配置失败: %v", err)
|
||||
}
|
||||
|
||||
if providerConfig.APIKey != "test-key-123" {
|
||||
t.Errorf("Expected API key 'test-key-123', got '%s'", providerConfig.APIKey)
|
||||
}
|
||||
|
||||
if providerConfig.APIHost != "https://api.siliconflow.cn/v1" {
|
||||
t.Errorf("Expected API host 'https://api.siliconflow.cn/v1', got '%s'", providerConfig.APIHost)
|
||||
}
|
||||
|
||||
// 验证Prompt
|
||||
prompt, exists := config.GetPrompt("simple")
|
||||
if !exists {
|
||||
t.Error("Prompt 'simple' should exist")
|
||||
}
|
||||
if prompt != "请用简单易懂的语言翻译以下内容。" {
|
||||
t.Errorf("Unexpected prompt content: %s", prompt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_Validate(t *testing.T) {
|
||||
config := &Config{
|
||||
DefaultProvider: "siliconflow",
|
||||
DefaultModel: "gpt-3.5-turbo",
|
||||
Timeout: 30,
|
||||
Providers: map[string]ProviderConfig{
|
||||
"siliconflow": {
|
||||
APIHost: "https://api.siliconflow.cn/v1",
|
||||
APIKey: "test-key",
|
||||
Model: "siliconflow-base",
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
Prompts: map[string]string{
|
||||
"simple": "test prompt",
|
||||
},
|
||||
}
|
||||
|
||||
if err := config.Validate(); err != nil {
|
||||
t.Errorf("Config validation failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_GetProviderConfig(t *testing.T) {
|
||||
config := &Config{
|
||||
Providers: map[string]ProviderConfig{
|
||||
"siliconflow": {
|
||||
APIHost: "https://api.siliconflow.cn/v1",
|
||||
APIKey: "test-key",
|
||||
Model: "siliconflow-base",
|
||||
Enabled: true,
|
||||
},
|
||||
"volcano": {
|
||||
APIHost: "https://api.volcengine.com/v1",
|
||||
APIKey: "test-key",
|
||||
Model: "volcano-chat",
|
||||
Enabled: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// 测试获取启用的厂商
|
||||
_, err := config.GetProviderConfig("siliconflow")
|
||||
if err != nil {
|
||||
t.Errorf("Should get enabled provider: %v", err)
|
||||
}
|
||||
|
||||
// 测试获取禁用的厂商
|
||||
_, err = config.GetProviderConfig("volcano")
|
||||
if err == nil {
|
||||
t.Error("Should return error for disabled provider")
|
||||
}
|
||||
|
||||
// 测试获取不存在的厂商
|
||||
_, err = config.GetProviderConfig("nonexistent")
|
||||
if err == nil {
|
||||
t.Error("Should return error for non-existent provider")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_GetEnabledProviders(t *testing.T) {
|
||||
config := &Config{
|
||||
Providers: map[string]ProviderConfig{
|
||||
"siliconflow": {Enabled: true},
|
||||
"volcano": {Enabled: false},
|
||||
"qwen": {Enabled: true},
|
||||
},
|
||||
}
|
||||
|
||||
enabled := config.GetEnabledProviders()
|
||||
if len(enabled) != 2 {
|
||||
t.Errorf("Expected 2 enabled providers, got %d", len(enabled))
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig_String(t *testing.T) {
|
||||
config := &Config{
|
||||
DefaultProvider: "siliconflow",
|
||||
DefaultModel: "gpt-3.5-turbo",
|
||||
Timeout: 30,
|
||||
Providers: map[string]ProviderConfig{
|
||||
"siliconflow": {
|
||||
Enabled: true,
|
||||
Model: "siliconflow-base",
|
||||
},
|
||||
},
|
||||
Prompts: map[string]string{
|
||||
"simple": "test prompt",
|
||||
},
|
||||
}
|
||||
|
||||
str := config.String()
|
||||
if str == "" {
|
||||
t.Error("String representation should not be empty")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user