mirror of
https://git.fightbot.fun/hxuanyu/FileRelay.git
synced 2026-02-15 11:51:43 +08:00
Initial commit
This commit is contained in:
70
internal/api/admin/auth.go
Normal file
70
internal/api/admin/auth.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"FileRelay/internal/auth"
|
||||
"FileRelay/internal/config"
|
||||
"FileRelay/internal/model"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
type AuthHandler struct{}
|
||||
|
||||
func NewAuthHandler() *AuthHandler {
|
||||
return &AuthHandler{}
|
||||
}
|
||||
|
||||
type LoginRequest struct {
|
||||
Password string `json:"password" binding:"required" example:"admin"`
|
||||
}
|
||||
|
||||
type LoginResponse struct {
|
||||
Token string `json:"token"`
|
||||
}
|
||||
|
||||
// Login 管理员登录
|
||||
// @Summary 管理员登录
|
||||
// @Description 通过密码换取 JWT Token
|
||||
// @Tags Admin
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body LoginRequest true "登录请求"
|
||||
// @Success 200 {object} model.Response{data=LoginResponse}
|
||||
// @Failure 401 {object} model.Response
|
||||
// @Router /api/admin/login [post]
|
||||
func (h *AuthHandler) Login(c *gin.Context) {
|
||||
var req LoginRequest
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, model.ErrorResponse(model.CodeBadRequest, "Invalid request"))
|
||||
return
|
||||
}
|
||||
|
||||
passwordHash := config.GlobalConfig.Security.AdminPasswordHash
|
||||
if passwordHash == "" {
|
||||
c.JSON(http.StatusInternalServerError, model.ErrorResponse(model.CodeInternalError, "Admin password hash not configured"))
|
||||
return
|
||||
}
|
||||
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(passwordHash), []byte(req.Password)); err != nil {
|
||||
slog.Warn("Failed admin login attempt", "ip", c.ClientIP())
|
||||
c.JSON(http.StatusUnauthorized, model.ErrorResponse(model.CodeUnauthorized, "Incorrect password"))
|
||||
return
|
||||
}
|
||||
|
||||
// 使用固定 ID 1 代表管理员(因为不再有数据库记录)
|
||||
token, err := auth.GenerateToken(1)
|
||||
if err != nil {
|
||||
slog.Error("Failed to generate admin token", "error", err)
|
||||
c.JSON(http.StatusInternalServerError, model.ErrorResponse(model.CodeInternalError, "Failed to generate token"))
|
||||
return
|
||||
}
|
||||
|
||||
slog.Info("Admin logged in", "ip", c.ClientIP())
|
||||
c.JSON(http.StatusOK, model.SuccessResponse(LoginResponse{
|
||||
Token: token,
|
||||
}))
|
||||
}
|
||||
256
internal/api/admin/batch.go
Normal file
256
internal/api/admin/batch.go
Normal file
@@ -0,0 +1,256 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"FileRelay/internal/bootstrap"
|
||||
"FileRelay/internal/model"
|
||||
"FileRelay/internal/service"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type BatchHandler struct {
|
||||
batchService *service.BatchService
|
||||
}
|
||||
|
||||
func NewBatchHandler() *BatchHandler {
|
||||
return &BatchHandler{
|
||||
batchService: service.NewBatchService(),
|
||||
}
|
||||
}
|
||||
|
||||
type ListBatchesResponse struct {
|
||||
Total int64 `json:"total"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
Data []model.FileBatch `json:"data"`
|
||||
}
|
||||
|
||||
type UpdateBatchRequest struct {
|
||||
Remark *string `json:"remark"`
|
||||
ExpireType *string `json:"expire_type"`
|
||||
ExpireAt *time.Time `json:"expire_at"`
|
||||
MaxDownloads *int `json:"max_downloads"`
|
||||
DownloadCount *int `json:"download_count"`
|
||||
Status *string `json:"status"`
|
||||
}
|
||||
|
||||
// ListBatches 获取批次列表
|
||||
// @Summary 获取批次列表
|
||||
// @Description 分页查询所有文件批次,支持按状态过滤和取件码模糊搜索
|
||||
// @Tags Admin
|
||||
// @Security AdminAuth
|
||||
// @Param page query int false "页码 (默认 1)"
|
||||
// @Param page_size query int false "每页数量 (默认 20)"
|
||||
// @Param status query string false "状态 (active/expired/deleted)"
|
||||
// @Param pickup_code query string false "取件码 (模糊搜索)"
|
||||
// @Produce json
|
||||
// @Success 200 {object} model.Response{data=ListBatchesResponse}
|
||||
// @Failure 401 {object} model.Response
|
||||
// @Router /api/admin/batches [get]
|
||||
func (h *BatchHandler) ListBatches(c *gin.Context) {
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
if pageSize < 1 {
|
||||
pageSize = 20
|
||||
}
|
||||
status := c.Query("status")
|
||||
pickupCode := c.Query("pickup_code")
|
||||
|
||||
query := bootstrap.DB.Model(&model.FileBatch{})
|
||||
if status != "" {
|
||||
query = query.Where("status = ?", status)
|
||||
}
|
||||
if pickupCode != "" {
|
||||
query = query.Where("pickup_code LIKE ?", "%"+pickupCode+"%")
|
||||
}
|
||||
|
||||
var total int64
|
||||
query.Count(&total)
|
||||
|
||||
var batches []model.FileBatch
|
||||
err := query.Offset((page - 1) * pageSize).Limit(pageSize).Order("created_at DESC").Find(&batches).Error
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, model.ErrorResponse(model.CodeInternalError, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, model.SuccessResponse(ListBatchesResponse{
|
||||
Total: total,
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
Data: batches,
|
||||
}))
|
||||
}
|
||||
|
||||
// GetBatch 获取批次详情
|
||||
// @Summary 获取批次详情
|
||||
// @Description 根据批次 ID 获取批次信息及关联的文件列表
|
||||
// @Tags Admin
|
||||
// @Security AdminAuth
|
||||
// @Param batch_id path string true "批次 ID (UUID)"
|
||||
// @Produce json
|
||||
// @Success 200 {object} model.Response{data=model.FileBatch}
|
||||
// @Failure 404 {object} model.Response
|
||||
// @Router /api/admin/batches/{batch_id} [get]
|
||||
func (h *BatchHandler) GetBatch(c *gin.Context) {
|
||||
id := c.Param("batch_id")
|
||||
var batch model.FileBatch
|
||||
if err := bootstrap.DB.Preload("FileItems").First(&batch, "id = ?", id).Error; err != nil {
|
||||
c.JSON(http.StatusNotFound, model.ErrorResponse(model.CodeNotFound, "batch not found"))
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, model.SuccessResponse(batch))
|
||||
}
|
||||
|
||||
// UpdateBatch 修改批次信息
|
||||
// @Summary 修改批次信息
|
||||
// @Description 允许修改备注、过期策略、最大下载次数、状态等
|
||||
// @Tags Admin
|
||||
// @Security AdminAuth
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param batch_id path string true "批次 ID (UUID)"
|
||||
// @Param request body UpdateBatchRequest true "修改内容"
|
||||
// @Success 200 {object} model.Response{data=model.FileBatch}
|
||||
// @Failure 400 {object} model.Response
|
||||
// @Router /api/admin/batches/{batch_id} [put]
|
||||
func (h *BatchHandler) UpdateBatch(c *gin.Context) {
|
||||
id := c.Param("batch_id")
|
||||
var batch model.FileBatch
|
||||
if err := bootstrap.DB.First(&batch, "id = ?", id).Error; err != nil {
|
||||
c.JSON(http.StatusNotFound, model.ErrorResponse(model.CodeNotFound, "batch not found"))
|
||||
return
|
||||
}
|
||||
|
||||
rawBody, err := c.GetRawData()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, model.ErrorResponse(model.CodeBadRequest, "failed to read body"))
|
||||
return
|
||||
}
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(rawBody))
|
||||
|
||||
var input UpdateBatchRequest
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
c.JSON(http.StatusBadRequest, model.ErrorResponse(model.CodeBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
var rawMap map[string]interface{}
|
||||
json.Unmarshal(rawBody, &rawMap)
|
||||
|
||||
updates := make(map[string]interface{})
|
||||
if input.Remark != nil {
|
||||
updates["remark"] = *input.Remark
|
||||
}
|
||||
if input.ExpireType != nil {
|
||||
newType := *input.ExpireType
|
||||
updates["expire_type"] = newType
|
||||
|
||||
// 如果类型发生变化,根据新类型清除不相关的配置
|
||||
if newType != batch.ExpireType {
|
||||
if newType == "download" {
|
||||
updates["expire_at"] = nil
|
||||
} else if newType == "time" {
|
||||
updates["max_downloads"] = 0
|
||||
} else if newType == "permanent" {
|
||||
updates["expire_at"] = nil
|
||||
updates["max_downloads"] = 0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 显式提供的值具有最高优先级,但仅在逻辑允许的情况下
|
||||
// 例如:如果切换到了 download 类型,用户可以同时提供一个新的 max_downloads
|
||||
if _, ok := rawMap["expire_at"]; ok {
|
||||
updates["expire_at"] = input.ExpireAt
|
||||
}
|
||||
if input.MaxDownloads != nil {
|
||||
updates["max_downloads"] = *input.MaxDownloads
|
||||
}
|
||||
|
||||
// 强制校验:如果最终结果是 permanent,确保限制被清空
|
||||
// 这样即使用户在请求中显式传了非零值,也会被修正
|
||||
finalType := batch.ExpireType
|
||||
if t, ok := updates["expire_type"].(string); ok {
|
||||
finalType = t
|
||||
}
|
||||
|
||||
if finalType == "permanent" {
|
||||
updates["expire_at"] = nil
|
||||
updates["max_downloads"] = 0
|
||||
} else if finalType == "time" {
|
||||
// 如果是时间过期,max_downloads 应该始终为 0
|
||||
updates["max_downloads"] = 0
|
||||
} else if finalType == "download" {
|
||||
// 如果是下载次数过期,expire_at 应该始终为 null
|
||||
updates["expire_at"] = nil
|
||||
}
|
||||
if input.DownloadCount != nil {
|
||||
updates["download_count"] = *input.DownloadCount
|
||||
}
|
||||
if input.Status != nil {
|
||||
updates["status"] = *input.Status
|
||||
}
|
||||
|
||||
if len(updates) > 0 {
|
||||
if err := bootstrap.DB.Model(&batch).Updates(updates).Error; err != nil {
|
||||
c.JSON(http.StatusInternalServerError, model.ErrorResponse(model.CodeInternalError, err.Error()))
|
||||
return
|
||||
}
|
||||
// 重新从数据库读取,确保返回的是完整且最新的数据
|
||||
bootstrap.DB.First(&batch, "id = ?", id)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, model.SuccessResponse(batch))
|
||||
}
|
||||
|
||||
// CleanBatches 手动触发清理过期或已删除的批次
|
||||
// @Summary 手动触发清理
|
||||
// @Description 手动扫描并物理删除所有已过期或标记为删除的文件批次及其关联文件
|
||||
// @Tags Admin
|
||||
// @Security AdminAuth
|
||||
// @Produce json
|
||||
// @Success 200 {object} model.Response
|
||||
// @Failure 500 {object} model.Response
|
||||
// @Router /api/admin/batches/clean [post]
|
||||
func (h *BatchHandler) CleanBatches(c *gin.Context) {
|
||||
slog.Info("Admin triggered manual cleanup")
|
||||
if err := h.batchService.Cleanup(c.Request.Context()); err != nil {
|
||||
slog.Error("Manual cleanup failed", "error", err)
|
||||
c.JSON(http.StatusInternalServerError, model.ErrorResponse(model.CodeInternalError, "cleanup failed: "+err.Error()))
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, model.SuccessResponse(nil))
|
||||
}
|
||||
|
||||
// DeleteBatch 删除批次
|
||||
// @Summary 删除批次
|
||||
// @Description 标记批次为已删除,并物理删除关联的存储文件
|
||||
// @Tags Admin
|
||||
// @Security AdminAuth
|
||||
// @Param batch_id path string true "批次 ID (UUID)"
|
||||
// @Produce json
|
||||
// @Success 200 {object} model.Response
|
||||
// @Failure 500 {object} model.Response
|
||||
// @Router /api/admin/batches/{batch_id} [delete]
|
||||
func (h *BatchHandler) DeleteBatch(c *gin.Context) {
|
||||
id := c.Param("batch_id")
|
||||
|
||||
if err := h.batchService.DeleteBatch(c.Request.Context(), id); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, model.ErrorResponse(model.CodeInternalError, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, model.SuccessResponse(map[string]interface{}{}))
|
||||
}
|
||||
107
internal/api/admin/config.go
Normal file
107
internal/api/admin/config.go
Normal file
@@ -0,0 +1,107 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"FileRelay/internal/bootstrap"
|
||||
"FileRelay/internal/config"
|
||||
"FileRelay/internal/model"
|
||||
"FileRelay/internal/service"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
type ConfigHandler struct{}
|
||||
|
||||
func NewConfigHandler() *ConfigHandler {
|
||||
return &ConfigHandler{}
|
||||
}
|
||||
|
||||
// GetConfig 获取当前完整配置
|
||||
// @Summary 获取完整配置
|
||||
// @Description 获取系统的完整配置文件内容(仅管理员)
|
||||
// @Tags Admin
|
||||
// @Security AdminAuth
|
||||
// @Produce json
|
||||
// @Success 200 {object} model.Response{data=config.Config}
|
||||
// @Router /api/admin/config [get]
|
||||
func (h *ConfigHandler) GetConfig(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, model.SuccessResponse(config.GlobalConfig))
|
||||
}
|
||||
|
||||
// UpdateConfig 更新配置
|
||||
// @Summary 更新配置
|
||||
// @Description 更新系统的配置文件内容(仅管理员)
|
||||
// @Tags Admin
|
||||
// @Security AdminAuth
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param config body config.Config true "新配置内容"
|
||||
// @Success 200 {object} model.Response{data=config.Config}
|
||||
// @Failure 400 {object} model.Response
|
||||
// @Failure 500 {object} model.Response
|
||||
// @Router /api/admin/config [put]
|
||||
func (h *ConfigHandler) UpdateConfig(c *gin.Context) {
|
||||
var newConfig config.Config
|
||||
if err := c.ShouldBindJSON(&newConfig); err != nil {
|
||||
c.JSON(http.StatusBadRequest, model.ErrorResponse(model.CodeBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
// 简单的校验,防止关键配置被改空
|
||||
if newConfig.Database.Path == "" {
|
||||
newConfig.Database.Path = config.GlobalConfig.Database.Path
|
||||
}
|
||||
if newConfig.Site.Port <= 0 || newConfig.Site.Port > 65535 {
|
||||
newConfig.Site.Port = 8080
|
||||
}
|
||||
|
||||
// 如果传入了明文密码,则重新生成 hash
|
||||
if newConfig.Security.AdminPassword != "" {
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(newConfig.Security.AdminPassword), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, model.ErrorResponse(model.CodeInternalError, "Failed to hash password: "+err.Error()))
|
||||
return
|
||||
}
|
||||
newConfig.Security.AdminPasswordHash = string(hash)
|
||||
}
|
||||
|
||||
// 检查取件码长度是否变化
|
||||
pickupCodeLengthChanged := newConfig.Security.PickupCodeLength != config.GlobalConfig.Security.PickupCodeLength && newConfig.Security.PickupCodeLength > 0
|
||||
// 检查数据库配置是否变化
|
||||
dbConfigChanged := newConfig.Database != config.GlobalConfig.Database
|
||||
|
||||
// 如果长度变化,同步更新现有取件码 (在可能切换数据库前,先处理旧库数据)
|
||||
if pickupCodeLengthChanged {
|
||||
batchService := service.NewBatchService()
|
||||
if err := batchService.UpdateAllPickupCodes(newConfig.Security.PickupCodeLength); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, model.ErrorResponse(model.CodeInternalError, "Failed to update existing pickup codes: "+err.Error()))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 更新内存配置
|
||||
config.UpdateGlobalConfig(&newConfig)
|
||||
|
||||
// 重新连接数据库并迁移数据(如果配置发生变化)
|
||||
if dbConfigChanged {
|
||||
if err := bootstrap.ReloadDB(newConfig.Database); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, model.ErrorResponse(model.CodeInternalError, "Failed to reload database: "+err.Error()))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 重新初始化存储(热更新业务逻辑)
|
||||
if err := bootstrap.ReloadStorage(); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, model.ErrorResponse(model.CodeInternalError, "Failed to reload storage: "+err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
// 保存到文件
|
||||
if err := config.SaveConfig(); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, model.ErrorResponse(model.CodeInternalError, "Failed to save config: "+err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, model.SuccessResponse(config.GlobalConfig))
|
||||
}
|
||||
138
internal/api/admin/token.go
Normal file
138
internal/api/admin/token.go
Normal file
@@ -0,0 +1,138 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"FileRelay/internal/bootstrap"
|
||||
"FileRelay/internal/model"
|
||||
"FileRelay/internal/service"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type TokenHandler struct {
|
||||
tokenService *service.TokenService
|
||||
}
|
||||
|
||||
func NewTokenHandler() *TokenHandler {
|
||||
return &TokenHandler{
|
||||
tokenService: service.NewTokenService(),
|
||||
}
|
||||
}
|
||||
|
||||
type CreateTokenRequest struct {
|
||||
Name string `json:"name" binding:"required" example:"Test Token"`
|
||||
Scope string `json:"scope" example:"upload,pickup" enums:"upload,pickup,admin"`
|
||||
ExpireAt *time.Time `json:"expire_at"`
|
||||
}
|
||||
|
||||
type CreateTokenResponse struct {
|
||||
Token string `json:"token"`
|
||||
Data *model.APIToken `json:"data"`
|
||||
}
|
||||
|
||||
// ListTokens 获取 API Token 列表
|
||||
// @Summary 获取 API Token 列表
|
||||
// @Description 获取系统中所有 API Token 的详细信息(不包含哈希)
|
||||
// @Tags Admin
|
||||
// @Security AdminAuth
|
||||
// @Produce json
|
||||
// @Success 200 {object} model.Response{data=[]model.APIToken}
|
||||
// @Failure 401 {object} model.Response
|
||||
// @Router /api/admin/api-tokens [get]
|
||||
func (h *TokenHandler) ListTokens(c *gin.Context) {
|
||||
var tokens []model.APIToken
|
||||
if err := bootstrap.DB.Find(&tokens).Error; err != nil {
|
||||
c.JSON(http.StatusInternalServerError, model.ErrorResponse(model.CodeInternalError, err.Error()))
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, model.SuccessResponse(tokens))
|
||||
}
|
||||
|
||||
// CreateToken 创建 API Token
|
||||
// @Summary 创建 API Token
|
||||
// @Description 创建一个新的 API Token,返回原始 Token(仅显示一次)
|
||||
// @Tags Admin
|
||||
// @Security AdminAuth
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body CreateTokenRequest true "Token 信息"
|
||||
// @Success 201 {object} model.Response{data=CreateTokenResponse}
|
||||
// @Failure 400 {object} model.Response
|
||||
// @Router /api/admin/api-tokens [post]
|
||||
func (h *TokenHandler) CreateToken(c *gin.Context) {
|
||||
var input CreateTokenRequest
|
||||
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
c.JSON(http.StatusBadRequest, model.ErrorResponse(model.CodeBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
rawToken, token, err := h.tokenService.CreateToken(input.Name, input.Scope, input.ExpireAt)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, model.ErrorResponse(model.CodeInternalError, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, model.SuccessResponse(CreateTokenResponse{
|
||||
Token: rawToken,
|
||||
Data: token,
|
||||
}))
|
||||
}
|
||||
|
||||
// RevokeToken 撤销 API Token
|
||||
// @Summary 撤销 API Token
|
||||
// @Description 将 API Token 标记为已撤销,使其失效但保留记录
|
||||
// @Tags Admin
|
||||
// @Security AdminAuth
|
||||
// @Param id path int true "Token ID"
|
||||
// @Produce json
|
||||
// @Success 200 {object} model.Response
|
||||
// @Failure 500 {object} model.Response
|
||||
// @Router /api/admin/api-tokens/{id}/revoke [post]
|
||||
func (h *TokenHandler) RevokeToken(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
if err := bootstrap.DB.Model(&model.APIToken{}).Where("id = ?", id).Update("revoked", true).Error; err != nil {
|
||||
c.JSON(http.StatusInternalServerError, model.ErrorResponse(model.CodeInternalError, err.Error()))
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, model.SuccessResponse(map[string]interface{}{}))
|
||||
}
|
||||
|
||||
// RecoverToken 恢复 API Token
|
||||
// @Summary 恢复 API Token
|
||||
// @Description 将已撤销的 API Token 恢复为有效状态
|
||||
// @Tags Admin
|
||||
// @Security AdminAuth
|
||||
// @Param id path int true "Token ID"
|
||||
// @Produce json
|
||||
// @Success 200 {object} model.Response
|
||||
// @Failure 500 {object} model.Response
|
||||
// @Router /api/admin/api-tokens/{id}/recover [post]
|
||||
func (h *TokenHandler) RecoverToken(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
if err := bootstrap.DB.Model(&model.APIToken{}).Where("id = ?", id).Update("revoked", false).Error; err != nil {
|
||||
c.JSON(http.StatusInternalServerError, model.ErrorResponse(model.CodeInternalError, err.Error()))
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, model.SuccessResponse(map[string]interface{}{}))
|
||||
}
|
||||
|
||||
// DeleteToken 删除 API Token
|
||||
// @Summary 删除 API Token
|
||||
// @Description 根据 ID 永久删除 API Token
|
||||
// @Tags Admin
|
||||
// @Security AdminAuth
|
||||
// @Param id path int true "Token ID"
|
||||
// @Produce json
|
||||
// @Success 200 {object} model.Response
|
||||
// @Failure 500 {object} model.Response
|
||||
// @Router /api/admin/api-tokens/{id} [delete]
|
||||
func (h *TokenHandler) DeleteToken(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
if err := bootstrap.DB.Delete(&model.APIToken{}, id).Error; err != nil {
|
||||
c.JSON(http.StatusInternalServerError, model.ErrorResponse(model.CodeInternalError, err.Error()))
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, model.SuccessResponse(map[string]interface{}{}))
|
||||
}
|
||||
119
internal/api/middleware/auth.go
Normal file
119
internal/api/middleware/auth.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"FileRelay/internal/auth"
|
||||
"FileRelay/internal/config"
|
||||
"FileRelay/internal/model"
|
||||
"FileRelay/internal/service"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func AdminAuth() gin.HandlerFunc {
|
||||
tokenService := service.NewTokenService()
|
||||
return func(c *gin.Context) {
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
c.JSON(http.StatusUnauthorized, model.ErrorResponse(model.CodeUnauthorized, "Authorization header required"))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
parts := strings.SplitN(authHeader, " ", 2)
|
||||
if !(len(parts) == 2 && parts[0] == "Bearer") {
|
||||
c.JSON(http.StatusUnauthorized, model.ErrorResponse(model.CodeUnauthorized, "Invalid authorization format"))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
tokenStr := parts[1]
|
||||
|
||||
// 1. 尝试解析为管理员 JWT
|
||||
claims, err := auth.ParseToken(tokenStr)
|
||||
if err == nil {
|
||||
c.Set("admin_id", claims.AdminID)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// 2. 尝试解析为 API Token (如果配置允许)
|
||||
if config.GlobalConfig.APIToken.Enabled && config.GlobalConfig.APIToken.AllowAdminAPI {
|
||||
token, err := tokenService.ValidateToken(tokenStr, model.ScopeAdmin)
|
||||
if err == nil {
|
||||
c.Set("token_id", token.ID)
|
||||
c.Set("token_scope", token.Scope)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusUnauthorized, model.ErrorResponse(model.CodeUnauthorized, "Invalid or expired token"))
|
||||
c.Abort()
|
||||
}
|
||||
}
|
||||
|
||||
func APITokenAuth(requiredScope string, optional bool) gin.HandlerFunc {
|
||||
tokenService := service.NewTokenService()
|
||||
return func(c *gin.Context) {
|
||||
handleAPITokenAuth(c, tokenService, requiredScope, optional)
|
||||
}
|
||||
}
|
||||
|
||||
func UploadAuth() gin.HandlerFunc {
|
||||
tokenService := service.NewTokenService()
|
||||
return func(c *gin.Context) {
|
||||
// 动态获取配置
|
||||
optional := !config.GlobalConfig.Upload.RequireToken
|
||||
handleAPITokenAuth(c, tokenService, model.ScopeUpload, optional)
|
||||
}
|
||||
}
|
||||
|
||||
func handleAPITokenAuth(c *gin.Context, tokenService *service.TokenService, requiredScope string, optional bool) {
|
||||
// 如果是可选的,直接跳过校验,满足“未打开对应的开关时不需校验”的需求
|
||||
if optional {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
c.JSON(http.StatusUnauthorized, model.ErrorResponse(model.CodeUnauthorized, "Authorization header required"))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
parts := strings.SplitN(authHeader, " ", 2)
|
||||
if !(len(parts) == 2 && parts[0] == "Bearer") {
|
||||
c.JSON(http.StatusUnauthorized, model.ErrorResponse(model.CodeUnauthorized, "Invalid authorization format"))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
tokenStr := parts[1]
|
||||
|
||||
// 1. 尝试解析为管理员 JWT
|
||||
if claims, err := auth.ParseToken(tokenStr); err == nil {
|
||||
c.Set("admin_id", claims.AdminID)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
if !config.GlobalConfig.APIToken.Enabled {
|
||||
c.JSON(http.StatusForbidden, model.ErrorResponse(model.CodeForbidden, "API Token is disabled"))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
token, err := tokenService.ValidateToken(tokenStr, requiredScope)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, model.ErrorResponse(model.CodeUnauthorized, err.Error()))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Set("token_id", token.ID)
|
||||
c.Set("token_scope", token.Scope)
|
||||
c.Next()
|
||||
}
|
||||
56
internal/api/middleware/limit.go
Normal file
56
internal/api/middleware/limit.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"FileRelay/internal/config"
|
||||
"FileRelay/internal/model"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
var (
|
||||
pickupFailures = make(map[string]int)
|
||||
failureMutex sync.Mutex
|
||||
)
|
||||
|
||||
func PickupRateLimit() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
key := c.ClientIP()
|
||||
|
||||
failureMutex.Lock()
|
||||
count, exists := pickupFailures[key]
|
||||
failureMutex.Unlock()
|
||||
|
||||
if exists && count >= config.GlobalConfig.Security.PickupFailLimit {
|
||||
slog.Warn("Pickup rate limit exceeded", "ip", key, "count", count)
|
||||
c.JSON(http.StatusTooManyRequests, model.ErrorResponse(model.CodeTooManyRequests, "Too many failed attempts. Please try again later."))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func RecordPickupFailure(ip string) {
|
||||
key := ip
|
||||
failureMutex.Lock()
|
||||
pickupFailures[key]++
|
||||
|
||||
// 仅在第一次失败时启动清除记录的计时器
|
||||
if pickupFailures[key] == 1 {
|
||||
go func() {
|
||||
// 设置 1 分钟后清除记录 (简单实现)
|
||||
time.Sleep(1 * time.Hour)
|
||||
failureMutex.Lock()
|
||||
delete(pickupFailures, key)
|
||||
slog.Info("Pickup failure record cleared", "ip", key)
|
||||
failureMutex.Unlock()
|
||||
}()
|
||||
}
|
||||
|
||||
failureMutex.Unlock()
|
||||
}
|
||||
61
internal/api/public/config.go
Normal file
61
internal/api/public/config.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package public
|
||||
|
||||
import (
|
||||
"FileRelay/internal/config"
|
||||
"FileRelay/internal/model"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type ConfigHandler struct{}
|
||||
|
||||
func NewConfigHandler() *ConfigHandler {
|
||||
return &ConfigHandler{}
|
||||
}
|
||||
|
||||
// PublicConfig 公开配置结构
|
||||
type PublicConfig struct {
|
||||
Site config.SiteConfig `json:"site"`
|
||||
Security PublicSecurityConfig `json:"security"`
|
||||
Upload config.UploadConfig `json:"upload"`
|
||||
APIToken PublicAPITokenConfig `json:"api_token"`
|
||||
Storage PublicStorageConfig `json:"storage"`
|
||||
}
|
||||
|
||||
type PublicSecurityConfig struct {
|
||||
PickupCodeLength int `json:"pickup_code_length"`
|
||||
}
|
||||
|
||||
type PublicAPITokenConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
|
||||
type PublicStorageConfig struct {
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
// GetPublicConfig 获取非敏感配置
|
||||
// @Summary 获取公共配置
|
||||
// @Description 获取前端展示所需的非敏感配置数据
|
||||
// @Tags Public
|
||||
// @Produce json
|
||||
// @Success 200 {object} model.Response{data=PublicConfig}
|
||||
// @Router /api/config [get]
|
||||
func (h *ConfigHandler) GetPublicConfig(c *gin.Context) {
|
||||
pub := PublicConfig{
|
||||
Site: config.GlobalConfig.Site,
|
||||
Security: PublicSecurityConfig{
|
||||
PickupCodeLength: config.GlobalConfig.Security.PickupCodeLength,
|
||||
},
|
||||
Upload: config.GlobalConfig.Upload,
|
||||
APIToken: PublicAPITokenConfig{
|
||||
Enabled: config.GlobalConfig.APIToken.Enabled,
|
||||
},
|
||||
Storage: PublicStorageConfig{
|
||||
Type: config.GlobalConfig.Storage.Type,
|
||||
},
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, model.SuccessResponse(pub))
|
||||
}
|
||||
314
internal/api/public/pickup.go
Normal file
314
internal/api/public/pickup.go
Normal file
@@ -0,0 +1,314 @@
|
||||
package public
|
||||
|
||||
import (
|
||||
"FileRelay/internal/api/middleware"
|
||||
"FileRelay/internal/bootstrap"
|
||||
"FileRelay/internal/config"
|
||||
"FileRelay/internal/model"
|
||||
"FileRelay/internal/service"
|
||||
"FileRelay/internal/storage"
|
||||
"archive/zip"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type PickupResponse struct {
|
||||
Remark string `json:"remark"`
|
||||
ExpireAt *time.Time `json:"expire_at"`
|
||||
ExpireType string `json:"expire_type"`
|
||||
DownloadCount int `json:"download_count"`
|
||||
MaxDownloads int `json:"max_downloads"`
|
||||
Type string `json:"type"`
|
||||
Content string `json:"content,omitempty"`
|
||||
Files []model.FileItem `json:"files,omitempty"`
|
||||
}
|
||||
|
||||
type DownloadCountResponse struct {
|
||||
DownloadCount int `json:"download_count"`
|
||||
MaxDownloads int `json:"max_downloads"`
|
||||
}
|
||||
|
||||
// DownloadBatch 批量下载文件 (ZIP)
|
||||
// @Summary 批量下载文件
|
||||
// @Description 根据取件码将批次内的所有文件打包为 ZIP 格式一次性下载。可选提供带 pickup scope 的 API Token。
|
||||
// @Tags Public
|
||||
// @Security APITokenAuth
|
||||
// @Param pickup_code path string true "取件码"
|
||||
// @Produce application/zip
|
||||
// @Success 200 {file} file
|
||||
// @Failure 404 {object} model.Response
|
||||
// @Router /api/batches/{pickup_code}/download [get]
|
||||
func (h *PickupHandler) DownloadBatch(c *gin.Context) {
|
||||
code := c.Param("pickup_code")
|
||||
batch, err := h.batchService.GetBatchByPickupCode(code)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, model.ErrorResponse(model.CodeNotFound, "batch not found or expired"))
|
||||
return
|
||||
}
|
||||
|
||||
c.Header("Content-Disposition", fmt.Sprintf("attachment; filename=\"batch_%s.zip\"", code))
|
||||
c.Header("Content-Type", "application/zip")
|
||||
|
||||
zw := zip.NewWriter(c.Writer)
|
||||
defer zw.Close()
|
||||
|
||||
for _, item := range batch.FileItems {
|
||||
reader, err := storage.GlobalStorage.Open(c.Request.Context(), item.StoragePath)
|
||||
if err != nil {
|
||||
continue // Skip failed files
|
||||
}
|
||||
|
||||
f, err := zw.Create(item.OriginalName)
|
||||
if err != nil {
|
||||
reader.Close()
|
||||
continue
|
||||
}
|
||||
|
||||
_, _ = io.Copy(f, reader)
|
||||
reader.Close()
|
||||
}
|
||||
|
||||
// 增加下载次数
|
||||
if err := h.batchService.IncrementDownloadCount(batch.ID); err != nil {
|
||||
slog.Error("Failed to increment download count", "batch_id", batch.ID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
type PickupHandler struct {
|
||||
batchService *service.BatchService
|
||||
}
|
||||
|
||||
func NewPickupHandler() *PickupHandler {
|
||||
return &PickupHandler{
|
||||
batchService: service.NewBatchService(),
|
||||
}
|
||||
}
|
||||
|
||||
// Pickup 获取批次信息
|
||||
// @Summary 获取批次信息
|
||||
// @Description 根据取件码获取文件批次详细信息和文件列表。可选提供带 pickup scope 的 API Token。
|
||||
// @Tags Public
|
||||
// @Security APITokenAuth
|
||||
// @Produce json
|
||||
// @Param pickup_code path string true "取件码"
|
||||
// @Success 200 {object} model.Response{data=PickupResponse}
|
||||
// @Failure 404 {object} model.Response
|
||||
// @Router /api/batches/{pickup_code} [get]
|
||||
func (h *PickupHandler) Pickup(c *gin.Context) {
|
||||
code := c.Param("pickup_code")
|
||||
if code == "" {
|
||||
c.JSON(http.StatusBadRequest, model.ErrorResponse(model.CodeBadRequest, "pickup code required"))
|
||||
return
|
||||
}
|
||||
|
||||
batch, err := h.batchService.GetBatchByPickupCode(code)
|
||||
if err != nil {
|
||||
middleware.RecordPickupFailure(c.ClientIP())
|
||||
c.JSON(http.StatusNotFound, model.ErrorResponse(model.CodeNotFound, "batch not found or expired"))
|
||||
return
|
||||
}
|
||||
|
||||
if batch.Type == "text" {
|
||||
if err := h.batchService.IncrementDownloadCount(batch.ID); err != nil {
|
||||
slog.Error("Failed to increment download count for batch", "batch_id", batch.ID, "error", err)
|
||||
} else {
|
||||
batch.DownloadCount++
|
||||
}
|
||||
}
|
||||
|
||||
baseURL := getBaseURL(c)
|
||||
|
||||
for i := range batch.FileItems {
|
||||
batch.FileItems[i].DownloadURL = fmt.Sprintf("%s/api/files/%s/%s", baseURL, batch.FileItems[i].ID, batch.FileItems[i].OriginalName)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, model.SuccessResponse(PickupResponse{
|
||||
Remark: batch.Remark,
|
||||
ExpireAt: batch.ExpireAt,
|
||||
ExpireType: batch.ExpireType,
|
||||
DownloadCount: batch.DownloadCount,
|
||||
MaxDownloads: batch.MaxDownloads,
|
||||
Type: batch.Type,
|
||||
Content: batch.Content,
|
||||
Files: batch.FileItems,
|
||||
}))
|
||||
}
|
||||
|
||||
// GetDownloadCount 查询下载次数
|
||||
// @Summary 查询下载次数
|
||||
// @Description 根据取件码查询当前下载次数和最大允许下载次数。支持已过期的批次。
|
||||
// @Tags Public
|
||||
// @Produce json
|
||||
// @Param pickup_code path string true "取件码"
|
||||
// @Success 200 {object} model.Response{data=DownloadCountResponse}
|
||||
// @Failure 400 {object} model.Response
|
||||
// @Failure 404 {object} model.Response
|
||||
// @Router /api/batches/{pickup_code}/count [get]
|
||||
func (h *PickupHandler) GetDownloadCount(c *gin.Context) {
|
||||
code := c.Param("pickup_code")
|
||||
if code == "" {
|
||||
c.JSON(http.StatusBadRequest, model.ErrorResponse(model.CodeBadRequest, "pickup code required"))
|
||||
return
|
||||
}
|
||||
|
||||
count, max, err := h.batchService.GetDownloadCountByPickupCode(code)
|
||||
if err != nil {
|
||||
middleware.RecordPickupFailure(c.ClientIP())
|
||||
c.JSON(http.StatusNotFound, model.ErrorResponse(model.CodeNotFound, "batch not found"))
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, model.SuccessResponse(DownloadCountResponse{
|
||||
DownloadCount: count,
|
||||
MaxDownloads: max,
|
||||
}))
|
||||
}
|
||||
|
||||
func getBaseURL(c *gin.Context) string {
|
||||
// 优先使用配置中的 BaseURL
|
||||
if config.GlobalConfig.Site.BaseURL != "" {
|
||||
return strings.TrimSuffix(config.GlobalConfig.Site.BaseURL, "/")
|
||||
}
|
||||
|
||||
// 自动检测逻辑
|
||||
scheme := "http"
|
||||
if c.Request.TLS != nil {
|
||||
scheme = "https"
|
||||
} else {
|
||||
// 检查常用的代理协议头 (优先)
|
||||
// 增加对用户提供的 :scheme (可能被某些代理转为普通 header) 的支持
|
||||
// 增加对 X-Forwarded-Proto 可能存在的逗号分隔列表的处理
|
||||
checkHeaders := []struct {
|
||||
name string
|
||||
values []string
|
||||
}{
|
||||
{"X-Forwarded-Proto", []string{"https"}},
|
||||
{"X-Forwarded-Protocol", []string{"https"}},
|
||||
{"X-Url-Scheme", []string{"https"}},
|
||||
{"Front-End-Https", []string{"on", "https"}},
|
||||
{"X-Forwarded-Ssl", []string{"on", "https"}},
|
||||
{":scheme", []string{"https"}},
|
||||
{"X-Scheme", []string{"https"}},
|
||||
}
|
||||
|
||||
found := false
|
||||
for _, h := range checkHeaders {
|
||||
val := c.GetHeader(h.name)
|
||||
if val == "" {
|
||||
continue
|
||||
}
|
||||
// 处理可能的逗号分隔列表 (如 X-Forwarded-Proto: https, http)
|
||||
firstVal := strings.TrimSpace(strings.ToLower(strings.Split(val, ",")[0]))
|
||||
for _, target := range h.values {
|
||||
if firstVal == target {
|
||||
scheme = "https"
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if found {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 检查 Forwarded 头部 (RFC 7239)
|
||||
if !found {
|
||||
if forwarded := c.GetHeader("Forwarded"); forwarded != "" {
|
||||
if strings.Contains(strings.ToLower(forwarded), "proto=https") {
|
||||
scheme = "https"
|
||||
found = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 启发式判断:如果上述头部都没有,但 Referer 是 https,则认为也是 https
|
||||
// 这在同域 API 请求时非常可靠
|
||||
if !found {
|
||||
if referer := c.GetHeader("Referer"); strings.HasPrefix(strings.ToLower(referer), "https://") {
|
||||
scheme = "https"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
host := c.Request.Host
|
||||
if forwardedHost := c.GetHeader("X-Forwarded-Host"); forwardedHost != "" {
|
||||
// 处理可能的逗号分隔列表
|
||||
host = strings.TrimSpace(strings.Split(forwardedHost, ",")[0])
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s://%s", scheme, host)
|
||||
}
|
||||
|
||||
// DownloadFile 下载单个文件
|
||||
// @Summary 下载单个文件
|
||||
// @Description 根据文件 ID 下载单个文件。支持直观的文件名结尾以方便下载工具识别。可选提供带 pickup scope 的 API Token。
|
||||
// @Tags Public
|
||||
// @Security APITokenAuth
|
||||
// @Param file_id path string true "文件 ID (UUID)"
|
||||
// @Param filename path string false "文件名"
|
||||
// @Produce application/octet-stream
|
||||
// @Success 200 {file} file
|
||||
// @Failure 404 {object} model.Response
|
||||
// @Failure 410 {object} model.Response
|
||||
// @Router /api/files/{file_id}/{filename} [get]
|
||||
// @Router /api/files/{file_id}/download [get]
|
||||
func (h *PickupHandler) DownloadFile(c *gin.Context) {
|
||||
fileID := c.Param("file_id")
|
||||
|
||||
var item model.FileItem
|
||||
if err := bootstrap.DB.First(&item, "id = ?", fileID).Error; err != nil {
|
||||
c.JSON(http.StatusNotFound, model.ErrorResponse(model.CodeNotFound, "file not found"))
|
||||
return
|
||||
}
|
||||
|
||||
var batch model.FileBatch
|
||||
if err := bootstrap.DB.First(&batch, "id = ?", item.BatchID).Error; err != nil {
|
||||
c.JSON(http.StatusNotFound, model.ErrorResponse(model.CodeNotFound, "batch not found"))
|
||||
return
|
||||
}
|
||||
|
||||
if h.batchService.IsExpired(&batch) {
|
||||
h.batchService.MarkAsExpired(&batch)
|
||||
// 按照需求,如果不存在(已在上面处理)或达到上限,返回 404
|
||||
if batch.ExpireType == "download" && batch.DownloadCount >= batch.MaxDownloads {
|
||||
c.JSON(http.StatusNotFound, model.ErrorResponse(model.CodeNotFound, "file not found or download limit reached"))
|
||||
} else {
|
||||
c.JSON(http.StatusGone, model.ErrorResponse(model.CodeGone, "batch expired"))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 打开文件
|
||||
reader, err := storage.GlobalStorage.Open(c.Request.Context(), item.StoragePath)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, model.ErrorResponse(model.CodeInternalError, "failed to open file"))
|
||||
return
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
// 增加下载次数
|
||||
if err := h.batchService.IncrementDownloadCount(batch.ID); err != nil {
|
||||
// 记录错误但不中断下载过程
|
||||
slog.Error("Failed to increment download count for batch", "batch_id", batch.ID, "error", err)
|
||||
}
|
||||
|
||||
c.Header("Content-Disposition", fmt.Sprintf("attachment; filename=\"%s\"", item.OriginalName))
|
||||
c.Header("Content-Type", item.MimeType)
|
||||
c.Header("Content-Length", strconv.FormatInt(item.Size, 10))
|
||||
|
||||
// 如果是 HEAD 请求,只返回 Header
|
||||
if c.Request.Method == http.MethodHead {
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := io.Copy(c.Writer, reader); err != nil {
|
||||
slog.Error("Error during file download", "file_id", item.ID, "error", err)
|
||||
}
|
||||
}
|
||||
175
internal/api/public/upload.go
Normal file
175
internal/api/public/upload.go
Normal file
@@ -0,0 +1,175 @@
|
||||
package public
|
||||
|
||||
import (
|
||||
"FileRelay/internal/config"
|
||||
"FileRelay/internal/model"
|
||||
"FileRelay/internal/service"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type UploadHandler struct {
|
||||
uploadService *service.UploadService
|
||||
}
|
||||
|
||||
func NewUploadHandler() *UploadHandler {
|
||||
return &UploadHandler{
|
||||
uploadService: service.NewUploadService(),
|
||||
}
|
||||
}
|
||||
|
||||
type UploadResponse struct {
|
||||
PickupCode string `json:"pickup_code"`
|
||||
ExpireAt *time.Time `json:"expire_at"`
|
||||
BatchID string `json:"batch_id"`
|
||||
}
|
||||
|
||||
// Upload 上传文件并生成取件码
|
||||
// @Summary 上传文件
|
||||
// @Description 上传一个或多个文件并创建一个提取批次。如果配置了 require_token,则必须提供带 upload scope 的 API Token。
|
||||
// @Tags Public
|
||||
// @Accept multipart/form-data
|
||||
// @Produce json
|
||||
// @Security APITokenAuth
|
||||
// @Param files formData file true "文件列表"
|
||||
// @Param remark formData string false "备注"
|
||||
// @Param expire_type formData string false "过期类型 (time/download/permanent)"
|
||||
// @Param expire_days formData int false "过期天数 (针对 time 类型)"
|
||||
// @Param max_downloads formData int false "最大下载次数 (针对 download 类型)"
|
||||
// @Success 200 {object} model.Response{data=UploadResponse}
|
||||
// @Failure 400 {object} model.Response
|
||||
// @Failure 500 {object} model.Response
|
||||
// @Router /api/batches [post]
|
||||
func (h *UploadHandler) Upload(c *gin.Context) {
|
||||
form, err := c.MultipartForm()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, model.ErrorResponse(model.CodeBadRequest, "invalid form"))
|
||||
return
|
||||
}
|
||||
|
||||
files := form.File["files"]
|
||||
if len(files) == 0 {
|
||||
c.JSON(http.StatusBadRequest, model.ErrorResponse(model.CodeBadRequest, "no files uploaded"))
|
||||
return
|
||||
}
|
||||
|
||||
if len(files) > config.GlobalConfig.Upload.MaxBatchFiles {
|
||||
c.JSON(http.StatusBadRequest, model.ErrorResponse(model.CodeBadRequest, "too many files"))
|
||||
return
|
||||
}
|
||||
|
||||
// 校验单个文件大小
|
||||
maxSize := config.GlobalConfig.Upload.MaxFileSizeMB * 1024 * 1024
|
||||
for _, file := range files {
|
||||
if file.Size > maxSize {
|
||||
c.JSON(http.StatusBadRequest, model.ErrorResponse(model.CodeBadRequest, fmt.Sprintf("文件 %s 超过最大限制 (%dMB)", file.Filename, config.GlobalConfig.Upload.MaxFileSizeMB)))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
remark := c.PostForm("remark")
|
||||
expireType := c.PostForm("expire_type") // time / download / permanent
|
||||
if expireType == "" {
|
||||
expireType = "time"
|
||||
}
|
||||
|
||||
var expireValue interface{}
|
||||
switch expireType {
|
||||
case "time":
|
||||
days, _ := strconv.Atoi(c.PostForm("expire_days"))
|
||||
if days <= 0 {
|
||||
days = config.GlobalConfig.Upload.MaxRetentionDays
|
||||
}
|
||||
expireValue = days
|
||||
case "download":
|
||||
max, _ := strconv.Atoi(c.PostForm("max_downloads"))
|
||||
if max <= 0 {
|
||||
max = 1
|
||||
}
|
||||
expireValue = max
|
||||
}
|
||||
|
||||
batch, err := h.uploadService.CreateBatch(c.Request.Context(), files, remark, expireType, expireValue)
|
||||
if err != nil {
|
||||
slog.Error("Upload handler failed to create batch", "error", err)
|
||||
c.JSON(http.StatusInternalServerError, model.ErrorResponse(model.CodeInternalError, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, model.SuccessResponse(UploadResponse{
|
||||
PickupCode: batch.PickupCode,
|
||||
ExpireAt: batch.ExpireAt,
|
||||
BatchID: batch.ID,
|
||||
}))
|
||||
}
|
||||
|
||||
type UploadTextRequest struct {
|
||||
Content string `json:"content" binding:"required" example:"这是一段长文本内容..."`
|
||||
Remark string `json:"remark" example:"文本备注"`
|
||||
ExpireType string `json:"expire_type" example:"time"`
|
||||
ExpireDays int `json:"expire_days" example:"7"`
|
||||
MaxDownloads int `json:"max_downloads" example:"5"`
|
||||
}
|
||||
|
||||
// UploadText 发送长文本并生成取件码
|
||||
// @Summary 发送长文本
|
||||
// @Description 中转一段长文本内容并创建一个提取批次。如果配置了 require_token,则必须提供带 upload scope 的 API Token。
|
||||
// @Tags Public
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security APITokenAuth
|
||||
// @Param request body UploadTextRequest true "文本内容及配置"
|
||||
// @Success 200 {object} model.Response{data=UploadResponse}
|
||||
// @Failure 400 {object} model.Response
|
||||
// @Failure 500 {object} model.Response
|
||||
// @Router /api/batches/text [post]
|
||||
func (h *UploadHandler) UploadText(c *gin.Context) {
|
||||
var req UploadTextRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, model.ErrorResponse(model.CodeBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
// 校验文本长度
|
||||
maxSize := config.GlobalConfig.Upload.MaxFileSizeMB * 1024 * 1024
|
||||
if int64(len(req.Content)) > maxSize {
|
||||
c.JSON(http.StatusBadRequest, model.ErrorResponse(model.CodeBadRequest, fmt.Sprintf("文本内容超过最大限制 (%dMB)", config.GlobalConfig.Upload.MaxFileSizeMB)))
|
||||
return
|
||||
}
|
||||
|
||||
if req.ExpireType == "" {
|
||||
req.ExpireType = "time"
|
||||
}
|
||||
|
||||
var expireValue interface{}
|
||||
switch req.ExpireType {
|
||||
case "time":
|
||||
if req.ExpireDays <= 0 {
|
||||
req.ExpireDays = config.GlobalConfig.Upload.MaxRetentionDays
|
||||
}
|
||||
expireValue = req.ExpireDays
|
||||
case "download":
|
||||
if req.MaxDownloads <= 0 {
|
||||
req.MaxDownloads = 1
|
||||
}
|
||||
expireValue = req.MaxDownloads
|
||||
}
|
||||
|
||||
batch, err := h.uploadService.CreateTextBatch(c.Request.Context(), req.Content, req.Remark, req.ExpireType, expireValue)
|
||||
if err != nil {
|
||||
slog.Error("Upload handler failed to create text batch", "error", err)
|
||||
c.JSON(http.StatusInternalServerError, model.ErrorResponse(model.CodeInternalError, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, model.SuccessResponse(UploadResponse{
|
||||
PickupCode: batch.PickupCode,
|
||||
ExpireAt: batch.ExpireAt,
|
||||
BatchID: batch.ID,
|
||||
}))
|
||||
}
|
||||
42
internal/auth/jwt.go
Normal file
42
internal/auth/jwt.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"FileRelay/internal/config"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
type Claims struct {
|
||||
AdminID uint `json:"admin_id"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
func GenerateToken(adminID uint) (string, error) {
|
||||
claims := Claims{
|
||||
AdminID: adminID,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(24 * time.Hour)),
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
},
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
return token.SignedString([]byte(config.GlobalConfig.Security.JWTSecret))
|
||||
}
|
||||
|
||||
func ParseToken(tokenString string) (*Claims, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
return []byte(config.GlobalConfig.Security.JWTSecret), nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if claims, ok := token.Claims.(*Claims); ok && token.Valid {
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
return nil, jwt.ErrSignatureInvalid
|
||||
}
|
||||
406
internal/bootstrap/init.go
Normal file
406
internal/bootstrap/init.go
Normal file
@@ -0,0 +1,406 @@
|
||||
package bootstrap
|
||||
|
||||
import (
|
||||
"FileRelay/internal/config"
|
||||
"FileRelay/internal/model"
|
||||
"FileRelay/internal/storage"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/big"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
|
||||
"log/slog"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/glebarez/sqlite"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var DB *gorm.DB
|
||||
|
||||
func InitLog() {
|
||||
level := slog.LevelInfo
|
||||
switch strings.ToLower(config.GlobalConfig.Log.Level) {
|
||||
case "debug":
|
||||
level = slog.LevelDebug
|
||||
case "info":
|
||||
level = slog.LevelInfo
|
||||
case "warn":
|
||||
level = slog.LevelWarn
|
||||
case "error":
|
||||
level = slog.LevelError
|
||||
}
|
||||
|
||||
var handlers []slog.Handler
|
||||
|
||||
// 1. 控制台处理器 (简化格式)
|
||||
handlers = append(handlers, &ConsoleHandler{
|
||||
out: os.Stdout,
|
||||
level: level,
|
||||
})
|
||||
|
||||
// 2. 文件处理器 (结构化 Text 格式)
|
||||
if config.GlobalConfig.Log.FilePath != "" {
|
||||
logDir := filepath.Dir(config.GlobalConfig.Log.FilePath)
|
||||
if err := os.MkdirAll(logDir, 0755); err != nil {
|
||||
fmt.Printf("Warning: Failed to create log directory: %v\n", err)
|
||||
} else {
|
||||
file, err := os.OpenFile(config.GlobalConfig.Log.FilePath, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0644)
|
||||
if err != nil {
|
||||
fmt.Printf("Warning: Failed to open log file: %v\n", err)
|
||||
} else {
|
||||
fileHandler := slog.NewTextHandler(file, &slog.HandlerOptions{
|
||||
Level: level,
|
||||
AddSource: true,
|
||||
ReplaceAttr: func(groups []string, a slog.Attr) slog.Attr {
|
||||
if a.Key == slog.TimeKey {
|
||||
return slog.String(a.Key, a.Value.Time().Format("2006-01-02 15:04:05.000"))
|
||||
}
|
||||
if a.Key == slog.SourceKey {
|
||||
source := a.Value.Any().(*slog.Source)
|
||||
return slog.String(a.Key, fmt.Sprintf("%s:%d", filepath.Base(source.File), source.Line))
|
||||
}
|
||||
return a
|
||||
},
|
||||
})
|
||||
handlers = append(handlers, fileHandler)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var finalHandler slog.Handler
|
||||
if len(handlers) == 1 {
|
||||
finalHandler = handlers[0]
|
||||
} else {
|
||||
finalHandler = &MultiHandler{handlers: handlers}
|
||||
}
|
||||
|
||||
logger := slog.New(finalHandler).With("service", "filerelay")
|
||||
slog.SetDefault(logger)
|
||||
}
|
||||
|
||||
func InitDB() {
|
||||
var err error
|
||||
cfg := config.GlobalConfig.Database
|
||||
|
||||
DB, err = ConnectDB(cfg)
|
||||
if err != nil {
|
||||
slog.Error("Failed to initialize database", "type", cfg.Type, "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
slog.Info("Database initialized and migrated", "type", cfg.Type)
|
||||
|
||||
// 初始化存储
|
||||
if err := ReloadStorage(); err != nil {
|
||||
slog.Error("Failed to initialize storage", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// 初始化管理员 (如果不存在)
|
||||
initAdmin()
|
||||
}
|
||||
|
||||
func ConnectDB(cfg config.DatabaseConfig) (*gorm.DB, error) {
|
||||
var dialector gorm.Dialector
|
||||
|
||||
switch strings.ToLower(cfg.Type) {
|
||||
case "mysql":
|
||||
params := cfg.Config
|
||||
if params == "" {
|
||||
params = "parseTime=True&loc=Local&charset=utf8mb4"
|
||||
} else {
|
||||
if !strings.Contains(strings.ToLower(params), "parsetime") {
|
||||
params += "&parseTime=True"
|
||||
}
|
||||
if !strings.Contains(strings.ToLower(params), "loc=") {
|
||||
params += "&loc=Local"
|
||||
}
|
||||
if !strings.Contains(strings.ToLower(params), "charset") {
|
||||
params += "&charset=utf8mb4"
|
||||
}
|
||||
}
|
||||
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?%s",
|
||||
cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.DBName, params)
|
||||
dialector = mysql.Open(dsn)
|
||||
case "postgres", "postgresql":
|
||||
dsn := fmt.Sprintf("host=%s user=%s password=%s dbname=%s port=%d %s",
|
||||
cfg.Host, cfg.User, cfg.Password, cfg.DBName, cfg.Port, cfg.Config)
|
||||
dialector = postgres.Open(dsn)
|
||||
case "sqlite", "sqlite3":
|
||||
fallthrough
|
||||
default:
|
||||
dbPath := cfg.Path
|
||||
if dbPath == "" {
|
||||
dbPath = "data/file_relay.db"
|
||||
}
|
||||
dialector = sqlite.Open(dbPath)
|
||||
}
|
||||
|
||||
db, err := gorm.Open(dialector, &gorm.Config{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 自动迁移
|
||||
err = db.AutoMigrate(
|
||||
&model.FileBatch{},
|
||||
&model.FileItem{},
|
||||
&model.APIToken{},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func ReloadDB(newCfg config.DatabaseConfig) error {
|
||||
newDB, err := ConnectDB(newCfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if DB != nil {
|
||||
// 检查是否真的是不同的数据库,避免自迁移导致冲突
|
||||
// 这里简单判断类型或连接串是否变化,或者直接让用户决定
|
||||
// 为了安全,我们只在连接参数确实变化时才尝试迁移
|
||||
slog.Info("Starting data migration to new database...")
|
||||
if err := MigrateData(DB, newDB); err != nil {
|
||||
slog.Error("Data migration failed", "error", err)
|
||||
// 迁移失败不一定需要中断,但需要记录
|
||||
} else {
|
||||
slog.Info("Data migration completed successfully")
|
||||
}
|
||||
}
|
||||
|
||||
DB = newDB
|
||||
return nil
|
||||
}
|
||||
|
||||
func MigrateData(sourceDB, targetDB *gorm.DB) error {
|
||||
// 迁移 APIToken
|
||||
var tokens []model.APIToken
|
||||
if err := sourceDB.Find(&tokens).Error; err == nil && len(tokens) > 0 {
|
||||
if err := targetDB.Save(&tokens).Error; err != nil {
|
||||
slog.Warn("Failed to migrate APITokens", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 迁移 FileBatch (分批处理以节省内存)
|
||||
var batches []model.FileBatch
|
||||
err := sourceDB.Model(&model.FileBatch{}).FindInBatches(&batches, 100, func(tx *gorm.DB, batch int) error {
|
||||
if err := targetDB.Save(&batches).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}).Error
|
||||
if err != nil {
|
||||
slog.Warn("Failed to migrate FileBatches", "error", err)
|
||||
}
|
||||
|
||||
// 迁移 FileItem (分批处理以节省内存)
|
||||
var items []model.FileItem
|
||||
err = sourceDB.Model(&model.FileItem{}).FindInBatches(&items, 100, func(tx *gorm.DB, batch int) error {
|
||||
if err := targetDB.Save(&items).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}).Error
|
||||
if err != nil {
|
||||
slog.Warn("Failed to migrate FileItems", "error", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func ReloadStorage() error {
|
||||
storageType := config.GlobalConfig.Storage.Type
|
||||
switch storageType {
|
||||
case "local":
|
||||
storage.GlobalStorage = storage.NewLocalStorage(config.GlobalConfig.Storage.Local.Path)
|
||||
case "webdav":
|
||||
cfg := config.GlobalConfig.Storage.WebDAV
|
||||
storage.GlobalStorage = storage.NewWebDAVStorage(cfg.URL, cfg.Username, cfg.Password, cfg.Root)
|
||||
case "s3":
|
||||
cfg := config.GlobalConfig.Storage.S3
|
||||
s3Storage, err := storage.NewS3Storage(context.Background(), cfg.Endpoint, cfg.Region, cfg.AccessKey, cfg.SecretKey, cfg.Bucket, cfg.UseSSL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
storage.GlobalStorage = s3Storage
|
||||
default:
|
||||
return fmt.Errorf("unsupported storage type: %s", storageType)
|
||||
}
|
||||
slog.Info("Storage initialized", "type", storageType)
|
||||
return nil
|
||||
}
|
||||
|
||||
func initAdmin() {
|
||||
passwordHash := config.GlobalConfig.Security.AdminPasswordHash
|
||||
if passwordHash == "" {
|
||||
// 生成随机密码
|
||||
password := generateRandomPassword(12)
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
slog.Error("Failed to generate password hash", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
passwordHash = string(hash)
|
||||
fmt.Printf("**************************************************\n")
|
||||
fmt.Printf("NO ADMIN PASSWORD CONFIGURED. GENERATED RANDOM PASSWORD:\n")
|
||||
fmt.Printf("Password: %s\n", password)
|
||||
fmt.Printf("Please save this password or configure admin_password_hash in config.yaml\n")
|
||||
fmt.Printf("**************************************************\n")
|
||||
|
||||
// 将生成的哈希保存回配置文件
|
||||
config.GlobalConfig.Security.AdminPasswordHash = passwordHash
|
||||
if err := config.SaveConfig(); err != nil {
|
||||
slog.Warn("Failed to save generated password hash to config", "error", err)
|
||||
}
|
||||
}
|
||||
slog.Info("Admin authentication initialized")
|
||||
}
|
||||
|
||||
func generateRandomPassword(length int) string {
|
||||
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!@#$%^&*"
|
||||
b := make([]byte, length)
|
||||
for i := range b {
|
||||
num, err := rand.Int(rand.Reader, big.NewInt(int64(len(charset))))
|
||||
if err != nil {
|
||||
return "admin123" // 退路
|
||||
}
|
||||
b[i] = charset[num.Int64()]
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
// ConsoleHandler 实现 simplified 控制台日志
|
||||
type ConsoleHandler struct {
|
||||
out io.Writer
|
||||
level slog.Leveler
|
||||
attrs []slog.Attr
|
||||
}
|
||||
|
||||
func (h *ConsoleHandler) Enabled(_ context.Context, l slog.Level) bool {
|
||||
return l >= h.level.Level()
|
||||
}
|
||||
|
||||
func (h *ConsoleHandler) Handle(_ context.Context, r slog.Record) error {
|
||||
var b strings.Builder
|
||||
|
||||
// 时间: 15:04:05.000
|
||||
b.WriteString(r.Time.Format("15:04:05.000"))
|
||||
b.WriteString(" ")
|
||||
|
||||
// 级别: INFO
|
||||
level := r.Level.String()
|
||||
b.WriteString(fmt.Sprintf("%-5s", level))
|
||||
b.WriteString(" ")
|
||||
|
||||
// 源码: [main.go:123]
|
||||
pc := r.PC
|
||||
if pc == 0 {
|
||||
// 如果 Logger 没采集 PC (自定义 Handler 默认情况),我们手动采集
|
||||
// 跳过: 1:runtime.Callers, 2:Handle, 3:slog.(*Logger).log, 4:slog.(*Logger).Info
|
||||
var pcs [1]uintptr
|
||||
runtime.Callers(4, pcs[:])
|
||||
pc = pcs[0]
|
||||
}
|
||||
|
||||
if pc != 0 {
|
||||
fs := runtime.CallersFrames([]uintptr{pc})
|
||||
f, _ := fs.Next()
|
||||
b.WriteString(fmt.Sprintf("[%s:%d] ", filepath.Base(f.File), f.Line))
|
||||
}
|
||||
|
||||
// 消息
|
||||
b.WriteString(r.Message)
|
||||
|
||||
// 属性: 仅输出值
|
||||
writeAttr := func(a slog.Attr) {
|
||||
if a.Key == "service" {
|
||||
return
|
||||
}
|
||||
b.WriteString(" ")
|
||||
val := a.Value.Resolve().String()
|
||||
if strings.Contains(val, " ") {
|
||||
b.WriteString(fmt.Sprintf("%q", val))
|
||||
} else {
|
||||
b.WriteString(val)
|
||||
}
|
||||
}
|
||||
|
||||
for _, a := range h.attrs {
|
||||
writeAttr(a)
|
||||
}
|
||||
r.Attrs(func(a slog.Attr) bool {
|
||||
writeAttr(a)
|
||||
return true
|
||||
})
|
||||
|
||||
b.WriteString("\n")
|
||||
_, err := h.out.Write([]byte(b.String()))
|
||||
return err
|
||||
}
|
||||
|
||||
func (h *ConsoleHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
|
||||
newAttrs := make([]slog.Attr, len(h.attrs)+len(attrs))
|
||||
copy(newAttrs, h.attrs)
|
||||
copy(newAttrs[len(h.attrs):], attrs)
|
||||
return &ConsoleHandler{
|
||||
out: h.out,
|
||||
level: h.level,
|
||||
attrs: newAttrs,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ConsoleHandler) WithGroup(name string) slog.Handler {
|
||||
return h // 简化版暂不支持分组
|
||||
}
|
||||
|
||||
// MultiHandler 实现多路分发
|
||||
type MultiHandler struct {
|
||||
handlers []slog.Handler
|
||||
}
|
||||
|
||||
func (h *MultiHandler) Enabled(ctx context.Context, l slog.Level) bool {
|
||||
for _, hh := range h.handlers {
|
||||
if hh.Enabled(ctx, l) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (h *MultiHandler) Handle(ctx context.Context, r slog.Record) error {
|
||||
for _, hh := range h.handlers {
|
||||
if hh.Enabled(ctx, r.Level) {
|
||||
_ = hh.Handle(ctx, r)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *MultiHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
|
||||
newHandlers := make([]slog.Handler, len(h.handlers))
|
||||
for i, hh := range h.handlers {
|
||||
newHandlers[i] = hh.WithAttrs(attrs)
|
||||
}
|
||||
return &MultiHandler{handlers: newHandlers}
|
||||
}
|
||||
|
||||
func (h *MultiHandler) WithGroup(name string) slog.Handler {
|
||||
newHandlers := make([]slog.Handler, len(h.handlers))
|
||||
for i, hh := range h.handlers {
|
||||
newHandlers[i] = hh.WithGroup(name)
|
||||
}
|
||||
return &MultiHandler{handlers: newHandlers}
|
||||
}
|
||||
282
internal/config/config.go
Normal file
282
internal/config/config.go
Normal file
@@ -0,0 +1,282 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Site SiteConfig `yaml:"site" json:"site"` // 站点设置
|
||||
Security SecurityConfig `yaml:"security" json:"security"` // 安全设置
|
||||
Upload UploadConfig `yaml:"upload" json:"upload"` // 上传设置
|
||||
Storage StorageConfig `yaml:"storage" json:"storage"` // 存储设置
|
||||
APIToken APITokenConfig `yaml:"api_token" json:"api_token"` // API Token 设置
|
||||
Database DatabaseConfig `yaml:"database" json:"database"` // 数据库设置
|
||||
Web WebConfig `yaml:"web" json:"web"` // Web 前端设置
|
||||
Log LogConfig `yaml:"log" json:"log"` // 日志设置
|
||||
}
|
||||
|
||||
type LogConfig struct {
|
||||
Level string `yaml:"level" json:"level"` // 日志级别: debug, info, warn, error
|
||||
FilePath string `yaml:"file_path" json:"file_path"` // 日志文件路径,为空则仅输出到控制台
|
||||
}
|
||||
|
||||
type WebConfig struct {
|
||||
Path string `yaml:"path" json:"path"` // Web 前端资源路径
|
||||
}
|
||||
|
||||
type SiteConfig struct {
|
||||
Name string `yaml:"name" json:"name"` // 站点名称
|
||||
Description string `yaml:"description" json:"description"` // 站点描述
|
||||
Logo string `yaml:"logo" json:"logo"` // 站点 Logo URL
|
||||
BaseURL string `yaml:"base_url" json:"base_url"` // 站点外部访问地址 (例如: https://file.example.com)
|
||||
Port int `yaml:"port" json:"port"` // 监听端口
|
||||
}
|
||||
|
||||
type SecurityConfig struct {
|
||||
AdminPasswordHash string `yaml:"admin_password_hash" json:"admin_password_hash"` // 管理员密码哈希 (bcrypt)
|
||||
AdminPassword string `yaml:"-" json:"admin_password,omitempty"` // 管理员密码明文 (仅用于更新请求,不保存到文件)
|
||||
PickupCodeLength int `yaml:"pickup_code_length" json:"pickup_code_length"` // 取件码长度 (变更后将自动通过右侧补零或截取调整存量数据)
|
||||
PickupFailLimit int `yaml:"pickup_fail_limit" json:"pickup_fail_limit"` // 取件失败尝试限制
|
||||
JWTSecret string `yaml:"jwt_secret" json:"jwt_secret"` // JWT 签名密钥
|
||||
}
|
||||
|
||||
type UploadConfig struct {
|
||||
MaxFileSizeMB int64 `yaml:"max_file_size_mb" json:"max_file_size_mb"` // 单个文件最大大小 (MB)
|
||||
MaxBatchFiles int `yaml:"max_batch_files" json:"max_batch_files"` // 每个批次最大文件数
|
||||
MaxRetentionDays int `yaml:"max_retention_days" json:"max_retention_days"` // 最大保留天数
|
||||
RequireToken bool `yaml:"require_token" json:"require_token"` // 是否强制要求上传 Token
|
||||
}
|
||||
|
||||
type StorageConfig struct {
|
||||
Type string `yaml:"type" json:"type"` // 存储类型: local, webdav, s3
|
||||
Local struct {
|
||||
Path string `yaml:"path" json:"path"` // 本地存储路径
|
||||
} `yaml:"local" json:"local"`
|
||||
WebDAV struct {
|
||||
URL string `yaml:"url" json:"url"` // WebDAV 地址
|
||||
Username string `yaml:"username" json:"username"` // WebDAV 用户名
|
||||
Password string `yaml:"password" json:"password"` // WebDAV 密码
|
||||
Root string `yaml:"root" json:"root"` // WebDAV 根目录
|
||||
} `yaml:"webdav" json:"webdav"`
|
||||
S3 struct {
|
||||
Endpoint string `yaml:"endpoint" json:"endpoint"` // S3 端点
|
||||
Region string `yaml:"region" json:"region"` // S3 区域
|
||||
AccessKey string `yaml:"access_key" json:"access_key"` // S3 Access Key
|
||||
SecretKey string `yaml:"secret_key" json:"secret_key"` // S3 Secret Key
|
||||
Bucket string `yaml:"bucket" json:"bucket"` // S3 Bucket
|
||||
UseSSL bool `yaml:"use_ssl" json:"use_ssl"` // 是否使用 SSL
|
||||
} `yaml:"s3" json:"s3"`
|
||||
}
|
||||
|
||||
type APITokenConfig struct {
|
||||
Enabled bool `yaml:"enabled" json:"enabled"` // 是否启用 API Token
|
||||
AllowAdminAPI bool `yaml:"allow_admin_api" json:"allow_admin_api"` // 是否允许 API Token 访问管理接口
|
||||
MaxTokens int `yaml:"max_tokens" json:"max_tokens"` // 最大 Token 数量
|
||||
}
|
||||
|
||||
type DatabaseConfig struct {
|
||||
Type string `yaml:"type" json:"type"` // 数据库类型: sqlite, mysql, postgres
|
||||
Path string `yaml:"path" json:"path"` // SQLite 数据库文件路径
|
||||
Host string `yaml:"host" json:"host"` // 数据库地址
|
||||
Port int `yaml:"port" json:"port"` // 数据库端口
|
||||
User string `yaml:"user" json:"user"` // 数据库用户名
|
||||
Password string `yaml:"password" json:"password"` // 数据库密码
|
||||
DBName string `yaml:"dbname" json:"dbname"` // 数据库名称
|
||||
Config string `yaml:"config" json:"config"` // 额外配置参数 (DSN)
|
||||
}
|
||||
|
||||
var (
|
||||
GlobalConfig *Config
|
||||
ConfigPath string
|
||||
configLock sync.RWMutex
|
||||
)
|
||||
|
||||
func LoadConfig(path string) error {
|
||||
configLock.Lock()
|
||||
defer configLock.Unlock()
|
||||
|
||||
// 检查文件是否存在
|
||||
if _, err := os.Stat(path); os.IsNotExist(err) {
|
||||
// 创建默认配置
|
||||
cfg := GetDefaultConfig()
|
||||
data, err := yaml.Marshal(cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// 确保目录存在
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.WriteFile(path, data, 0644); err != nil {
|
||||
return err
|
||||
}
|
||||
cfg.overrideWithEnv()
|
||||
GlobalConfig = cfg
|
||||
ConfigPath = path
|
||||
return nil
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var cfg Config
|
||||
if err := yaml.Unmarshal(data, &cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cfg.overrideWithEnv()
|
||||
GlobalConfig = &cfg
|
||||
ConfigPath = path
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cfg *Config) overrideWithEnv() {
|
||||
// Site settings
|
||||
if val := os.Getenv("FR_SITE_NAME"); val != "" {
|
||||
cfg.Site.Name = val
|
||||
}
|
||||
if val := os.Getenv("FR_SITE_BASE_URL"); val != "" {
|
||||
cfg.Site.BaseURL = val
|
||||
}
|
||||
if val := os.Getenv("FR_SITE_PORT"); val != "" {
|
||||
if port, err := strconv.Atoi(val); err == nil {
|
||||
cfg.Site.Port = port
|
||||
}
|
||||
}
|
||||
|
||||
// Security settings
|
||||
if val := os.Getenv("FR_SECURITY_JWT_SECRET"); val != "" {
|
||||
cfg.Security.JWTSecret = val
|
||||
}
|
||||
|
||||
// Upload settings
|
||||
if val := os.Getenv("FR_UPLOAD_MAX_SIZE"); val != "" {
|
||||
if size, err := strconv.ParseInt(val, 10, 64); err == nil {
|
||||
cfg.Upload.MaxFileSizeMB = size
|
||||
}
|
||||
}
|
||||
if val := os.Getenv("FR_UPLOAD_RETENTION_DAYS"); val != "" {
|
||||
if days, err := strconv.Atoi(val); err == nil {
|
||||
cfg.Upload.MaxRetentionDays = days
|
||||
}
|
||||
}
|
||||
|
||||
// Database settings
|
||||
if val := os.Getenv("FR_DB_TYPE"); val != "" {
|
||||
cfg.Database.Type = val
|
||||
}
|
||||
if val := os.Getenv("FR_DB_PATH"); val != "" {
|
||||
cfg.Database.Path = val
|
||||
}
|
||||
if val := os.Getenv("FR_DB_HOST"); val != "" {
|
||||
cfg.Database.Host = val
|
||||
}
|
||||
if val := os.Getenv("FR_DB_PORT"); val != "" {
|
||||
if port, err := strconv.Atoi(val); err == nil {
|
||||
cfg.Database.Port = port
|
||||
}
|
||||
}
|
||||
if val := os.Getenv("FR_DB_USER"); val != "" {
|
||||
cfg.Database.User = val
|
||||
}
|
||||
if val := os.Getenv("FR_DB_PASSWORD"); val != "" {
|
||||
cfg.Database.Password = val
|
||||
}
|
||||
if val := os.Getenv("FR_DB_NAME"); val != "" {
|
||||
cfg.Database.DBName = val
|
||||
}
|
||||
|
||||
// Storage settings
|
||||
if val := os.Getenv("FR_STORAGE_TYPE"); val != "" {
|
||||
cfg.Storage.Type = val
|
||||
}
|
||||
if val := os.Getenv("FR_STORAGE_LOCAL_PATH"); val != "" {
|
||||
cfg.Storage.Local.Path = val
|
||||
}
|
||||
|
||||
// Log settings
|
||||
if val := os.Getenv("FR_LOG_LEVEL"); val != "" {
|
||||
cfg.Log.Level = val
|
||||
}
|
||||
if val := os.Getenv("FR_LOG_FILE_PATH"); val != "" {
|
||||
cfg.Log.FilePath = val
|
||||
}
|
||||
|
||||
// Web settings
|
||||
if val := os.Getenv("FR_WEB_PATH"); val != "" {
|
||||
cfg.Web.Path = val
|
||||
}
|
||||
}
|
||||
|
||||
func GetDefaultConfig() *Config {
|
||||
return &Config{
|
||||
Site: SiteConfig{
|
||||
Name: "文件暂存柜",
|
||||
Description: "临时文件中转服务",
|
||||
Logo: "/favicon.png",
|
||||
BaseURL: "",
|
||||
Port: 8080,
|
||||
},
|
||||
Security: SecurityConfig{
|
||||
AdminPasswordHash: "$2a$10$Bm0TEmU4uj.bVHYiIPFBheUkcdg6XHpsanLvmpoAtgU1UnKbo9.vy", // 默认密码: admin123
|
||||
PickupCodeLength: 6,
|
||||
PickupFailLimit: 5,
|
||||
JWTSecret: "file-relay-secret",
|
||||
},
|
||||
Upload: UploadConfig{
|
||||
MaxFileSizeMB: 100,
|
||||
MaxBatchFiles: 20,
|
||||
MaxRetentionDays: 30,
|
||||
RequireToken: false,
|
||||
},
|
||||
Storage: StorageConfig{
|
||||
Type: "local",
|
||||
Local: struct {
|
||||
Path string `yaml:"path" json:"path"`
|
||||
}{
|
||||
Path: "data/storage_data",
|
||||
},
|
||||
},
|
||||
APIToken: APITokenConfig{
|
||||
Enabled: true,
|
||||
AllowAdminAPI: true,
|
||||
MaxTokens: 20,
|
||||
},
|
||||
Database: DatabaseConfig{
|
||||
Type: "sqlite",
|
||||
Path: "data/file_relay.db",
|
||||
},
|
||||
Web: WebConfig{
|
||||
Path: "web",
|
||||
},
|
||||
Log: LogConfig{
|
||||
Level: "info",
|
||||
FilePath: "data/logs/app.log",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func SaveConfig() error {
|
||||
configLock.RLock()
|
||||
defer configLock.RUnlock()
|
||||
|
||||
data, err := yaml.Marshal(GlobalConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return os.WriteFile(ConfigPath, data, 0644)
|
||||
}
|
||||
|
||||
func UpdateGlobalConfig(cfg *Config) {
|
||||
configLock.Lock()
|
||||
defer configLock.Unlock()
|
||||
GlobalConfig = cfg
|
||||
}
|
||||
43
internal/config/config_test.go
Normal file
43
internal/config/config_test.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestGetDefaultConfig(t *testing.T) {
|
||||
cfg := GetDefaultConfig()
|
||||
assert.Equal(t, "data/storage_data", cfg.Storage.Local.Path)
|
||||
assert.Equal(t, "data/file_relay.db", cfg.Database.Path)
|
||||
assert.Equal(t, "data/logs/app.log", cfg.Log.FilePath)
|
||||
}
|
||||
|
||||
func TestOverrideWithEnv(t *testing.T) {
|
||||
// 设置环境变量
|
||||
os.Setenv("FR_SITE_NAME", "EnvSiteName")
|
||||
os.Setenv("FR_SITE_PORT", "9999")
|
||||
os.Setenv("FR_DB_TYPE", "mysql")
|
||||
os.Setenv("FR_DB_PATH", "custom_db_path.db")
|
||||
os.Setenv("FR_UPLOAD_MAX_SIZE", "500")
|
||||
os.Setenv("FR_LOG_FILE_PATH", "custom_log_path.log")
|
||||
|
||||
cfg := GetDefaultConfig()
|
||||
cfg.overrideWithEnv()
|
||||
|
||||
assert.Equal(t, "EnvSiteName", cfg.Site.Name)
|
||||
assert.Equal(t, 9999, cfg.Site.Port)
|
||||
assert.Equal(t, "mysql", cfg.Database.Type)
|
||||
assert.Equal(t, "custom_db_path.db", cfg.Database.Path)
|
||||
assert.Equal(t, int64(500), cfg.Upload.MaxFileSizeMB)
|
||||
assert.Equal(t, "custom_log_path.log", cfg.Log.FilePath)
|
||||
|
||||
// 清理环境变量
|
||||
os.Unsetenv("FR_SITE_NAME")
|
||||
os.Unsetenv("FR_SITE_PORT")
|
||||
os.Unsetenv("FR_DB_TYPE")
|
||||
os.Unsetenv("FR_DB_PATH")
|
||||
os.Unsetenv("FR_UPLOAD_MAX_SIZE")
|
||||
os.Unsetenv("FR_LOG_FILE_PATH")
|
||||
}
|
||||
11
internal/model/admin.go
Normal file
11
internal/model/admin.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// AdminSession 管理员会话信息 (不再存库,仅用于 JWT 或 API 交互)
|
||||
type AdminSession struct {
|
||||
ID uint `json:"id"`
|
||||
LastLogin *time.Time `json:"last_login"`
|
||||
}
|
||||
22
internal/model/api_token.go
Normal file
22
internal/model/api_token.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
ScopeUpload = "upload" // 上传权限
|
||||
ScopePickup = "pickup" // 取件/下载权限
|
||||
ScopeAdmin = "admin" // 管理权限
|
||||
)
|
||||
|
||||
type APIToken struct {
|
||||
ID uint `gorm:"primaryKey" json:"id"`
|
||||
Name string `json:"name"`
|
||||
TokenHash string `gorm:"uniqueIndex;not null" json:"-"`
|
||||
Scope string `json:"scope"`
|
||||
ExpireAt *time.Time `json:"expire_at"`
|
||||
LastUsedAt *time.Time `json:"last_used_at"`
|
||||
Revoked bool `gorm:"default:false" json:"revoked"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
24
internal/model/file_batch.go
Normal file
24
internal/model/file_batch.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type FileBatch struct {
|
||||
ID string `gorm:"primaryKey;type:varchar(36)" json:"id"`
|
||||
PickupCode string `gorm:"uniqueIndex;not null" json:"pickup_code"`
|
||||
Remark string `json:"remark"`
|
||||
ExpireType string `json:"expire_type"` // time / download / permanent
|
||||
ExpireAt *time.Time `json:"expire_at"`
|
||||
MaxDownloads int `json:"max_downloads"`
|
||||
DownloadCount int `gorm:"default:0" json:"download_count"`
|
||||
Status string `gorm:"default:'active'" json:"status"` // active / expired / deleted
|
||||
Type string `gorm:"default:'file'" json:"type"` // file / text
|
||||
Content string `json:"content,omitempty"`
|
||||
FileItems []FileItem `gorm:"foreignKey:BatchID" json:"file_items,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
|
||||
}
|
||||
16
internal/model/file_item.go
Normal file
16
internal/model/file_item.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type FileItem struct {
|
||||
ID string `gorm:"primaryKey;type:varchar(36)" json:"id"`
|
||||
BatchID string `gorm:"index;not null;type:varchar(36)" json:"batch_id"`
|
||||
OriginalName string `json:"original_name"`
|
||||
StoragePath string `json:"storage_path"`
|
||||
Size int64 `json:"size"`
|
||||
MimeType string `json:"mime_type"`
|
||||
DownloadURL string `gorm:"-" json:"download_url,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
46
internal/model/response.go
Normal file
46
internal/model/response.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package model
|
||||
|
||||
// Response 统一响应模型
|
||||
type Response struct {
|
||||
Code int `json:"code" example:"200"`
|
||||
Msg string `json:"msg" example:"success"`
|
||||
Data interface{} `json:"data"`
|
||||
}
|
||||
|
||||
// 错误码定义
|
||||
const (
|
||||
CodeSuccess = 200
|
||||
CodeBadRequest = 400
|
||||
CodeUnauthorized = 401
|
||||
CodeForbidden = 403
|
||||
CodeNotFound = 404
|
||||
CodeGone = 410
|
||||
CodeInternalError = 500
|
||||
CodeTooManyRequests = 429
|
||||
)
|
||||
|
||||
// NewResponse 创建响应
|
||||
func NewResponse(code int, msg string, data interface{}) Response {
|
||||
return Response{
|
||||
Code: code,
|
||||
Msg: msg,
|
||||
Data: data,
|
||||
}
|
||||
}
|
||||
|
||||
// SuccessResponse 成功响应
|
||||
func SuccessResponse(data interface{}) Response {
|
||||
return Response{
|
||||
Code: CodeSuccess,
|
||||
Msg: "success",
|
||||
Data: data,
|
||||
}
|
||||
}
|
||||
|
||||
// ErrorResponse 错误响应
|
||||
func ErrorResponse(code int, msg string) Response {
|
||||
return Response{
|
||||
Code: code,
|
||||
Msg: msg,
|
||||
}
|
||||
}
|
||||
276
internal/service/batch_service.go
Normal file
276
internal/service/batch_service.go
Normal file
@@ -0,0 +1,276 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"FileRelay/internal/bootstrap"
|
||||
"FileRelay/internal/model"
|
||||
"FileRelay/internal/storage"
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"math/big"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"crypto/rand"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type BatchService struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewBatchService() *BatchService {
|
||||
return &BatchService{db: bootstrap.DB}
|
||||
}
|
||||
|
||||
func (s *BatchService) GetBatchByPickupCode(code string) (*model.FileBatch, error) {
|
||||
var batch model.FileBatch
|
||||
err := s.db.Preload("FileItems").Where("pickup_code = ? AND status = ?", code, "active").First(&batch).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 检查是否过期
|
||||
if s.IsExpired(&batch) {
|
||||
s.MarkAsExpired(&batch)
|
||||
return nil, errors.New("batch expired")
|
||||
}
|
||||
|
||||
return &batch, nil
|
||||
}
|
||||
|
||||
func (s *BatchService) GetDownloadCountByPickupCode(code string) (int, int, error) {
|
||||
var batch model.FileBatch
|
||||
// 查询活跃或已过期的批次
|
||||
err := s.db.Where("pickup_code = ? AND (status = ? OR status = ?)", code, "active", "expired").First(&batch).Error
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
return batch.DownloadCount, batch.MaxDownloads, nil
|
||||
}
|
||||
|
||||
func (s *BatchService) IsExpired(batch *model.FileBatch) bool {
|
||||
if batch.Status != "active" {
|
||||
return true
|
||||
}
|
||||
|
||||
switch batch.ExpireType {
|
||||
case "time":
|
||||
if batch.ExpireAt != nil && time.Now().After(*batch.ExpireAt) {
|
||||
return true
|
||||
}
|
||||
case "download":
|
||||
if batch.MaxDownloads > 0 && batch.DownloadCount >= batch.MaxDownloads {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *BatchService) MarkAsExpired(batch *model.FileBatch) error {
|
||||
slog.Info("Marking batch as expired", "batch_id", batch.ID, "pickup_code", batch.PickupCode)
|
||||
return s.db.Model(batch).Update("status", "expired").Error
|
||||
}
|
||||
|
||||
func (s *BatchService) DeleteBatch(ctx context.Context, batchID string) error {
|
||||
var batch model.FileBatch
|
||||
// 使用 Unscoped 以确保即使是已软删除的批次也能找到并清理其物理文件
|
||||
if err := s.db.Unscoped().Preload("FileItems").First(&batch, "id = ?", batchID).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
slog.Info("Deleting batch", "batch_id", batch.ID, "files_count", len(batch.FileItems))
|
||||
|
||||
// 删除物理文件
|
||||
for _, item := range batch.FileItems {
|
||||
if err := storage.GlobalStorage.Delete(ctx, item.StoragePath); err != nil {
|
||||
slog.Error("Failed to delete physical file", "path", item.StoragePath, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 删除数据库记录 (彻底删除,不再保留元数据以便清理任务不再扫描到它)
|
||||
return s.db.Transaction(func(tx *gorm.DB) error {
|
||||
if err := tx.Where("batch_id = ?", batch.ID).Delete(&model.FileItem{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
if err := tx.Unscoped().Delete(&batch).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BatchService) IncrementDownloadCount(batchID string) error {
|
||||
if batchID == "" {
|
||||
return errors.New("batch id is empty")
|
||||
}
|
||||
result := s.db.Model(&model.FileBatch{}).Where("id = ?", batchID).
|
||||
UpdateColumn("download_count", gorm.Expr("download_count + ?", 1))
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
if result.RowsAffected == 0 {
|
||||
return errors.New("batch not found or already deleted")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *BatchService) GeneratePickupCode(length int) (string, error) {
|
||||
const charset = "0123456789"
|
||||
b := make([]byte, length)
|
||||
for i := range b {
|
||||
num, err := rand.Int(rand.Reader, big.NewInt(int64(len(charset))))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
b[i] = charset[num.Int64()]
|
||||
}
|
||||
// 检查是否冲突 (排除已删除的,但包括活跃的和已过期的)
|
||||
var count int64
|
||||
s.db.Model(&model.FileBatch{}).Where("pickup_code = ?", string(b)).Count(&count)
|
||||
if count > 0 {
|
||||
return s.GeneratePickupCode(length) // 递归生成
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
func (s *BatchService) UpdateAllPickupCodes(newLength int) error {
|
||||
var batches []model.FileBatch
|
||||
// 只更新未删除的记录,包括 active 和 expired
|
||||
if err := s.db.Find(&batches).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.db.Transaction(func(tx *gorm.DB) error {
|
||||
for _, batch := range batches {
|
||||
oldCode := batch.PickupCode
|
||||
if len(oldCode) == newLength {
|
||||
continue
|
||||
}
|
||||
|
||||
var newCode string
|
||||
if len(oldCode) < newLength {
|
||||
// 右侧补零,方便用户输入原码后通过补 0 完成输入
|
||||
newCode = oldCode + strings.Repeat("0", newLength-len(oldCode))
|
||||
} else {
|
||||
// 截取前 newLength 位,保留原码头部
|
||||
newCode = oldCode[:newLength]
|
||||
}
|
||||
|
||||
// 检查是否冲突 (在事务中检查)
|
||||
var count int64
|
||||
tx.Model(&model.FileBatch{}).Where("pickup_code = ? AND id != ?", newCode, batch.ID).Count(&count)
|
||||
if count > 0 {
|
||||
// 如果冲突,生成一个新的随机码
|
||||
var err error
|
||||
newCode, err = s.generateUniquePickupCodeInTx(tx, newLength)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Model(&batch).Update("pickup_code", newCode).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BatchService) Cleanup(ctx context.Context) error {
|
||||
slog.Debug("Starting cleanup scan")
|
||||
// 1. 寻找并标记过期的 Active Batches
|
||||
var batches []model.FileBatch
|
||||
now := time.Now()
|
||||
// 同时检查时间过期 and 下载次数过期
|
||||
err := s.db.Where("status = ? AND ("+
|
||||
"(expire_type = 'time' AND expire_at < ?) OR "+
|
||||
"(expire_type = 'download' AND max_downloads > 0 AND download_count >= max_downloads)"+
|
||||
")", "active", now).Find(&batches).Error
|
||||
if err != nil {
|
||||
slog.Error("Failed to query active batches for cleanup", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
if len(batches) > 0 {
|
||||
slog.Info("Found expired batches to mark", "count", len(batches))
|
||||
}
|
||||
|
||||
for _, batch := range batches {
|
||||
_ = s.MarkAsExpired(&batch)
|
||||
}
|
||||
|
||||
// 2. 检查异常文件 (物理文件缺失)
|
||||
var activeBatches []model.FileBatch
|
||||
if err := s.db.Preload("FileItems").Where("status = ?", "active").Find(&activeBatches).Error; err == nil {
|
||||
for _, batch := range activeBatches {
|
||||
var missingItems []model.FileItem
|
||||
for _, item := range batch.FileItems {
|
||||
exists, err := storage.GlobalStorage.Exists(ctx, item.StoragePath)
|
||||
if err != nil {
|
||||
slog.Error("Failed to check file existence", "path", item.StoragePath, "error", err)
|
||||
continue
|
||||
}
|
||||
if !exists {
|
||||
missingItems = append(missingItems, item)
|
||||
}
|
||||
}
|
||||
|
||||
if len(missingItems) > 0 {
|
||||
slog.Info("Removing missing files from batch", "batch_id", batch.ID, "count", len(missingItems))
|
||||
for _, item := range missingItems {
|
||||
if err := s.db.Delete(&item).Error; err != nil {
|
||||
slog.Error("Failed to remove missing file record", "item_id", item.ID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(missingItems) == len(batch.FileItems) {
|
||||
slog.Warn("All files missing for batch, marking as expired", "batch_id", batch.ID)
|
||||
_ = s.MarkAsExpired(&batch)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 彻底清理标记为 expired 或 deleted 的批次
|
||||
var toDelete []model.FileBatch
|
||||
// Unscoped 用于包含已软删除但尚未物理清理的记录
|
||||
err = s.db.Unscoped().Where("status IN ? OR deleted_at IS NOT NULL", []string{"expired", "deleted"}).Find(&toDelete).Error
|
||||
if err != nil {
|
||||
slog.Error("Failed to query batches for physical deletion", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
if len(toDelete) > 0 {
|
||||
slog.Info("Found batches for physical deletion", "count", len(toDelete))
|
||||
}
|
||||
|
||||
for _, batch := range toDelete {
|
||||
if err := s.DeleteBatch(ctx, batch.ID); err != nil {
|
||||
slog.Error("Failed to physically delete batch", "batch_id", batch.ID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *BatchService) generateUniquePickupCodeInTx(tx *gorm.DB, length int) (string, error) {
|
||||
const charset = "0123456789"
|
||||
for {
|
||||
b := make([]byte, length)
|
||||
for i := range b {
|
||||
num, err := rand.Int(rand.Reader, big.NewInt(int64(len(charset))))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
b[i] = charset[num.Int64()]
|
||||
}
|
||||
|
||||
var count int64
|
||||
tx.Model(&model.FileBatch{}).Where("pickup_code = ?", string(b)).Count(&count)
|
||||
if count == 0 {
|
||||
return string(b), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
83
internal/service/token_service.go
Normal file
83
internal/service/token_service.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"FileRelay/internal/bootstrap"
|
||||
"FileRelay/internal/model"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type TokenService struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewTokenService() *TokenService {
|
||||
return &TokenService{db: bootstrap.DB}
|
||||
}
|
||||
|
||||
func (s *TokenService) CreateToken(name string, scope string, expireAt *time.Time) (string, *model.APIToken, error) {
|
||||
rawToken := uuid.New().String()
|
||||
hash := s.hashToken(rawToken)
|
||||
|
||||
token := &model.APIToken{
|
||||
Name: name,
|
||||
TokenHash: hash,
|
||||
Scope: scope,
|
||||
ExpireAt: expireAt,
|
||||
}
|
||||
|
||||
if err := s.db.Create(token).Error; err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
return rawToken, token, nil
|
||||
}
|
||||
|
||||
func (s *TokenService) ValidateToken(rawToken string, requiredScope string) (*model.APIToken, error) {
|
||||
hash := s.hashToken(rawToken)
|
||||
var token model.APIToken
|
||||
if err := s.db.Where("token_hash = ? AND revoked = ?", hash, false).First(&token).Error; err != nil {
|
||||
return nil, errors.New("invalid token")
|
||||
}
|
||||
|
||||
if token.ExpireAt != nil && time.Now().After(*token.ExpireAt) {
|
||||
return nil, errors.New("token expired")
|
||||
}
|
||||
|
||||
// 检查 Scope (简单包含判断)
|
||||
// 在实际应用中可以实现更复杂的逻辑
|
||||
if requiredScope != "" && !s.checkScope(token.Scope, requiredScope) {
|
||||
return nil, errors.New("insufficient scope")
|
||||
}
|
||||
|
||||
// 更新最后使用时间
|
||||
now := time.Now()
|
||||
s.db.Model(&token).Update("last_used_at", &now)
|
||||
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
func (s *TokenService) hashToken(token string) string {
|
||||
h := sha256.New()
|
||||
h.Write([]byte(token))
|
||||
return hex.EncodeToString(h.Sum(nil))
|
||||
}
|
||||
|
||||
func (s *TokenService) checkScope(tokenScope, requiredScope string) bool {
|
||||
if requiredScope == "" {
|
||||
return true
|
||||
}
|
||||
scopes := strings.Split(tokenScope, ",")
|
||||
for _, s := range scopes {
|
||||
if strings.TrimSpace(s) == requiredScope {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
149
internal/service/upload_service.go
Normal file
149
internal/service/upload_service.go
Normal file
@@ -0,0 +1,149 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"FileRelay/internal/bootstrap"
|
||||
"FileRelay/internal/config"
|
||||
"FileRelay/internal/model"
|
||||
"FileRelay/internal/storage"
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"mime/multipart"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type UploadService struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewUploadService() *UploadService {
|
||||
return &UploadService{db: bootstrap.DB}
|
||||
}
|
||||
|
||||
func (s *UploadService) CreateBatch(ctx context.Context, files []*multipart.FileHeader, remark string, expireType string, expireValue interface{}) (*model.FileBatch, error) {
|
||||
// 1. 生成取件码
|
||||
batchService := NewBatchService()
|
||||
pickupCode, err := batchService.GeneratePickupCode(config.GlobalConfig.Security.PickupCodeLength)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 2. 准备 Batch
|
||||
batch := &model.FileBatch{
|
||||
ID: uuid.New().String(),
|
||||
PickupCode: pickupCode,
|
||||
Remark: remark,
|
||||
ExpireType: expireType,
|
||||
Status: "active",
|
||||
Type: "file",
|
||||
}
|
||||
|
||||
s.applyExpire(batch, expireType, expireValue)
|
||||
|
||||
// 3. 处理文件上传
|
||||
err = s.db.Transaction(func(tx *gorm.DB) error {
|
||||
if err := tx.Create(batch).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, fileHeader := range files {
|
||||
fileItem, err := s.processFile(ctx, tx, batch.ID, fileHeader)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
batch.FileItems = append(batch.FileItems, *fileItem)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err == nil {
|
||||
slog.Info("File batch created", "batch_id", batch.ID, "pickup_code", batch.PickupCode, "files_count", len(files))
|
||||
} else {
|
||||
slog.Error("Failed to create file batch", "error", err)
|
||||
}
|
||||
|
||||
return batch, err
|
||||
}
|
||||
|
||||
func (s *UploadService) CreateTextBatch(ctx context.Context, content string, remark string, expireType string, expireValue interface{}) (*model.FileBatch, error) {
|
||||
// 1. 生成取件码
|
||||
batchService := NewBatchService()
|
||||
pickupCode, err := batchService.GeneratePickupCode(config.GlobalConfig.Security.PickupCodeLength)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 2. 准备 Batch
|
||||
batch := &model.FileBatch{
|
||||
ID: uuid.New().String(),
|
||||
PickupCode: pickupCode,
|
||||
Remark: remark,
|
||||
ExpireType: expireType,
|
||||
Status: "active",
|
||||
Type: "text",
|
||||
Content: content,
|
||||
}
|
||||
|
||||
s.applyExpire(batch, expireType, expireValue)
|
||||
|
||||
// 3. 保存
|
||||
if err := s.db.Create(batch).Error; err != nil {
|
||||
slog.Error("Failed to create text batch", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
slog.Info("Text batch created", "batch_id", batch.ID, "pickup_code", batch.PickupCode)
|
||||
return batch, nil
|
||||
}
|
||||
|
||||
func (s *UploadService) applyExpire(batch *model.FileBatch, expireType string, expireValue interface{}) {
|
||||
switch expireType {
|
||||
case "time":
|
||||
if days, ok := expireValue.(int); ok {
|
||||
expireAt := time.Now().Add(time.Duration(days) * 24 * time.Hour)
|
||||
batch.ExpireAt = &expireAt
|
||||
}
|
||||
case "download":
|
||||
if max, ok := expireValue.(int); ok {
|
||||
batch.MaxDownloads = max
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *UploadService) processFile(ctx context.Context, tx *gorm.DB, batchID string, fileHeader *multipart.FileHeader) (*model.FileItem, error) {
|
||||
file, err := fileHeader.Open()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// 生成唯一存储路径
|
||||
ext := filepath.Ext(fileHeader.Filename)
|
||||
fileID := uuid.New().String()
|
||||
storagePath := fmt.Sprintf("%s/%s%s", batchID, fileID, ext)
|
||||
|
||||
// 保存到存储层
|
||||
if err := storage.GlobalStorage.Save(ctx, storagePath, file); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 创建数据库记录
|
||||
item := &model.FileItem{
|
||||
ID: fileID,
|
||||
BatchID: batchID,
|
||||
OriginalName: fileHeader.Filename,
|
||||
StoragePath: storagePath,
|
||||
Size: fileHeader.Size,
|
||||
MimeType: fileHeader.Header.Get("Content-Type"),
|
||||
}
|
||||
|
||||
if err := tx.Create(item).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return item, nil
|
||||
}
|
||||
59
internal/storage/local.go
Normal file
59
internal/storage/local.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
type LocalStorage struct {
|
||||
RootPath string
|
||||
}
|
||||
|
||||
func NewLocalStorage(rootPath string) *LocalStorage {
|
||||
// 确保根目录存在
|
||||
if _, err := os.Stat(rootPath); os.IsNotExist(err) {
|
||||
os.MkdirAll(rootPath, 0755)
|
||||
}
|
||||
return &LocalStorage{RootPath: rootPath}
|
||||
}
|
||||
|
||||
func (s *LocalStorage) Save(ctx context.Context, path string, reader io.Reader) error {
|
||||
fullPath := filepath.Join(s.RootPath, path)
|
||||
dir := filepath.Dir(fullPath)
|
||||
if _, err := os.Stat(dir); os.IsNotExist(err) {
|
||||
os.MkdirAll(dir, 0755)
|
||||
}
|
||||
|
||||
file, err := os.Create(fullPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
_, err = io.Copy(file, reader)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *LocalStorage) Open(ctx context.Context, path string) (io.ReadCloser, error) {
|
||||
fullPath := filepath.Join(s.RootPath, path)
|
||||
return os.Open(fullPath)
|
||||
}
|
||||
|
||||
func (s *LocalStorage) Delete(ctx context.Context, path string) error {
|
||||
fullPath := filepath.Join(s.RootPath, path)
|
||||
return os.Remove(fullPath)
|
||||
}
|
||||
|
||||
func (s *LocalStorage) Exists(ctx context.Context, path string) (bool, error) {
|
||||
fullPath := filepath.Join(s.RootPath, path)
|
||||
_, err := os.Stat(fullPath)
|
||||
if err == nil {
|
||||
return true, nil
|
||||
}
|
||||
if os.IsNotExist(err) {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
88
internal/storage/s3.go
Normal file
88
internal/storage/s3.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
"github.com/aws/aws-sdk-go-v2/config"
|
||||
"github.com/aws/aws-sdk-go-v2/credentials"
|
||||
"github.com/aws/aws-sdk-go-v2/service/s3"
|
||||
)
|
||||
|
||||
type S3Storage struct {
|
||||
client *s3.Client
|
||||
bucket string
|
||||
}
|
||||
|
||||
func NewS3Storage(ctx context.Context, endpoint, region, accessKey, secretKey, bucket string, useSSL bool) (*S3Storage, error) {
|
||||
customResolver := aws.EndpointResolverWithOptionsFunc(func(service, region string, options ...interface{}) (aws.Endpoint, error) {
|
||||
if endpoint != "" {
|
||||
return aws.Endpoint{
|
||||
URL: endpoint,
|
||||
SigningRegion: region,
|
||||
HostnameImmutable: true,
|
||||
}, nil
|
||||
}
|
||||
return aws.Endpoint{}, &aws.EndpointNotFoundError{}
|
||||
})
|
||||
|
||||
cfg, err := config.LoadDefaultConfig(ctx,
|
||||
config.WithRegion(region),
|
||||
config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(accessKey, secretKey, "")),
|
||||
config.WithEndpointResolverWithOptions(customResolver),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
client := s3.NewFromConfig(cfg)
|
||||
return &S3Storage{
|
||||
client: client,
|
||||
bucket: bucket,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *S3Storage) Save(ctx context.Context, path string, reader io.Reader) error {
|
||||
_, err := s.client.PutObject(ctx, &s3.PutObjectInput{
|
||||
Bucket: aws.String(s.bucket),
|
||||
Key: aws.String(path),
|
||||
Body: reader,
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *S3Storage) Open(ctx context.Context, path string) (io.ReadCloser, error) {
|
||||
output, err := s.client.GetObject(ctx, &s3.GetObjectInput{
|
||||
Bucket: aws.String(s.bucket),
|
||||
Key: aws.String(path),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return output.Body, nil
|
||||
}
|
||||
|
||||
func (s *S3Storage) Delete(ctx context.Context, path string) error {
|
||||
_, err := s.client.DeleteObject(ctx, &s3.DeleteObjectInput{
|
||||
Bucket: aws.String(s.bucket),
|
||||
Key: aws.String(path),
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *S3Storage) Exists(ctx context.Context, path string) (bool, error) {
|
||||
_, err := s.client.HeadObject(ctx, &s3.HeadObjectInput{
|
||||
Bucket: aws.String(s.bucket),
|
||||
Key: aws.String(path),
|
||||
})
|
||||
if err != nil {
|
||||
// In AWS SDK v2, we check if the error is 404
|
||||
if strings.Contains(err.Error(), "NotFound") || strings.Contains(err.Error(), "404") {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
15
internal/storage/storage.go
Normal file
15
internal/storage/storage.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
)
|
||||
|
||||
type Storage interface {
|
||||
Save(ctx context.Context, path string, reader io.Reader) error
|
||||
Open(ctx context.Context, path string) (io.ReadCloser, error)
|
||||
Delete(ctx context.Context, path string) error
|
||||
Exists(ctx context.Context, path string) (bool, error)
|
||||
}
|
||||
|
||||
var GlobalStorage Storage
|
||||
69
internal/storage/webdav.go
Normal file
69
internal/storage/webdav.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/studio-b12/gowebdav"
|
||||
)
|
||||
|
||||
type WebDAVStorage struct {
|
||||
client *gowebdav.Client
|
||||
root string
|
||||
}
|
||||
|
||||
func NewWebDAVStorage(url, username, password, root string) *WebDAVStorage {
|
||||
client := gowebdav.NewClient(url, username, password)
|
||||
return &WebDAVStorage{
|
||||
client: client,
|
||||
root: root,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *WebDAVStorage) getFullPath(path string) string {
|
||||
return filepath.ToSlash(filepath.Join(s.root, path))
|
||||
}
|
||||
|
||||
func (s *WebDAVStorage) Save(ctx context.Context, path string, reader io.Reader) error {
|
||||
fullPath := s.getFullPath(path)
|
||||
|
||||
// Ensure directory exists
|
||||
dir := filepath.Dir(fullPath)
|
||||
if dir != "." && dir != "/" {
|
||||
parts := strings.Split(strings.Trim(dir, "/"), "/")
|
||||
current := ""
|
||||
for _, part := range parts {
|
||||
current += "/" + part
|
||||
_ = s.client.Mkdir(current, 0755)
|
||||
}
|
||||
}
|
||||
|
||||
return s.client.WriteStream(fullPath, reader, 0644)
|
||||
}
|
||||
|
||||
func (s *WebDAVStorage) Open(ctx context.Context, path string) (io.ReadCloser, error) {
|
||||
fullPath := s.getFullPath(path)
|
||||
return s.client.ReadStream(fullPath)
|
||||
}
|
||||
|
||||
func (s *WebDAVStorage) Delete(ctx context.Context, path string) error {
|
||||
fullPath := s.getFullPath(path)
|
||||
return s.client.Remove(fullPath)
|
||||
}
|
||||
|
||||
func (s *WebDAVStorage) Exists(ctx context.Context, path string) (bool, error) {
|
||||
fullPath := s.getFullPath(path)
|
||||
_, err := s.client.Stat(fullPath)
|
||||
if err != nil {
|
||||
// gowebdav's Stat returns error if not found
|
||||
// We could check for 404 but gowebdav doesn't export error types easily
|
||||
// Usually we check if it's a 404
|
||||
if strings.Contains(err.Error(), "404") || strings.Contains(err.Error(), "not found") {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
39
internal/task/cleaner.go
Normal file
39
internal/task/cleaner.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package task
|
||||
|
||||
import (
|
||||
"FileRelay/internal/service"
|
||||
"context"
|
||||
"log/slog"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Cleaner struct {
|
||||
batchService *service.BatchService
|
||||
}
|
||||
|
||||
func NewCleaner() *Cleaner {
|
||||
return &Cleaner{
|
||||
batchService: service.NewBatchService(),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Cleaner) Start(ctx context.Context) {
|
||||
ticker := time.NewTicker(1 * time.Hour)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
c.Clean()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Cleaner) Clean() {
|
||||
slog.Info("Running cleanup task")
|
||||
if err := c.batchService.Cleanup(context.Background()); err != nil {
|
||||
slog.Error("Error during cleanup", "error", err)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user