Files
FileRelay/internal/bootstrap/init.go
2026-01-28 20:44:34 +08:00

407 lines
10 KiB
Go

package bootstrap
import (
"FileRelay/internal/config"
"FileRelay/internal/model"
"FileRelay/internal/storage"
"context"
"crypto/rand"
"fmt"
"io"
"math/big"
"path/filepath"
"runtime"
"log/slog"
"os"
"strings"
"github.com/glebarez/sqlite"
"golang.org/x/crypto/bcrypt"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/gorm"
)
var DB *gorm.DB
func InitLog() {
level := slog.LevelInfo
switch strings.ToLower(config.GlobalConfig.Log.Level) {
case "debug":
level = slog.LevelDebug
case "info":
level = slog.LevelInfo
case "warn":
level = slog.LevelWarn
case "error":
level = slog.LevelError
}
var handlers []slog.Handler
// 1. 控制台处理器 (简化格式)
handlers = append(handlers, &ConsoleHandler{
out: os.Stdout,
level: level,
})
// 2. 文件处理器 (结构化 Text 格式)
if config.GlobalConfig.Log.FilePath != "" {
logDir := filepath.Dir(config.GlobalConfig.Log.FilePath)
if err := os.MkdirAll(logDir, 0755); err != nil {
fmt.Printf("Warning: Failed to create log directory: %v\n", err)
} else {
file, err := os.OpenFile(config.GlobalConfig.Log.FilePath, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0644)
if err != nil {
fmt.Printf("Warning: Failed to open log file: %v\n", err)
} else {
fileHandler := slog.NewTextHandler(file, &slog.HandlerOptions{
Level: level,
AddSource: true,
ReplaceAttr: func(groups []string, a slog.Attr) slog.Attr {
if a.Key == slog.TimeKey {
return slog.String(a.Key, a.Value.Time().Format("2006-01-02 15:04:05.000"))
}
if a.Key == slog.SourceKey {
source := a.Value.Any().(*slog.Source)
return slog.String(a.Key, fmt.Sprintf("%s:%d", filepath.Base(source.File), source.Line))
}
return a
},
})
handlers = append(handlers, fileHandler)
}
}
}
var finalHandler slog.Handler
if len(handlers) == 1 {
finalHandler = handlers[0]
} else {
finalHandler = &MultiHandler{handlers: handlers}
}
logger := slog.New(finalHandler).With("service", "filerelay")
slog.SetDefault(logger)
}
func InitDB() {
var err error
cfg := config.GlobalConfig.Database
DB, err = ConnectDB(cfg)
if err != nil {
slog.Error("Failed to initialize database", "type", cfg.Type, "error", err)
os.Exit(1)
}
slog.Info("Database initialized and migrated", "type", cfg.Type)
// 初始化存储
if err := ReloadStorage(); err != nil {
slog.Error("Failed to initialize storage", "error", err)
os.Exit(1)
}
// 初始化管理员 (如果不存在)
initAdmin()
}
func ConnectDB(cfg config.DatabaseConfig) (*gorm.DB, error) {
var dialector gorm.Dialector
switch strings.ToLower(cfg.Type) {
case "mysql":
params := cfg.Config
if params == "" {
params = "parseTime=True&loc=Local&charset=utf8mb4"
} else {
if !strings.Contains(strings.ToLower(params), "parsetime") {
params += "&parseTime=True"
}
if !strings.Contains(strings.ToLower(params), "loc=") {
params += "&loc=Local"
}
if !strings.Contains(strings.ToLower(params), "charset") {
params += "&charset=utf8mb4"
}
}
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?%s",
cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.DBName, params)
dialector = mysql.Open(dsn)
case "postgres", "postgresql":
dsn := fmt.Sprintf("host=%s user=%s password=%s dbname=%s port=%d %s",
cfg.Host, cfg.User, cfg.Password, cfg.DBName, cfg.Port, cfg.Config)
dialector = postgres.Open(dsn)
case "sqlite", "sqlite3":
fallthrough
default:
dbPath := cfg.Path
if dbPath == "" {
dbPath = "data/file_relay.db"
}
dialector = sqlite.Open(dbPath)
}
db, err := gorm.Open(dialector, &gorm.Config{})
if err != nil {
return nil, err
}
// 自动迁移
err = db.AutoMigrate(
&model.FileBatch{},
&model.FileItem{},
&model.APIToken{},
)
if err != nil {
return nil, err
}
return db, nil
}
func ReloadDB(newCfg config.DatabaseConfig) error {
newDB, err := ConnectDB(newCfg)
if err != nil {
return err
}
if DB != nil {
// 检查是否真的是不同的数据库,避免自迁移导致冲突
// 这里简单判断类型或连接串是否变化,或者直接让用户决定
// 为了安全,我们只在连接参数确实变化时才尝试迁移
slog.Info("Starting data migration to new database...")
if err := MigrateData(DB, newDB); err != nil {
slog.Error("Data migration failed", "error", err)
// 迁移失败不一定需要中断,但需要记录
} else {
slog.Info("Data migration completed successfully")
}
}
DB = newDB
return nil
}
func MigrateData(sourceDB, targetDB *gorm.DB) error {
// 迁移 APIToken
var tokens []model.APIToken
if err := sourceDB.Find(&tokens).Error; err == nil && len(tokens) > 0 {
if err := targetDB.Save(&tokens).Error; err != nil {
slog.Warn("Failed to migrate APITokens", "error", err)
}
}
// 迁移 FileBatch (分批处理以节省内存)
var batches []model.FileBatch
err := sourceDB.Model(&model.FileBatch{}).FindInBatches(&batches, 100, func(tx *gorm.DB, batch int) error {
if err := targetDB.Save(&batches).Error; err != nil {
return err
}
return nil
}).Error
if err != nil {
slog.Warn("Failed to migrate FileBatches", "error", err)
}
// 迁移 FileItem (分批处理以节省内存)
var items []model.FileItem
err = sourceDB.Model(&model.FileItem{}).FindInBatches(&items, 100, func(tx *gorm.DB, batch int) error {
if err := targetDB.Save(&items).Error; err != nil {
return err
}
return nil
}).Error
if err != nil {
slog.Warn("Failed to migrate FileItems", "error", err)
}
return nil
}
func ReloadStorage() error {
storageType := config.GlobalConfig.Storage.Type
switch storageType {
case "local":
storage.GlobalStorage = storage.NewLocalStorage(config.GlobalConfig.Storage.Local.Path)
case "webdav":
cfg := config.GlobalConfig.Storage.WebDAV
storage.GlobalStorage = storage.NewWebDAVStorage(cfg.URL, cfg.Username, cfg.Password, cfg.Root)
case "s3":
cfg := config.GlobalConfig.Storage.S3
s3Storage, err := storage.NewS3Storage(context.Background(), cfg.Endpoint, cfg.Region, cfg.AccessKey, cfg.SecretKey, cfg.Bucket, cfg.UseSSL)
if err != nil {
return err
}
storage.GlobalStorage = s3Storage
default:
return fmt.Errorf("unsupported storage type: %s", storageType)
}
slog.Info("Storage initialized", "type", storageType)
return nil
}
func initAdmin() {
passwordHash := config.GlobalConfig.Security.AdminPasswordHash
if passwordHash == "" {
// 生成随机密码
password := generateRandomPassword(12)
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
slog.Error("Failed to generate password hash", "error", err)
os.Exit(1)
}
passwordHash = string(hash)
fmt.Printf("**************************************************\n")
fmt.Printf("NO ADMIN PASSWORD CONFIGURED. GENERATED RANDOM PASSWORD:\n")
fmt.Printf("Password: %s\n", password)
fmt.Printf("Please save this password or configure admin_password_hash in config.yaml\n")
fmt.Printf("**************************************************\n")
// 将生成的哈希保存回配置文件
config.GlobalConfig.Security.AdminPasswordHash = passwordHash
if err := config.SaveConfig(); err != nil {
slog.Warn("Failed to save generated password hash to config", "error", err)
}
}
slog.Info("Admin authentication initialized")
}
func generateRandomPassword(length int) string {
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!@#$%^&*"
b := make([]byte, length)
for i := range b {
num, err := rand.Int(rand.Reader, big.NewInt(int64(len(charset))))
if err != nil {
return "admin123" // 退路
}
b[i] = charset[num.Int64()]
}
return string(b)
}
// ConsoleHandler 实现 simplified 控制台日志
type ConsoleHandler struct {
out io.Writer
level slog.Leveler
attrs []slog.Attr
}
func (h *ConsoleHandler) Enabled(_ context.Context, l slog.Level) bool {
return l >= h.level.Level()
}
func (h *ConsoleHandler) Handle(_ context.Context, r slog.Record) error {
var b strings.Builder
// 时间: 15:04:05.000
b.WriteString(r.Time.Format("15:04:05.000"))
b.WriteString(" ")
// 级别: INFO
level := r.Level.String()
b.WriteString(fmt.Sprintf("%-5s", level))
b.WriteString(" ")
// 源码: [main.go:123]
pc := r.PC
if pc == 0 {
// 如果 Logger 没采集 PC (自定义 Handler 默认情况),我们手动采集
// 跳过: 1:runtime.Callers, 2:Handle, 3:slog.(*Logger).log, 4:slog.(*Logger).Info
var pcs [1]uintptr
runtime.Callers(4, pcs[:])
pc = pcs[0]
}
if pc != 0 {
fs := runtime.CallersFrames([]uintptr{pc})
f, _ := fs.Next()
b.WriteString(fmt.Sprintf("[%s:%d] ", filepath.Base(f.File), f.Line))
}
// 消息
b.WriteString(r.Message)
// 属性: 仅输出值
writeAttr := func(a slog.Attr) {
if a.Key == "service" {
return
}
b.WriteString(" ")
val := a.Value.Resolve().String()
if strings.Contains(val, " ") {
b.WriteString(fmt.Sprintf("%q", val))
} else {
b.WriteString(val)
}
}
for _, a := range h.attrs {
writeAttr(a)
}
r.Attrs(func(a slog.Attr) bool {
writeAttr(a)
return true
})
b.WriteString("\n")
_, err := h.out.Write([]byte(b.String()))
return err
}
func (h *ConsoleHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
newAttrs := make([]slog.Attr, len(h.attrs)+len(attrs))
copy(newAttrs, h.attrs)
copy(newAttrs[len(h.attrs):], attrs)
return &ConsoleHandler{
out: h.out,
level: h.level,
attrs: newAttrs,
}
}
func (h *ConsoleHandler) WithGroup(name string) slog.Handler {
return h // 简化版暂不支持分组
}
// MultiHandler 实现多路分发
type MultiHandler struct {
handlers []slog.Handler
}
func (h *MultiHandler) Enabled(ctx context.Context, l slog.Level) bool {
for _, hh := range h.handlers {
if hh.Enabled(ctx, l) {
return true
}
}
return false
}
func (h *MultiHandler) Handle(ctx context.Context, r slog.Record) error {
for _, hh := range h.handlers {
if hh.Enabled(ctx, r.Level) {
_ = hh.Handle(ctx, r)
}
}
return nil
}
func (h *MultiHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
newHandlers := make([]slog.Handler, len(h.handlers))
for i, hh := range h.handlers {
newHandlers[i] = hh.WithAttrs(attrs)
}
return &MultiHandler{handlers: newHandlers}
}
func (h *MultiHandler) WithGroup(name string) slog.Handler {
newHandlers := make([]slog.Handler, len(h.handlers))
for i, hh := range h.handlers {
newHandlers[i] = hh.WithGroup(name)
}
return &MultiHandler{handlers: newHandlers}
}