Files
BingPaper/internal/config/config.go

338 lines
11 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package config
import (
"fmt"
"os"
"sort"
"strings"
"sync"
"time"
"github.com/fsnotify/fsnotify"
"github.com/spf13/viper"
"gopkg.in/yaml.v3"
"BingPaper/internal/util"
)
type Config struct {
Server ServerConfig `mapstructure:"server" yaml:"server"`
Log LogConfig `mapstructure:"log" yaml:"log"`
API APIConfig `mapstructure:"api" yaml:"api"`
Cron CronConfig `mapstructure:"cron" yaml:"cron"`
Retention RetentionConfig `mapstructure:"retention" yaml:"retention"`
DB DBConfig `mapstructure:"db" yaml:"db"`
Storage StorageConfig `mapstructure:"storage" yaml:"storage"`
Admin AdminConfig `mapstructure:"admin" yaml:"admin"`
Token TokenConfig `mapstructure:"token" yaml:"token"`
Feature FeatureConfig `mapstructure:"feature" yaml:"feature"`
Web WebConfig `mapstructure:"web" yaml:"web"`
Fetcher FetcherConfig `mapstructure:"fetcher" yaml:"fetcher"`
}
type ServerConfig struct {
Port int `mapstructure:"port" yaml:"port"`
BaseURL string `mapstructure:"base_url" yaml:"base_url"`
}
type LogConfig struct {
Level string `mapstructure:"level" yaml:"level"`
Filename string `mapstructure:"filename" yaml:"filename"` // 业务日志文件名
DBFilename string `mapstructure:"db_filename" yaml:"db_filename"` // 数据库日志文件名
MaxSize int `mapstructure:"max_size" yaml:"max_size"` // 每个日志文件最大大小 (MB)
MaxBackups int `mapstructure:"max_backups" yaml:"max_backups"` // 保留旧日志文件最大个数
MaxAge int `mapstructure:"max_age" yaml:"max_age"` // 保留旧日志文件最大天数
Compress bool `mapstructure:"compress" yaml:"compress"` // 是否压缩旧日志文件
LogConsole bool `mapstructure:"log_console" yaml:"log_console"` // 是否同时输出到控制台
ShowDBLog bool `mapstructure:"show_db_log" yaml:"show_db_log"` // 是否在控制台显示数据库日志
DBLogLevel string `mapstructure:"db_log_level" yaml:"db_log_level"` // 数据库日志级别: debug, info, warn, error
}
func (c LogConfig) GetLevel() string { return c.Level }
func (c LogConfig) GetFilename() string { return c.Filename }
func (c LogConfig) GetDBFilename() string { return c.DBFilename }
func (c LogConfig) GetMaxSize() int { return c.MaxSize }
func (c LogConfig) GetMaxBackups() int { return c.MaxBackups }
func (c LogConfig) GetMaxAge() int { return c.MaxAge }
func (c LogConfig) GetCompress() bool { return c.Compress }
func (c LogConfig) GetLogConsole() bool { return c.LogConsole }
func (c LogConfig) GetShowDBLog() bool { return c.ShowDBLog }
func (c LogConfig) GetDBLogLevel() string { return c.DBLogLevel }
type APIConfig struct {
Mode string `mapstructure:"mode" yaml:"mode"` // local | redirect
EnableMktFallback bool `mapstructure:"enable_mkt_fallback" yaml:"enable_mkt_fallback"` // 当请求的地区不存在时,是否回退到默认地区
}
type CronConfig struct {
Enabled bool `mapstructure:"enabled" yaml:"enabled"`
DailySpec string `mapstructure:"daily_spec" yaml:"daily_spec"`
}
type RetentionConfig struct {
Days int `mapstructure:"days" yaml:"days"`
}
type DBConfig struct {
Type string `mapstructure:"type" yaml:"type"` // sqlite/mysql/postgres
DSN string `mapstructure:"dsn" yaml:"dsn"`
}
type StorageConfig struct {
Type string `mapstructure:"type" yaml:"type"` // local/s3/webdav
Local LocalConfig `mapstructure:"local" yaml:"local"`
S3 S3Config `mapstructure:"s3" yaml:"s3"`
WebDAV WebDAVConfig `mapstructure:"webdav" yaml:"webdav"`
}
type LocalConfig struct {
Root string `mapstructure:"root" yaml:"root"`
}
type S3Config struct {
Endpoint string `mapstructure:"endpoint" yaml:"endpoint"`
Region string `mapstructure:"region" yaml:"region"`
Bucket string `mapstructure:"bucket" yaml:"bucket"`
AccessKey string `mapstructure:"access_key" yaml:"access_key"`
SecretKey string `mapstructure:"secret_key" yaml:"secret_key"`
PublicURLPrefix string `mapstructure:"public_url_prefix" yaml:"public_url_prefix"`
ForcePathStyle bool `mapstructure:"force_path_style" yaml:"force_path_style"`
}
type WebDAVConfig struct {
URL string `mapstructure:"url" yaml:"url"`
Username string `mapstructure:"username" yaml:"username"`
Password string `mapstructure:"password" yaml:"password"`
PublicURLPrefix string `mapstructure:"public_url_prefix" yaml:"public_url_prefix"`
}
type AdminConfig struct {
PasswordBcrypt string `mapstructure:"password_bcrypt" yaml:"password_bcrypt"`
}
type TokenConfig struct {
DefaultTTL string `mapstructure:"default_ttl" yaml:"default_ttl"`
}
type FeatureConfig struct {
WriteDailyFiles bool `mapstructure:"write_daily_files" yaml:"write_daily_files"`
}
type WebConfig struct {
Path string `mapstructure:"path" yaml:"path"`
}
type FetcherConfig struct {
Regions []string `mapstructure:"regions" yaml:"regions"`
}
// Bing 默认配置 (内置)
const (
BingMkt = "zh-CN"
BingFetchN = 8
BingAPIBase = "https://www.bing.com/HPImageArchive.aspx"
)
var (
GlobalConfig *Config
configLock sync.RWMutex
v *viper.Viper
// OnDBConfigChange 当数据库配置发生变更时的回调函数
OnDBConfigChange func(newCfg *Config)
)
func Init(configPath string) error {
v = viper.New()
if configPath != "" {
v.SetConfigFile(configPath)
} else {
v.SetConfigName("config")
v.SetConfigType("yaml")
v.AddConfigPath("./data")
v.AddConfigPath(".")
}
v.SetDefault("server.port", 8080)
v.SetDefault("log.level", "info")
v.SetDefault("log.filename", "data/logs/app.log")
v.SetDefault("log.db_filename", "data/logs/db.log")
v.SetDefault("log.max_size", 100)
v.SetDefault("log.max_backups", 3)
v.SetDefault("log.max_age", 7)
v.SetDefault("log.compress", true)
v.SetDefault("log.log_console", true)
v.SetDefault("log.show_db_log", false)
v.SetDefault("log.db_log_level", "info")
v.SetDefault("api.mode", "local")
v.SetDefault("api.enable_mkt_fallback", true)
v.SetDefault("cron.enabled", true)
v.SetDefault("cron.daily_spec", "20 8-23/4 * * *")
v.SetDefault("retention.days", 0)
v.SetDefault("db.type", "sqlite")
v.SetDefault("db.dsn", "data/bing_paper.db")
v.SetDefault("storage.type", "local")
v.SetDefault("storage.local.root", "data/picture")
v.SetDefault("token.default_ttl", "168h")
v.SetDefault("feature.write_daily_files", true)
v.SetDefault("web.path", "web")
// 默认抓取所有支持的地区
var defaultRegions []string
for _, r := range util.AllRegions {
defaultRegions = append(defaultRegions, r.Value)
}
v.SetDefault("fetcher.regions", defaultRegions)
v.SetDefault("admin.password_bcrypt", "$2a$10$fYHPeWHmwObephJvtlyH1O8DIgaLk5TINbi9BOezo2M8cSjmJchka") // 默认密码: admin123
// 绑定环境变量
v.SetEnvPrefix("BINGPAPER")
v.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
v.AutomaticEnv()
if err := v.ReadInConfig(); err != nil {
// 如果指定了配置文件但读取失败(且不是找不到文件的错误),或者没指定但也没找到
_, isNotFound := err.(viper.ConfigFileNotFoundError)
// 如果显式指定了文件viper 报错可能不是 ConfigFileNotFoundError 而是 os.PathError
if !isNotFound && configPath != "" {
if _, statErr := os.Stat(configPath); os.IsNotExist(statErr) {
isNotFound = true
}
}
if !isNotFound {
return err
}
// 如果文件不存在,我们使用默认值并尝试创建一个默认配置文件
targetConfigPath := configPath
if targetConfigPath == "" {
targetConfigPath = "data/config.yaml"
}
fmt.Printf("Config file not found, creating default config at %s\n", targetConfigPath)
var defaultCfg Config
if err := v.Unmarshal(&defaultCfg); err == nil {
data, _ := yaml.Marshal(&defaultCfg)
if err := os.WriteFile(targetConfigPath, data, 0644); err != nil {
fmt.Printf("Warning: Failed to create default config file: %v\n", err)
}
}
}
var cfg Config
if err := v.Unmarshal(&cfg); err != nil {
return err
}
GlobalConfig = &cfg
v.OnConfigChange(func(e fsnotify.Event) {
fmt.Println("Config file changed:", e.Name)
var newCfg Config
if err := v.Unmarshal(&newCfg); err == nil {
configLock.Lock()
oldDBConfig := GlobalConfig.DB
GlobalConfig = &newCfg
newDBConfig := newCfg.DB
configLock.Unlock()
// 检查数据库配置是否发生变更
if oldDBConfig.Type != newDBConfig.Type || oldDBConfig.DSN != newDBConfig.DSN {
// 触发数据库迁移逻辑
// 这里由于循环依赖问题,我们可能需要通过回调或者一个统一的 Reload 函数来处理
if OnDBConfigChange != nil {
OnDBConfigChange(&newCfg)
}
}
}
})
v.WatchConfig()
return nil
}
func GetConfig() *Config {
configLock.RLock()
defer configLock.RUnlock()
return GlobalConfig
}
func SaveConfig(cfg *Config) error {
configLock.Lock()
defer configLock.Unlock()
// 1. 使用 yaml.v3 序列化,它会尊重结构体字段顺序及 yaml 标签
data, err := yaml.Marshal(cfg)
if err != nil {
return fmt.Errorf("failed to marshal config: %v", err)
}
// 2. 获取当前使用的配置文件路径
targetPath := v.ConfigFileUsed()
if targetPath == "" {
targetPath = "data/config.yaml" // 默认回退路径
}
// 3. 直接写入文件,绕过 viper 的字母序排序逻辑
if err := os.WriteFile(targetPath, data, 0644); err != nil {
return fmt.Errorf("failed to write config file: %v", err)
}
// 4. 同步更新内存中的全局配置对象
GlobalConfig = cfg
return nil
}
func GetRawViper() *viper.Viper {
return v
}
// GetAllSettings 返回所有生效配置项
func GetAllSettings() map[string]interface{} {
return v.AllSettings()
}
// GetFormattedSettings 以 key: value 形式返回所有配置项的字符串
func GetFormattedSettings() string {
keys := v.AllKeys()
sort.Strings(keys)
var sb strings.Builder
for _, k := range keys {
sb.WriteString(fmt.Sprintf("%s: %v\n", k, v.Get(k)))
}
return sb.String()
}
// GetEnvOverrides 返回环境变量覆盖详情(已排序)
func GetEnvOverrides() []string {
var overrides []string
keys := v.AllKeys()
sort.Strings(keys)
for _, key := range keys {
// 根据 viper 的配置生成对应的环境变量名
// Prefix: BINGPAPER, KeyReplacer: . -> _
envKey := strings.ToUpper(fmt.Sprintf("BINGPAPER_%s", strings.ReplaceAll(key, ".", "_")))
if val, ok := os.LookupEnv(envKey); ok {
overrides = append(overrides, fmt.Sprintf("%s: %s=%s", key, envKey, val))
}
}
return overrides
}
func GetTokenTTL() time.Duration {
ttl, err := time.ParseDuration(GetConfig().Token.DefaultTTL)
if err != nil {
return 168 * time.Hour
}
return ttl
}
// GetDefaultMkt 返回生效的默认地区编码
func (c *Config) GetDefaultMkt() string {
if len(c.Fetcher.Regions) > 0 {
return c.Fetcher.Regions[0]
}
return BingMkt
}