mirror of
https://git.fightbot.fun/hxuanyu/BingPaper.git
synced 2026-02-15 07:19:33 +08:00
340 lines
11 KiB
Go
340 lines
11 KiB
Go
package config
|
||
|
||
import (
|
||
"fmt"
|
||
"os"
|
||
"sort"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
|
||
"github.com/fsnotify/fsnotify"
|
||
"github.com/spf13/viper"
|
||
"gopkg.in/yaml.v3"
|
||
|
||
"BingPaper/internal/util"
|
||
)
|
||
|
||
type Config struct {
|
||
Server ServerConfig `mapstructure:"server" yaml:"server"`
|
||
Log LogConfig `mapstructure:"log" yaml:"log"`
|
||
API APIConfig `mapstructure:"api" yaml:"api"`
|
||
Cron CronConfig `mapstructure:"cron" yaml:"cron"`
|
||
Retention RetentionConfig `mapstructure:"retention" yaml:"retention"`
|
||
DB DBConfig `mapstructure:"db" yaml:"db"`
|
||
Storage StorageConfig `mapstructure:"storage" yaml:"storage"`
|
||
Admin AdminConfig `mapstructure:"admin" yaml:"admin"`
|
||
Token TokenConfig `mapstructure:"token" yaml:"token"`
|
||
Feature FeatureConfig `mapstructure:"feature" yaml:"feature"`
|
||
Web WebConfig `mapstructure:"web" yaml:"web"`
|
||
Fetcher FetcherConfig `mapstructure:"fetcher" yaml:"fetcher"`
|
||
}
|
||
|
||
type ServerConfig struct {
|
||
Port int `mapstructure:"port" yaml:"port"`
|
||
BaseURL string `mapstructure:"base_url" yaml:"base_url"`
|
||
}
|
||
|
||
type LogConfig struct {
|
||
Level string `mapstructure:"level" yaml:"level"`
|
||
Filename string `mapstructure:"filename" yaml:"filename"` // 业务日志文件名
|
||
DBFilename string `mapstructure:"db_filename" yaml:"db_filename"` // 数据库日志文件名
|
||
MaxSize int `mapstructure:"max_size" yaml:"max_size"` // 每个日志文件最大大小 (MB)
|
||
MaxBackups int `mapstructure:"max_backups" yaml:"max_backups"` // 保留旧日志文件最大个数
|
||
MaxAge int `mapstructure:"max_age" yaml:"max_age"` // 保留旧日志文件最大天数
|
||
Compress bool `mapstructure:"compress" yaml:"compress"` // 是否压缩旧日志文件
|
||
LogConsole bool `mapstructure:"log_console" yaml:"log_console"` // 是否同时输出到控制台
|
||
ShowDBLog bool `mapstructure:"show_db_log" yaml:"show_db_log"` // 是否在控制台显示数据库日志
|
||
DBLogLevel string `mapstructure:"db_log_level" yaml:"db_log_level"` // 数据库日志级别: debug, info, warn, error
|
||
}
|
||
|
||
func (c LogConfig) GetLevel() string { return c.Level }
|
||
func (c LogConfig) GetFilename() string { return c.Filename }
|
||
func (c LogConfig) GetDBFilename() string { return c.DBFilename }
|
||
func (c LogConfig) GetMaxSize() int { return c.MaxSize }
|
||
func (c LogConfig) GetMaxBackups() int { return c.MaxBackups }
|
||
func (c LogConfig) GetMaxAge() int { return c.MaxAge }
|
||
func (c LogConfig) GetCompress() bool { return c.Compress }
|
||
func (c LogConfig) GetLogConsole() bool { return c.LogConsole }
|
||
func (c LogConfig) GetShowDBLog() bool { return c.ShowDBLog }
|
||
func (c LogConfig) GetDBLogLevel() string { return c.DBLogLevel }
|
||
|
||
type APIConfig struct {
|
||
Mode string `mapstructure:"mode" yaml:"mode"` // local | redirect
|
||
EnableMktFallback bool `mapstructure:"enable_mkt_fallback" yaml:"enable_mkt_fallback"` // 当请求的地区不存在时,是否回退到默认地区
|
||
EnableOnDemandFetch bool `mapstructure:"enable_on_demand_fetch" yaml:"enable_on_demand_fetch"` // 是否启用按需抓取
|
||
}
|
||
|
||
type CronConfig struct {
|
||
Enabled bool `mapstructure:"enabled" yaml:"enabled"`
|
||
DailySpec string `mapstructure:"daily_spec" yaml:"daily_spec"`
|
||
}
|
||
|
||
type RetentionConfig struct {
|
||
Days int `mapstructure:"days" yaml:"days"`
|
||
}
|
||
|
||
type DBConfig struct {
|
||
Type string `mapstructure:"type" yaml:"type"` // sqlite/mysql/postgres
|
||
DSN string `mapstructure:"dsn" yaml:"dsn"`
|
||
}
|
||
|
||
type StorageConfig struct {
|
||
Type string `mapstructure:"type" yaml:"type"` // local/s3/webdav
|
||
Local LocalConfig `mapstructure:"local" yaml:"local"`
|
||
S3 S3Config `mapstructure:"s3" yaml:"s3"`
|
||
WebDAV WebDAVConfig `mapstructure:"webdav" yaml:"webdav"`
|
||
}
|
||
|
||
type LocalConfig struct {
|
||
Root string `mapstructure:"root" yaml:"root"`
|
||
}
|
||
|
||
type S3Config struct {
|
||
Endpoint string `mapstructure:"endpoint" yaml:"endpoint"`
|
||
Region string `mapstructure:"region" yaml:"region"`
|
||
Bucket string `mapstructure:"bucket" yaml:"bucket"`
|
||
AccessKey string `mapstructure:"access_key" yaml:"access_key"`
|
||
SecretKey string `mapstructure:"secret_key" yaml:"secret_key"`
|
||
PublicURLPrefix string `mapstructure:"public_url_prefix" yaml:"public_url_prefix"`
|
||
ForcePathStyle bool `mapstructure:"force_path_style" yaml:"force_path_style"`
|
||
}
|
||
|
||
type WebDAVConfig struct {
|
||
URL string `mapstructure:"url" yaml:"url"`
|
||
Username string `mapstructure:"username" yaml:"username"`
|
||
Password string `mapstructure:"password" yaml:"password"`
|
||
PublicURLPrefix string `mapstructure:"public_url_prefix" yaml:"public_url_prefix"`
|
||
}
|
||
|
||
type AdminConfig struct {
|
||
PasswordBcrypt string `mapstructure:"password_bcrypt" yaml:"password_bcrypt"`
|
||
}
|
||
|
||
type TokenConfig struct {
|
||
DefaultTTL string `mapstructure:"default_ttl" yaml:"default_ttl"`
|
||
}
|
||
|
||
type FeatureConfig struct {
|
||
WriteDailyFiles bool `mapstructure:"write_daily_files" yaml:"write_daily_files"`
|
||
}
|
||
|
||
type WebConfig struct {
|
||
Path string `mapstructure:"path" yaml:"path"`
|
||
}
|
||
|
||
type FetcherConfig struct {
|
||
Regions []string `mapstructure:"regions" yaml:"regions"`
|
||
}
|
||
|
||
// Bing 默认配置 (内置)
|
||
const (
|
||
BingMkt = "zh-CN"
|
||
BingFetchN = 8
|
||
BingAPIBase = "https://www.bing.com/HPImageArchive.aspx"
|
||
)
|
||
|
||
var (
|
||
GlobalConfig *Config
|
||
configLock sync.RWMutex
|
||
v *viper.Viper
|
||
|
||
// OnDBConfigChange 当数据库配置发生变更时的回调函数
|
||
OnDBConfigChange func(newCfg *Config)
|
||
)
|
||
|
||
func Init(configPath string) error {
|
||
v = viper.New()
|
||
if configPath != "" {
|
||
v.SetConfigFile(configPath)
|
||
} else {
|
||
v.SetConfigName("config")
|
||
v.SetConfigType("yaml")
|
||
v.AddConfigPath("./data")
|
||
v.AddConfigPath(".")
|
||
}
|
||
|
||
v.SetDefault("server.port", 8080)
|
||
v.SetDefault("log.level", "info")
|
||
v.SetDefault("log.filename", "data/logs/app.log")
|
||
v.SetDefault("log.db_filename", "data/logs/db.log")
|
||
v.SetDefault("log.max_size", 100)
|
||
v.SetDefault("log.max_backups", 3)
|
||
v.SetDefault("log.max_age", 7)
|
||
v.SetDefault("log.compress", true)
|
||
v.SetDefault("log.log_console", true)
|
||
v.SetDefault("log.show_db_log", false)
|
||
v.SetDefault("log.db_log_level", "info")
|
||
v.SetDefault("api.mode", "redirect")
|
||
v.SetDefault("api.enable_mkt_fallback", false)
|
||
v.SetDefault("api.enable_on_demand_fetch", false)
|
||
v.SetDefault("cron.enabled", true)
|
||
v.SetDefault("cron.daily_spec", "20 8-23/4 * * *")
|
||
v.SetDefault("retention.days", 0)
|
||
v.SetDefault("db.type", "sqlite")
|
||
v.SetDefault("db.dsn", "data/bing_paper.db")
|
||
v.SetDefault("storage.type", "local")
|
||
v.SetDefault("storage.local.root", "data/picture")
|
||
v.SetDefault("token.default_ttl", "168h")
|
||
v.SetDefault("feature.write_daily_files", true)
|
||
v.SetDefault("web.path", "web")
|
||
|
||
// 默认抓取所有支持的地区
|
||
var defaultRegions []string
|
||
for _, r := range util.AllRegions {
|
||
defaultRegions = append(defaultRegions, r.Value)
|
||
}
|
||
v.SetDefault("fetcher.regions", defaultRegions)
|
||
v.SetDefault("admin.password_bcrypt", "$2a$10$fYHPeWHmwObephJvtlyH1O8DIgaLk5TINbi9BOezo2M8cSjmJchka") // 默认密码: admin123
|
||
|
||
// 绑定环境变量
|
||
v.SetEnvPrefix("BINGPAPER")
|
||
v.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
|
||
v.AutomaticEnv()
|
||
|
||
if err := v.ReadInConfig(); err != nil {
|
||
// 如果指定了配置文件但读取失败(且不是找不到文件的错误),或者没指定但也没找到
|
||
_, isNotFound := err.(viper.ConfigFileNotFoundError)
|
||
// 如果显式指定了文件,viper 报错可能不是 ConfigFileNotFoundError 而是 os.PathError
|
||
if !isNotFound && configPath != "" {
|
||
if _, statErr := os.Stat(configPath); os.IsNotExist(statErr) {
|
||
isNotFound = true
|
||
}
|
||
}
|
||
|
||
if !isNotFound {
|
||
return err
|
||
}
|
||
|
||
// 如果文件不存在,我们使用默认值并尝试创建一个默认配置文件
|
||
targetConfigPath := configPath
|
||
if targetConfigPath == "" {
|
||
targetConfigPath = "data/config.yaml"
|
||
}
|
||
fmt.Printf("Config file not found, creating default config at %s\n", targetConfigPath)
|
||
|
||
var defaultCfg Config
|
||
if err := v.Unmarshal(&defaultCfg); err == nil {
|
||
data, _ := yaml.Marshal(&defaultCfg)
|
||
if err := os.WriteFile(targetConfigPath, data, 0644); err != nil {
|
||
fmt.Printf("Warning: Failed to create default config file: %v\n", err)
|
||
}
|
||
}
|
||
}
|
||
|
||
var cfg Config
|
||
if err := v.Unmarshal(&cfg); err != nil {
|
||
return err
|
||
}
|
||
|
||
GlobalConfig = &cfg
|
||
|
||
v.OnConfigChange(func(e fsnotify.Event) {
|
||
fmt.Println("Config file changed:", e.Name)
|
||
var newCfg Config
|
||
if err := v.Unmarshal(&newCfg); err == nil {
|
||
configLock.Lock()
|
||
oldDBConfig := GlobalConfig.DB
|
||
GlobalConfig = &newCfg
|
||
newDBConfig := newCfg.DB
|
||
configLock.Unlock()
|
||
|
||
// 检查数据库配置是否发生变更
|
||
if oldDBConfig.Type != newDBConfig.Type || oldDBConfig.DSN != newDBConfig.DSN {
|
||
// 触发数据库迁移逻辑
|
||
// 这里由于循环依赖问题,我们可能需要通过回调或者一个统一的 Reload 函数来处理
|
||
if OnDBConfigChange != nil {
|
||
OnDBConfigChange(&newCfg)
|
||
}
|
||
}
|
||
}
|
||
})
|
||
v.WatchConfig()
|
||
|
||
return nil
|
||
}
|
||
|
||
func GetConfig() *Config {
|
||
configLock.RLock()
|
||
defer configLock.RUnlock()
|
||
return GlobalConfig
|
||
}
|
||
|
||
func SaveConfig(cfg *Config) error {
|
||
configLock.Lock()
|
||
defer configLock.Unlock()
|
||
|
||
// 1. 使用 yaml.v3 序列化,它会尊重结构体字段顺序及 yaml 标签
|
||
data, err := yaml.Marshal(cfg)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to marshal config: %v", err)
|
||
}
|
||
|
||
// 2. 获取当前使用的配置文件路径
|
||
targetPath := v.ConfigFileUsed()
|
||
if targetPath == "" {
|
||
targetPath = "data/config.yaml" // 默认回退路径
|
||
}
|
||
|
||
// 3. 直接写入文件,绕过 viper 的字母序排序逻辑
|
||
if err := os.WriteFile(targetPath, data, 0644); err != nil {
|
||
return fmt.Errorf("failed to write config file: %v", err)
|
||
}
|
||
|
||
// 4. 同步更新内存中的全局配置对象
|
||
GlobalConfig = cfg
|
||
return nil
|
||
}
|
||
|
||
func GetRawViper() *viper.Viper {
|
||
return v
|
||
}
|
||
|
||
// GetAllSettings 返回所有生效配置项
|
||
func GetAllSettings() map[string]interface{} {
|
||
return v.AllSettings()
|
||
}
|
||
|
||
// GetFormattedSettings 以 key: value 形式返回所有配置项的字符串
|
||
func GetFormattedSettings() string {
|
||
keys := v.AllKeys()
|
||
sort.Strings(keys)
|
||
var sb strings.Builder
|
||
for _, k := range keys {
|
||
sb.WriteString(fmt.Sprintf("%s: %v\n", k, v.Get(k)))
|
||
}
|
||
return sb.String()
|
||
}
|
||
|
||
// GetEnvOverrides 返回环境变量覆盖详情(已排序)
|
||
func GetEnvOverrides() []string {
|
||
var overrides []string
|
||
keys := v.AllKeys()
|
||
sort.Strings(keys)
|
||
for _, key := range keys {
|
||
// 根据 viper 的配置生成对应的环境变量名
|
||
// Prefix: BINGPAPER, KeyReplacer: . -> _
|
||
envKey := strings.ToUpper(fmt.Sprintf("BINGPAPER_%s", strings.ReplaceAll(key, ".", "_")))
|
||
if val, ok := os.LookupEnv(envKey); ok {
|
||
overrides = append(overrides, fmt.Sprintf("%s: %s=%s", key, envKey, val))
|
||
}
|
||
}
|
||
return overrides
|
||
}
|
||
|
||
func GetTokenTTL() time.Duration {
|
||
ttl, err := time.ParseDuration(GetConfig().Token.DefaultTTL)
|
||
if err != nil {
|
||
return 168 * time.Hour
|
||
}
|
||
return ttl
|
||
}
|
||
|
||
// GetDefaultRegion 返回生效的默认地区编码
|
||
func (c *Config) GetDefaultRegion() string {
|
||
if len(c.Fetcher.Regions) > 0 {
|
||
return c.Fetcher.Regions[0]
|
||
}
|
||
return BingMkt
|
||
}
|