package config import ( "fmt" "os" "path/filepath" "strings" "gopkg.in/yaml.v3" ) // Config 全局配置结构 type Config struct { // 全局设置 DefaultProvider string `yaml:"default_provider"` DefaultModel string `yaml:"default_model"` Timeout int `yaml:"timeout"` // 秒 // 厂商配置 Providers map[string]ProviderConfig `yaml:"providers"` // Prompt配置 Prompts map[string]string `yaml:"prompts"` } // ProviderConfig 厂商配置 type ProviderConfig struct { APIHost string `yaml:"api_host"` APIKey string `yaml:"api_key"` Model string `yaml:"model"` Enabled bool `yaml:"enabled"` } // ConfigLoader 配置加载器接口 type ConfigLoader interface { Load(path string) (*Config, error) Save(config *Config, path string) error } // YAMLConfigLoader YAML配置加载器实现 type YAMLConfigLoader struct{} // Load 加载YAML配置文件 func (l *YAMLConfigLoader) Load(path string) (*Config, error) { // 读取文件 data, err := os.ReadFile(path) if err != nil { return nil, fmt.Errorf("读取配置文件失败: %w", err) } // 替换环境变量 content := string(data) content = os.ExpandEnv(content) // 解析YAML config := &Config{} if err := yaml.Unmarshal([]byte(content), config); err != nil { return nil, fmt.Errorf("解析配置文件失败: %w", err) } // 设置默认值 config.setDefaults() return config, nil } // Save 保存配置到文件 func (l *YAMLConfigLoader) Save(config *Config, path string) error { // 确保目录存在 dir := filepath.Dir(path) if err := os.MkdirAll(dir, 0755); err != nil { return fmt.Errorf("创建配置目录失败: %w", err) } // 序列化为YAML data, err := yaml.Marshal(config) if err != nil { return fmt.Errorf("序列化配置失败: %w", err) } // 写入文件 if err := os.WriteFile(path, data, 0644); err != nil { return fmt.Errorf("写入配置文件失败: %w", err) } return nil } // setDefaults 设置默认值 func (c *Config) setDefaults() { if c.DefaultProvider == "" { c.DefaultProvider = "siliconflow" } if c.Timeout <= 0 { c.Timeout = 30 } if c.DefaultModel == "" { c.DefaultModel = "gpt-3.5-turbo" } // 为每个厂商设置默认值 for name, provider := range c.Providers { if provider.Model == "" { provider.Model = c.DefaultModel c.Providers[name] = provider } // 替换环境变量 provider.APIKey = os.ExpandEnv(provider.APIKey) c.Providers[name] = provider } // 确保Prompts映射存在 if c.Prompts == nil { c.Prompts = make(map[string]string) } } // GetProviderConfig 获取指定厂商的配置 func (c *Config) GetProviderConfig(name string) (ProviderConfig, error) { config, exists := c.Providers[name] if !exists { return ProviderConfig{}, fmt.Errorf("未找到厂商配置: %s", name) } if !config.Enabled { return ProviderConfig{}, fmt.Errorf("厂商未启用: %s", name) } return config, nil } // GetPrompt 获取指定名称的Prompt func (c *Config) GetPrompt(name string) (string, bool) { prompt, exists := c.Prompts[name] return prompt, exists } // ExpandEnv 扩展环境变量(辅助函数) func ExpandEnv(s string) string { return os.ExpandEnv(s) } // Validate 验证配置是否有效 func (c *Config) Validate() error { if c.DefaultProvider == "" { return fmt.Errorf("默认厂商不能为空") } // 检查默认厂商是否在配置中 if _, exists := c.Providers[c.DefaultProvider]; !exists { return fmt.Errorf("默认厂商 '%s' 未在配置中定义", c.DefaultProvider) } // 检查每个厂商配置 for name, provider := range c.Providers { if provider.APIKey == "" { return fmt.Errorf("厂商 '%s' 的API密钥不能为空", name) } if provider.APIHost == "" { return fmt.Errorf("厂商 '%s' 的API主机不能为空", name) } } return nil } // GetDefaultProvider 获取默认厂商配置 func (c *Config) GetDefaultProvider() (ProviderConfig, error) { return c.GetProviderConfig(c.DefaultProvider) } // IsProviderEnabled 检查厂商是否启用 func (c *Config) IsProviderEnabled(name string) bool { config, exists := c.Providers[name] return exists && config.Enabled } // GetEnabledProviders 获取所有启用的厂商 func (c *Config) GetEnabledProviders() []string { var enabled []string for name, config := range c.Providers { if config.Enabled { enabled = append(enabled, name) } } return enabled } // String 返回配置的字符串表示(隐藏敏感信息) func (c *Config) String() string { var builder strings.Builder builder.WriteString(fmt.Sprintf("DefaultProvider: %s\n", c.DefaultProvider)) builder.WriteString(fmt.Sprintf("DefaultModel: %s\n", c.DefaultModel)) builder.WriteString(fmt.Sprintf("Timeout: %d seconds\n", c.Timeout)) builder.WriteString("Providers:\n") for name, provider := range c.Providers { builder.WriteString(fmt.Sprintf(" %s: enabled=%v, model=%s\n", name, provider.Enabled, provider.Model)) } builder.WriteString("Prompts:\n") for name := range c.Prompts { builder.WriteString(fmt.Sprintf(" %s\n", name)) } return builder.String() }