项目初始化
This commit is contained in:
70
internal/api/admin/auth.go
Normal file
70
internal/api/admin/auth.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"FileRelay/internal/auth"
|
||||
"FileRelay/internal/bootstrap"
|
||||
"FileRelay/internal/model"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
type AuthHandler struct{}
|
||||
|
||||
func NewAuthHandler() *AuthHandler {
|
||||
return &AuthHandler{}
|
||||
}
|
||||
|
||||
type LoginRequest struct {
|
||||
Password string `json:"password" binding:"required" example:"admin"`
|
||||
}
|
||||
|
||||
type LoginResponse struct {
|
||||
Token string `json:"token"`
|
||||
}
|
||||
|
||||
// Login 管理员登录
|
||||
// @Summary 管理员登录
|
||||
// @Description 通过密码换取 JWT Token
|
||||
// @Tags Admin
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body LoginRequest true "登录请求"
|
||||
// @Success 200 {object} model.Response{data=LoginResponse}
|
||||
// @Failure 401 {object} model.Response
|
||||
// @Router /admin/login [post]
|
||||
func (h *AuthHandler) Login(c *gin.Context) {
|
||||
var req LoginRequest
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, model.ErrorResponse(model.CodeBadRequest, "Invalid request"))
|
||||
return
|
||||
}
|
||||
|
||||
var admin model.Admin
|
||||
if err := bootstrap.DB.First(&admin).Error; err != nil {
|
||||
c.JSON(http.StatusInternalServerError, model.ErrorResponse(model.CodeInternalError, "Admin not found"))
|
||||
return
|
||||
}
|
||||
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(admin.PasswordHash), []byte(req.Password)); err != nil {
|
||||
c.JSON(http.StatusUnauthorized, model.ErrorResponse(model.CodeUnauthorized, "Incorrect password"))
|
||||
return
|
||||
}
|
||||
|
||||
token, err := auth.GenerateToken(admin.ID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, model.ErrorResponse(model.CodeInternalError, "Failed to generate token"))
|
||||
return
|
||||
}
|
||||
|
||||
// 更新登录时间
|
||||
now := time.Now()
|
||||
bootstrap.DB.Model(&admin).Update("last_login", &now)
|
||||
|
||||
c.JSON(http.StatusOK, model.SuccessResponse(LoginResponse{
|
||||
Token: token,
|
||||
}))
|
||||
}
|
||||
173
internal/api/admin/batch.go
Normal file
173
internal/api/admin/batch.go
Normal file
@@ -0,0 +1,173 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"FileRelay/internal/bootstrap"
|
||||
"FileRelay/internal/model"
|
||||
"FileRelay/internal/service"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type BatchHandler struct {
|
||||
batchService *service.BatchService
|
||||
}
|
||||
|
||||
func NewBatchHandler() *BatchHandler {
|
||||
return &BatchHandler{
|
||||
batchService: service.NewBatchService(),
|
||||
}
|
||||
}
|
||||
|
||||
type ListBatchesResponse struct {
|
||||
Total int64 `json:"total"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
Data []model.FileBatch `json:"data"`
|
||||
}
|
||||
|
||||
type UpdateBatchRequest struct {
|
||||
Remark string `json:"remark"`
|
||||
ExpireType string `json:"expire_type"`
|
||||
ExpireAt *time.Time `json:"expire_at"`
|
||||
MaxDownloads int `json:"max_downloads"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
// ListBatches 获取批次列表
|
||||
// @Summary 获取批次列表
|
||||
// @Description 分页查询所有文件批次,支持按状态过滤和取件码模糊搜索
|
||||
// @Tags Admin
|
||||
// @Security AdminAuth
|
||||
// @Param page query int false "页码 (默认 1)"
|
||||
// @Param page_size query int false "每页数量 (默认 20)"
|
||||
// @Param status query string false "状态 (active/expired/deleted)"
|
||||
// @Param pickup_code query string false "取件码 (模糊搜索)"
|
||||
// @Produce json
|
||||
// @Success 200 {object} model.Response{data=ListBatchesResponse}
|
||||
// @Failure 401 {object} model.Response
|
||||
// @Router /admin/batches [get]
|
||||
func (h *BatchHandler) ListBatches(c *gin.Context) {
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
|
||||
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
if pageSize < 1 {
|
||||
pageSize = 20
|
||||
}
|
||||
status := c.Query("status")
|
||||
pickupCode := c.Query("pickup_code")
|
||||
|
||||
query := bootstrap.DB.Model(&model.FileBatch{})
|
||||
if status != "" {
|
||||
query = query.Where("status = ?", status)
|
||||
}
|
||||
if pickupCode != "" {
|
||||
query = query.Where("pickup_code LIKE ?", "%"+pickupCode+"%")
|
||||
}
|
||||
|
||||
var total int64
|
||||
query.Count(&total)
|
||||
|
||||
var batches []model.FileBatch
|
||||
err := query.Offset((page - 1) * pageSize).Limit(pageSize).Order("created_at DESC").Find(&batches).Error
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, model.ErrorResponse(model.CodeInternalError, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, model.SuccessResponse(ListBatchesResponse{
|
||||
Total: total,
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
Data: batches,
|
||||
}))
|
||||
}
|
||||
|
||||
// GetBatch 获取批次详情
|
||||
// @Summary 获取批次详情
|
||||
// @Description 根据批次 ID 获取批次信息及关联的文件列表
|
||||
// @Tags Admin
|
||||
// @Security AdminAuth
|
||||
// @Param batch_id path int true "批次 ID"
|
||||
// @Produce json
|
||||
// @Success 200 {object} model.Response{data=model.FileBatch}
|
||||
// @Failure 404 {object} model.Response
|
||||
// @Router /admin/batch/{batch_id} [get]
|
||||
func (h *BatchHandler) GetBatch(c *gin.Context) {
|
||||
id := c.Param("batch_id")
|
||||
var batch model.FileBatch
|
||||
if err := bootstrap.DB.Preload("FileItems").First(&batch, id).Error; err != nil {
|
||||
c.JSON(http.StatusNotFound, model.ErrorResponse(model.CodeNotFound, "batch not found"))
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, model.SuccessResponse(batch))
|
||||
}
|
||||
|
||||
// UpdateBatch 修改批次信息
|
||||
// @Summary 修改批次信息
|
||||
// @Description 允许修改备注、过期策略、最大下载次数、状态等
|
||||
// @Tags Admin
|
||||
// @Security AdminAuth
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param batch_id path int true "批次 ID"
|
||||
// @Param request body UpdateBatchRequest true "修改内容"
|
||||
// @Success 200 {object} model.Response{data=model.FileBatch}
|
||||
// @Failure 400 {object} model.Response
|
||||
// @Router /admin/batch/{batch_id} [put]
|
||||
func (h *BatchHandler) UpdateBatch(c *gin.Context) {
|
||||
id := c.Param("batch_id")
|
||||
var batch model.FileBatch
|
||||
if err := bootstrap.DB.First(&batch, id).Error; err != nil {
|
||||
c.JSON(http.StatusNotFound, model.ErrorResponse(model.CodeNotFound, "batch not found"))
|
||||
return
|
||||
}
|
||||
|
||||
var input UpdateBatchRequest
|
||||
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
c.JSON(http.StatusBadRequest, model.ErrorResponse(model.CodeBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
updates := make(map[string]interface{})
|
||||
updates["remark"] = input.Remark
|
||||
updates["expire_type"] = input.ExpireType
|
||||
updates["expire_at"] = input.ExpireAt
|
||||
updates["max_downloads"] = input.MaxDownloads
|
||||
updates["status"] = input.Status
|
||||
|
||||
if err := bootstrap.DB.Model(&batch).Updates(updates).Error; err != nil {
|
||||
c.JSON(http.StatusInternalServerError, model.ErrorResponse(model.CodeInternalError, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, model.SuccessResponse(batch))
|
||||
}
|
||||
|
||||
// DeleteBatch 删除批次
|
||||
// @Summary 删除批次
|
||||
// @Description 标记批次为已删除,并物理删除关联的存储文件
|
||||
// @Tags Admin
|
||||
// @Security AdminAuth
|
||||
// @Param batch_id path int true "批次 ID"
|
||||
// @Produce json
|
||||
// @Success 200 {object} model.Response
|
||||
// @Failure 500 {object} model.Response
|
||||
// @Router /admin/batch/{batch_id} [delete]
|
||||
func (h *BatchHandler) DeleteBatch(c *gin.Context) {
|
||||
idStr := c.Param("batch_id")
|
||||
id, _ := strconv.ParseUint(idStr, 10, 32)
|
||||
|
||||
if err := h.batchService.DeleteBatch(c.Request.Context(), uint(id)); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, model.ErrorResponse(model.CodeInternalError, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, model.SuccessResponse(map[string]interface{}{}))
|
||||
}
|
||||
109
internal/api/admin/token.go
Normal file
109
internal/api/admin/token.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"FileRelay/internal/bootstrap"
|
||||
"FileRelay/internal/model"
|
||||
"FileRelay/internal/service"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type TokenHandler struct {
|
||||
tokenService *service.TokenService
|
||||
}
|
||||
|
||||
func NewTokenHandler() *TokenHandler {
|
||||
return &TokenHandler{
|
||||
tokenService: service.NewTokenService(),
|
||||
}
|
||||
}
|
||||
|
||||
type CreateTokenRequest struct {
|
||||
Name string `json:"name" binding:"required" example:"Test Token"`
|
||||
Scope string `json:"scope" example:"upload,pickup"`
|
||||
ExpireAt *time.Time `json:"expire_at"`
|
||||
}
|
||||
|
||||
type CreateTokenResponse struct {
|
||||
Token string `json:"token"`
|
||||
Data *model.APIToken `json:"data"`
|
||||
}
|
||||
|
||||
// ListTokens 获取 API Token 列表
|
||||
// @Summary 获取 API Token 列表
|
||||
// @Description 获取系统中所有 API Token 的详详信息(不包含哈希)
|
||||
// @Tags Admin
|
||||
// @Security AdminAuth
|
||||
// @Produce json
|
||||
// @Success 200 {object} model.Response{data=[]model.APIToken}
|
||||
// @Failure 401 {object} model.Response
|
||||
// @Router /admin/api-tokens [get]
|
||||
func (h *TokenHandler) ListTokens(c *gin.Context) {
|
||||
var tokens []model.APIToken
|
||||
if err := bootstrap.DB.Find(&tokens).Error; err != nil {
|
||||
c.JSON(http.StatusInternalServerError, model.ErrorResponse(model.CodeInternalError, err.Error()))
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, model.SuccessResponse(tokens))
|
||||
}
|
||||
|
||||
// CreateToken 创建 API Token
|
||||
// @Summary 创建 API Token
|
||||
// @Description 创建一个新的 API Token,返回原始 Token(仅显示一次)
|
||||
// @Tags Admin
|
||||
// @Security AdminAuth
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body CreateTokenRequest true "Token 信息"
|
||||
// @Success 201 {object} model.Response{data=CreateTokenResponse}
|
||||
// @Failure 400 {object} model.Response
|
||||
// @Router /admin/api-tokens [post]
|
||||
func (h *TokenHandler) CreateToken(c *gin.Context) {
|
||||
var input CreateTokenRequest
|
||||
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
c.JSON(http.StatusBadRequest, model.ErrorResponse(model.CodeBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
rawToken, token, err := h.tokenService.CreateToken(input.Name, input.Scope, input.ExpireAt)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, model.ErrorResponse(model.CodeInternalError, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, model.SuccessResponse(CreateTokenResponse{
|
||||
Token: rawToken,
|
||||
Data: token,
|
||||
}))
|
||||
}
|
||||
|
||||
func (h *TokenHandler) RevokeToken(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
if err := bootstrap.DB.Model(&model.APIToken{}).Where("id = ?", id).Update("revoked", true).Error; err != nil {
|
||||
c.JSON(http.StatusInternalServerError, model.ErrorResponse(model.CodeInternalError, err.Error()))
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, model.SuccessResponse(map[string]interface{}{}))
|
||||
}
|
||||
|
||||
// DeleteToken 删除 API Token
|
||||
// @Summary 删除 API Token
|
||||
// @Description 根据 ID 永久删除 API Token
|
||||
// @Tags Admin
|
||||
// @Security AdminAuth
|
||||
// @Param id path int true "Token ID"
|
||||
// @Produce json
|
||||
// @Success 200 {object} model.Response
|
||||
// @Failure 500 {object} model.Response
|
||||
// @Router /admin/api-tokens/{id} [delete]
|
||||
func (h *TokenHandler) DeleteToken(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
if err := bootstrap.DB.Delete(&model.APIToken{}, id).Error; err != nil {
|
||||
c.JSON(http.StatusInternalServerError, model.ErrorResponse(model.CodeInternalError, err.Error()))
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, model.SuccessResponse(map[string]interface{}{}))
|
||||
}
|
||||
76
internal/api/middleware/auth.go
Normal file
76
internal/api/middleware/auth.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"FileRelay/internal/auth"
|
||||
"FileRelay/internal/config"
|
||||
"FileRelay/internal/model"
|
||||
"FileRelay/internal/service"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func AdminAuth() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
c.JSON(http.StatusUnauthorized, model.ErrorResponse(model.CodeUnauthorized, "Authorization header required"))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
parts := strings.SplitN(authHeader, " ", 2)
|
||||
if !(len(parts) == 2 && parts[0] == "Bearer") {
|
||||
c.JSON(http.StatusUnauthorized, model.ErrorResponse(model.CodeUnauthorized, "Invalid authorization format"))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
claims, err := auth.ParseToken(parts[1])
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, model.ErrorResponse(model.CodeUnauthorized, "Invalid or expired token"))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Set("admin_id", claims.AdminID)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func APITokenAuth(requiredScope string) gin.HandlerFunc {
|
||||
tokenService := service.NewTokenService()
|
||||
return func(c *gin.Context) {
|
||||
if !config.GlobalConfig.APIToken.Enabled {
|
||||
c.JSON(http.StatusForbidden, model.ErrorResponse(model.CodeForbidden, "API Token is disabled"))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
c.JSON(http.StatusUnauthorized, model.ErrorResponse(model.CodeUnauthorized, "Authorization header required"))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
parts := strings.SplitN(authHeader, " ", 2)
|
||||
if !(len(parts) == 2 && parts[0] == "Bearer") {
|
||||
c.JSON(http.StatusUnauthorized, model.ErrorResponse(model.CodeUnauthorized, "Invalid authorization format"))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
token, err := tokenService.ValidateToken(parts[1], requiredScope)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, model.ErrorResponse(model.CodeUnauthorized, err.Error()))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Set("token_id", token.ID)
|
||||
c.Set("token_scope", token.Scope)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
52
internal/api/middleware/limit.go
Normal file
52
internal/api/middleware/limit.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"FileRelay/internal/config"
|
||||
"FileRelay/internal/model"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
var (
|
||||
pickupFailures = make(map[string]int)
|
||||
failureMutex sync.Mutex
|
||||
)
|
||||
|
||||
func PickupRateLimit() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ip := c.ClientIP()
|
||||
code := c.Param("pickup_code")
|
||||
key := ip + ":" + code
|
||||
|
||||
failureMutex.Lock()
|
||||
count, exists := pickupFailures[key]
|
||||
failureMutex.Unlock()
|
||||
|
||||
if exists && count >= config.GlobalConfig.Security.PickupFailLimit {
|
||||
c.JSON(http.StatusTooManyRequests, model.ErrorResponse(http.StatusTooManyRequests, "Too many failed attempts. Please try again later."))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func RecordPickupFailure(ip, code string) {
|
||||
key := ip + ":" + code
|
||||
failureMutex.Lock()
|
||||
pickupFailures[key]++
|
||||
|
||||
// 设置 1 小时后清除记录 (简单实现)
|
||||
go func() {
|
||||
time.Sleep(1 * time.Hour)
|
||||
failureMutex.Lock()
|
||||
delete(pickupFailures, key)
|
||||
failureMutex.Unlock()
|
||||
}()
|
||||
|
||||
failureMutex.Unlock()
|
||||
}
|
||||
162
internal/api/public/pickup.go
Normal file
162
internal/api/public/pickup.go
Normal file
@@ -0,0 +1,162 @@
|
||||
package public
|
||||
|
||||
import (
|
||||
"FileRelay/internal/api/middleware"
|
||||
"FileRelay/internal/bootstrap"
|
||||
"FileRelay/internal/model"
|
||||
"FileRelay/internal/service"
|
||||
"FileRelay/internal/storage"
|
||||
"archive/zip"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type PickupResponse struct {
|
||||
Remark string `json:"remark"`
|
||||
ExpireAt *time.Time `json:"expire_at"`
|
||||
ExpireType string `json:"expire_type"`
|
||||
DownloadCount int `json:"download_count"`
|
||||
MaxDownloads int `json:"max_downloads"`
|
||||
Files []model.FileItem `json:"files"`
|
||||
}
|
||||
|
||||
// DownloadBatch 批量下载文件 (ZIP)
|
||||
// @Summary 批量下载文件
|
||||
// @Description 根据取件码将批次内的所有文件打包为 ZIP 格式一次性下载
|
||||
// @Tags Public
|
||||
// @Param pickup_code path string true "取件码"
|
||||
// @Produce application/zip
|
||||
// @Success 200 {file} file
|
||||
// @Failure 404 {object} model.Response
|
||||
// @Router /api/download/batch/{pickup_code} [get]
|
||||
func (h *PickupHandler) DownloadBatch(c *gin.Context) {
|
||||
code := c.Param("pickup_code")
|
||||
batch, err := h.batchService.GetBatchByPickupCode(code)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, model.ErrorResponse(model.CodeNotFound, "batch not found or expired"))
|
||||
return
|
||||
}
|
||||
|
||||
c.Header("Content-Disposition", fmt.Sprintf("attachment; filename=\"batch_%s.zip\"", code))
|
||||
c.Header("Content-Type", "application/zip")
|
||||
|
||||
zw := zip.NewWriter(c.Writer)
|
||||
defer zw.Close()
|
||||
|
||||
for _, item := range batch.FileItems {
|
||||
reader, err := storage.GlobalStorage.Open(c.Request.Context(), item.StoragePath)
|
||||
if err != nil {
|
||||
continue // Skip failed files
|
||||
}
|
||||
|
||||
f, err := zw.Create(item.OriginalName)
|
||||
if err != nil {
|
||||
reader.Close()
|
||||
continue
|
||||
}
|
||||
|
||||
_, _ = io.Copy(f, reader)
|
||||
reader.Close()
|
||||
}
|
||||
|
||||
// 增加下载次数
|
||||
h.batchService.IncrementDownloadCount(batch.ID)
|
||||
}
|
||||
|
||||
type PickupHandler struct {
|
||||
batchService *service.BatchService
|
||||
}
|
||||
|
||||
func NewPickupHandler() *PickupHandler {
|
||||
return &PickupHandler{
|
||||
batchService: service.NewBatchService(),
|
||||
}
|
||||
}
|
||||
|
||||
// Pickup 获取批次信息
|
||||
// @Summary 获取批次信息
|
||||
// @Description 根据取件码获取文件批次详详情和文件列表
|
||||
// @Tags Public
|
||||
// @Produce json
|
||||
// @Param pickup_code path string true "取件码"
|
||||
// @Success 200 {object} model.Response{data=PickupResponse}
|
||||
// @Failure 404 {object} model.Response
|
||||
// @Router /api/pickup/{pickup_code} [get]
|
||||
func (h *PickupHandler) Pickup(c *gin.Context) {
|
||||
code := c.Param("pickup_code")
|
||||
if code == "" {
|
||||
c.JSON(http.StatusBadRequest, model.ErrorResponse(model.CodeBadRequest, "pickup code required"))
|
||||
return
|
||||
}
|
||||
|
||||
batch, err := h.batchService.GetBatchByPickupCode(code)
|
||||
if err != nil {
|
||||
middleware.RecordPickupFailure(c.ClientIP(), code)
|
||||
c.JSON(http.StatusNotFound, model.ErrorResponse(model.CodeNotFound, "batch not found or expired"))
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, model.SuccessResponse(PickupResponse{
|
||||
Remark: batch.Remark,
|
||||
ExpireAt: batch.ExpireAt,
|
||||
ExpireType: batch.ExpireType,
|
||||
DownloadCount: batch.DownloadCount,
|
||||
MaxDownloads: batch.MaxDownloads,
|
||||
Files: batch.FileItems,
|
||||
}))
|
||||
}
|
||||
|
||||
// DownloadFile 下载单个文件
|
||||
// @Summary 下载单个文件
|
||||
// @Description 根据文件 ID 下载单个文件
|
||||
// @Tags Public
|
||||
// @Param file_id path int true "文件 ID"
|
||||
// @Produce application/octet-stream
|
||||
// @Success 200 {file} file
|
||||
// @Failure 404 {object} model.Response
|
||||
// @Failure 410 {object} model.Response
|
||||
// @Router /api/download/file/{file_id} [get]
|
||||
func (h *PickupHandler) DownloadFile(c *gin.Context) {
|
||||
fileIDStr := c.Param("file_id")
|
||||
fileID, _ := strconv.ParseUint(fileIDStr, 10, 32)
|
||||
|
||||
var item model.FileItem
|
||||
if err := bootstrap.DB.First(&item, fileID).Error; err != nil {
|
||||
c.JSON(http.StatusNotFound, model.ErrorResponse(model.CodeNotFound, "file not found"))
|
||||
return
|
||||
}
|
||||
|
||||
var batch model.FileBatch
|
||||
if err := bootstrap.DB.First(&batch, item.BatchID).Error; err != nil {
|
||||
c.JSON(http.StatusNotFound, model.ErrorResponse(model.CodeNotFound, "batch not found"))
|
||||
return
|
||||
}
|
||||
|
||||
if h.batchService.IsExpired(&batch) {
|
||||
h.batchService.MarkAsExpired(&batch)
|
||||
c.JSON(http.StatusGone, model.ErrorResponse(model.CodeGone, "batch expired"))
|
||||
return
|
||||
}
|
||||
|
||||
// 打开文件
|
||||
reader, err := storage.GlobalStorage.Open(c.Request.Context(), item.StoragePath)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, model.ErrorResponse(model.CodeInternalError, "failed to open file"))
|
||||
return
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
// 增加下载次数
|
||||
h.batchService.IncrementDownloadCount(batch.ID)
|
||||
|
||||
c.Header("Content-Disposition", fmt.Sprintf("attachment; filename=\"%s\"", item.OriginalName))
|
||||
c.Header("Content-Type", item.MimeType)
|
||||
c.Header("Content-Length", strconv.FormatInt(item.Size, 10))
|
||||
|
||||
io.Copy(c.Writer, reader)
|
||||
}
|
||||
96
internal/api/public/upload.go
Normal file
96
internal/api/public/upload.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package public
|
||||
|
||||
import (
|
||||
"FileRelay/internal/config"
|
||||
"FileRelay/internal/model"
|
||||
"FileRelay/internal/service"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type UploadHandler struct {
|
||||
uploadService *service.UploadService
|
||||
}
|
||||
|
||||
func NewUploadHandler() *UploadHandler {
|
||||
return &UploadHandler{
|
||||
uploadService: service.NewUploadService(),
|
||||
}
|
||||
}
|
||||
|
||||
type UploadResponse struct {
|
||||
PickupCode string `json:"pickup_code"`
|
||||
ExpireAt *time.Time `json:"expire_at"`
|
||||
BatchID uint `json:"batch_id"`
|
||||
}
|
||||
|
||||
// Upload 上传文件并生成取件码
|
||||
// @Summary 上传文件
|
||||
// @Description 上传一个或多个文件并创建一个提取批次
|
||||
// @Tags Public
|
||||
// @Accept multipart/form-data
|
||||
// @Produce json
|
||||
// @Param files formData file true "文件列表"
|
||||
// @Param remark formData string false "备注"
|
||||
// @Param expire_type formData string false "过期类型 (time/download/permanent)"
|
||||
// @Param expire_days formData int false "过期天数 (针对 time 类型)"
|
||||
// @Param max_downloads formData int false "最大下载次数 (针对 download 类型)"
|
||||
// @Success 200 {object} model.Response{data=UploadResponse}
|
||||
// @Failure 400 {object} model.Response
|
||||
// @Failure 500 {object} model.Response
|
||||
// @Router /api/upload [post]
|
||||
func (h *UploadHandler) Upload(c *gin.Context) {
|
||||
form, err := c.MultipartForm()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, model.ErrorResponse(model.CodeBadRequest, "invalid form"))
|
||||
return
|
||||
}
|
||||
|
||||
files := form.File["files"]
|
||||
if len(files) == 0 {
|
||||
c.JSON(http.StatusBadRequest, model.ErrorResponse(model.CodeBadRequest, "no files uploaded"))
|
||||
return
|
||||
}
|
||||
|
||||
if len(files) > config.GlobalConfig.Upload.MaxBatchFiles {
|
||||
c.JSON(http.StatusBadRequest, model.ErrorResponse(model.CodeBadRequest, "too many files"))
|
||||
return
|
||||
}
|
||||
|
||||
remark := c.PostForm("remark")
|
||||
expireType := c.PostForm("expire_type") // time / download / permanent
|
||||
if expireType == "" {
|
||||
expireType = "time"
|
||||
}
|
||||
|
||||
var expireValue interface{}
|
||||
switch expireType {
|
||||
case "time":
|
||||
days, _ := strconv.Atoi(c.PostForm("expire_days"))
|
||||
if days <= 0 {
|
||||
days = config.GlobalConfig.Upload.MaxRetentionDays
|
||||
}
|
||||
expireValue = days
|
||||
case "download":
|
||||
max, _ := strconv.Atoi(c.PostForm("max_downloads"))
|
||||
if max <= 0 {
|
||||
max = 1
|
||||
}
|
||||
expireValue = max
|
||||
}
|
||||
|
||||
batch, err := h.uploadService.CreateBatch(c.Request.Context(), files, remark, expireType, expireValue)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, model.ErrorResponse(model.CodeInternalError, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, model.SuccessResponse(UploadResponse{
|
||||
PickupCode: batch.PickupCode,
|
||||
ExpireAt: batch.ExpireAt,
|
||||
BatchID: batch.ID,
|
||||
}))
|
||||
}
|
||||
42
internal/auth/jwt.go
Normal file
42
internal/auth/jwt.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"FileRelay/internal/config"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
type Claims struct {
|
||||
AdminID uint `json:"admin_id"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
func GenerateToken(adminID uint) (string, error) {
|
||||
claims := Claims{
|
||||
AdminID: adminID,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(24 * time.Hour)),
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
},
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
return token.SignedString([]byte(config.GlobalConfig.Security.JWTSecret))
|
||||
}
|
||||
|
||||
func ParseToken(tokenString string) (*Claims, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
return []byte(config.GlobalConfig.Security.JWTSecret), nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if claims, ok := token.Claims.(*Claims); ok && token.Valid {
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
return nil, jwt.ErrSignatureInvalid
|
||||
}
|
||||
112
internal/bootstrap/init.go
Normal file
112
internal/bootstrap/init.go
Normal file
@@ -0,0 +1,112 @@
|
||||
package bootstrap
|
||||
|
||||
import (
|
||||
"FileRelay/internal/config"
|
||||
"FileRelay/internal/model"
|
||||
"FileRelay/internal/storage"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"log"
|
||||
"math/big"
|
||||
|
||||
"github.com/glebarez/sqlite"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var DB *gorm.DB
|
||||
|
||||
func InitDB() {
|
||||
var err error
|
||||
dbPath := config.GlobalConfig.Database.Path
|
||||
if dbPath == "" {
|
||||
dbPath = "file_relay.db"
|
||||
}
|
||||
|
||||
DB, err = gorm.Open(sqlite.Open(dbPath), &gorm.Config{})
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to connect to database: %v", err)
|
||||
}
|
||||
|
||||
// 自动迁移
|
||||
err = DB.AutoMigrate(
|
||||
&model.FileBatch{},
|
||||
&model.FileItem{},
|
||||
&model.APIToken{},
|
||||
&model.Admin{},
|
||||
)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to migrate database: %v", err)
|
||||
}
|
||||
|
||||
fmt.Println("Database initialized and migrated.")
|
||||
|
||||
// 初始化存储
|
||||
initStorage()
|
||||
|
||||
// 初始化管理员 (如果不存在)
|
||||
initAdmin()
|
||||
}
|
||||
|
||||
func initStorage() {
|
||||
storageType := config.GlobalConfig.Storage.Type
|
||||
switch storageType {
|
||||
case "local":
|
||||
storage.GlobalStorage = storage.NewLocalStorage(config.GlobalConfig.Storage.Local.Path)
|
||||
case "webdav":
|
||||
cfg := config.GlobalConfig.Storage.WebDAV
|
||||
storage.GlobalStorage = storage.NewWebDAVStorage(cfg.URL, cfg.Username, cfg.Password, cfg.Root)
|
||||
case "s3":
|
||||
cfg := config.GlobalConfig.Storage.S3
|
||||
s3Storage, err := storage.NewS3Storage(context.Background(), cfg.Endpoint, cfg.Region, cfg.AccessKey, cfg.SecretKey, cfg.Bucket, cfg.UseSSL)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to initialize S3 storage: %v", err)
|
||||
}
|
||||
storage.GlobalStorage = s3Storage
|
||||
default:
|
||||
log.Fatalf("Unsupported storage type: %s", storageType)
|
||||
}
|
||||
fmt.Printf("Storage initialized with type: %s\n", storageType)
|
||||
}
|
||||
|
||||
func initAdmin() {
|
||||
var count int64
|
||||
DB.Model(&model.Admin{}).Count(&count)
|
||||
if count == 0 {
|
||||
passwordHash := config.GlobalConfig.Security.AdminPasswordHash
|
||||
if passwordHash == "" {
|
||||
// 生成随机密码
|
||||
password := generateRandomPassword(12)
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to generate password hash: %v", err)
|
||||
}
|
||||
passwordHash = string(hash)
|
||||
fmt.Printf("**************************************************\n")
|
||||
fmt.Printf("NO ADMIN PASSWORD CONFIGURED. GENERATED RANDOM PASSWORD:\n")
|
||||
fmt.Printf("Password: %s\n", password)
|
||||
fmt.Printf("Please save this password or configure admin_password_hash in config.yaml\n")
|
||||
fmt.Printf("**************************************************\n")
|
||||
}
|
||||
|
||||
admin := &model.Admin{
|
||||
PasswordHash: passwordHash,
|
||||
}
|
||||
DB.Create(admin)
|
||||
fmt.Println("Admin account initialized.")
|
||||
}
|
||||
}
|
||||
|
||||
func generateRandomPassword(length int) string {
|
||||
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!@#$%^&*"
|
||||
b := make([]byte, length)
|
||||
for i := range b {
|
||||
num, err := rand.Int(rand.Reader, big.NewInt(int64(len(charset))))
|
||||
if err != nil {
|
||||
return "admin123" // 退路
|
||||
}
|
||||
b[i] = charset[num.Int64()]
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
82
internal/config/config.go
Normal file
82
internal/config/config.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Site SiteConfig `yaml:"site"`
|
||||
Security SecurityConfig `yaml:"security"`
|
||||
Upload UploadConfig `yaml:"upload"`
|
||||
Storage StorageConfig `yaml:"storage"`
|
||||
APIToken APITokenConfig `yaml:"api_token"`
|
||||
Database DatabaseConfig `yaml:"database"`
|
||||
}
|
||||
|
||||
type SiteConfig struct {
|
||||
Name string `yaml:"name"`
|
||||
Description string `yaml:"description"`
|
||||
}
|
||||
|
||||
type SecurityConfig struct {
|
||||
AdminPasswordHash string `yaml:"admin_password_hash"`
|
||||
PickupCodeLength int `yaml:"pickup_code_length"`
|
||||
PickupFailLimit int `yaml:"pickup_fail_limit"`
|
||||
JWTSecret string `yaml:"jwt_secret"`
|
||||
}
|
||||
|
||||
type UploadConfig struct {
|
||||
MaxFileSizeMB int64 `yaml:"max_file_size_mb"`
|
||||
MaxBatchFiles int `yaml:"max_batch_files"`
|
||||
MaxRetentionDays int `yaml:"max_retention_days"`
|
||||
}
|
||||
|
||||
type StorageConfig struct {
|
||||
Type string `yaml:"type"`
|
||||
Local struct {
|
||||
Path string `yaml:"path"`
|
||||
} `yaml:"local"`
|
||||
WebDAV struct {
|
||||
URL string `yaml:"url"`
|
||||
Username string `yaml:"username"`
|
||||
Password string `yaml:"password"`
|
||||
Root string `yaml:"root"`
|
||||
} `yaml:"webdav"`
|
||||
S3 struct {
|
||||
Endpoint string `yaml:"endpoint"`
|
||||
Region string `yaml:"region"`
|
||||
AccessKey string `yaml:"access_key"`
|
||||
SecretKey string `yaml:"secret_key"`
|
||||
Bucket string `yaml:"bucket"`
|
||||
UseSSL bool `yaml:"use_ssl"`
|
||||
} `yaml:"s3"`
|
||||
}
|
||||
|
||||
type APITokenConfig struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
AllowAdminAPI bool `yaml:"allow_admin_api"`
|
||||
MaxTokens int `yaml:"max_tokens"`
|
||||
}
|
||||
|
||||
type DatabaseConfig struct {
|
||||
Path string `yaml:"path"`
|
||||
}
|
||||
|
||||
var GlobalConfig *Config
|
||||
|
||||
func LoadConfig(path string) error {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var cfg Config
|
||||
if err := yaml.Unmarshal(data, &cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
GlobalConfig = &cfg
|
||||
return nil
|
||||
}
|
||||
11
internal/model/admin.go
Normal file
11
internal/model/admin.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type Admin struct {
|
||||
ID uint `gorm:"primaryKey" json:"id"`
|
||||
PasswordHash string `json:"-"`
|
||||
LastLogin *time.Time `json:"last_login"`
|
||||
}
|
||||
16
internal/model/api_token.go
Normal file
16
internal/model/api_token.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type APIToken struct {
|
||||
ID uint `gorm:"primaryKey" json:"id"`
|
||||
Name string `json:"name"`
|
||||
TokenHash string `gorm:"uniqueIndex;not null" json:"-"`
|
||||
Scope string `json:"scope"`
|
||||
ExpireAt *time.Time `json:"expire_at"`
|
||||
LastUsedAt *time.Time `json:"last_used_at"`
|
||||
Revoked bool `gorm:"default:false" json:"revoked"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
22
internal/model/file_batch.go
Normal file
22
internal/model/file_batch.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type FileBatch struct {
|
||||
ID uint `gorm:"primaryKey" json:"id"`
|
||||
PickupCode string `gorm:"uniqueIndex;not null" json:"pickup_code"`
|
||||
Remark string `json:"remark"`
|
||||
ExpireType string `json:"expire_type"` // time / download / permanent
|
||||
ExpireAt *time.Time `json:"expire_at"`
|
||||
MaxDownloads int `json:"max_downloads"`
|
||||
DownloadCount int `gorm:"default:0" json:"download_count"`
|
||||
Status string `gorm:"default:'active'" json:"status"` // active / expired / deleted
|
||||
FileItems []FileItem `gorm:"foreignKey:BatchID" json:"file_items,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
|
||||
}
|
||||
15
internal/model/file_item.go
Normal file
15
internal/model/file_item.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type FileItem struct {
|
||||
ID uint `gorm:"primaryKey" json:"id"`
|
||||
BatchID uint `gorm:"index;not null" json:"batch_id"`
|
||||
OriginalName string `json:"original_name"`
|
||||
StoragePath string `json:"storage_path"`
|
||||
Size int64 `json:"size"`
|
||||
MimeType string `json:"mime_type"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
45
internal/model/response.go
Normal file
45
internal/model/response.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package model
|
||||
|
||||
// Response 统一响应模型
|
||||
type Response struct {
|
||||
Code int `json:"code" example:"200"`
|
||||
Msg string `json:"msg" example:"success"`
|
||||
Data interface{} `json:"data"`
|
||||
}
|
||||
|
||||
// 错误码定义
|
||||
const (
|
||||
CodeSuccess = 200
|
||||
CodeBadRequest = 400
|
||||
CodeUnauthorized = 401
|
||||
CodeForbidden = 403
|
||||
CodeNotFound = 404
|
||||
CodeGone = 410
|
||||
CodeInternalError = 500
|
||||
)
|
||||
|
||||
// NewResponse 创建响应
|
||||
func NewResponse(code int, msg string, data interface{}) Response {
|
||||
return Response{
|
||||
Code: code,
|
||||
Msg: msg,
|
||||
Data: data,
|
||||
}
|
||||
}
|
||||
|
||||
// SuccessResponse 成功响应
|
||||
func SuccessResponse(data interface{}) Response {
|
||||
return Response{
|
||||
Code: CodeSuccess,
|
||||
Msg: "success",
|
||||
Data: data,
|
||||
}
|
||||
}
|
||||
|
||||
// ErrorResponse 错误响应
|
||||
func ErrorResponse(code int, msg string) Response {
|
||||
return Response{
|
||||
Code: code,
|
||||
Msg: msg,
|
||||
}
|
||||
}
|
||||
86
internal/service/batch_service.go
Normal file
86
internal/service/batch_service.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"FileRelay/internal/bootstrap"
|
||||
"FileRelay/internal/model"
|
||||
"FileRelay/internal/storage"
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type BatchService struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewBatchService() *BatchService {
|
||||
return &BatchService{db: bootstrap.DB}
|
||||
}
|
||||
|
||||
func (s *BatchService) GetBatchByPickupCode(code string) (*model.FileBatch, error) {
|
||||
var batch model.FileBatch
|
||||
err := s.db.Preload("FileItems").Where("pickup_code = ? AND status = ?", code, "active").First(&batch).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 检查是否过期
|
||||
if s.IsExpired(&batch) {
|
||||
s.MarkAsExpired(&batch)
|
||||
return nil, errors.New("batch expired")
|
||||
}
|
||||
|
||||
return &batch, nil
|
||||
}
|
||||
|
||||
func (s *BatchService) IsExpired(batch *model.FileBatch) bool {
|
||||
if batch.Status != "active" {
|
||||
return true
|
||||
}
|
||||
|
||||
switch batch.ExpireType {
|
||||
case "time":
|
||||
if batch.ExpireAt != nil && time.Now().After(*batch.ExpireAt) {
|
||||
return true
|
||||
}
|
||||
case "download":
|
||||
if batch.MaxDownloads > 0 && batch.DownloadCount >= batch.MaxDownloads {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *BatchService) MarkAsExpired(batch *model.FileBatch) error {
|
||||
return s.db.Model(batch).Update("status", "expired").Error
|
||||
}
|
||||
|
||||
func (s *BatchService) DeleteBatch(ctx context.Context, batchID uint) error {
|
||||
var batch model.FileBatch
|
||||
if err := s.db.Preload("FileItems").First(&batch, batchID).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 删除物理文件
|
||||
for _, item := range batch.FileItems {
|
||||
_ = storage.GlobalStorage.Delete(ctx, item.StoragePath)
|
||||
}
|
||||
|
||||
// 删除数据库记录 (软删除 Batch)
|
||||
return s.db.Transaction(func(tx *gorm.DB) error {
|
||||
if err := tx.Where("batch_id = ?", batch.ID).Delete(&model.FileItem{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
if err := tx.Delete(&batch).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BatchService) IncrementDownloadCount(batchID uint) error {
|
||||
return s.db.Model(&model.FileBatch{}).Where("id = ?", batchID).
|
||||
UpdateColumn("download_count", gorm.Expr("download_count + ?", 1)).Error
|
||||
}
|
||||
83
internal/service/token_service.go
Normal file
83
internal/service/token_service.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"FileRelay/internal/bootstrap"
|
||||
"FileRelay/internal/model"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type TokenService struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewTokenService() *TokenService {
|
||||
return &TokenService{db: bootstrap.DB}
|
||||
}
|
||||
|
||||
func (s *TokenService) CreateToken(name string, scope string, expireAt *time.Time) (string, *model.APIToken, error) {
|
||||
rawToken := uuid.New().String()
|
||||
hash := s.hashToken(rawToken)
|
||||
|
||||
token := &model.APIToken{
|
||||
Name: name,
|
||||
TokenHash: hash,
|
||||
Scope: scope,
|
||||
ExpireAt: expireAt,
|
||||
}
|
||||
|
||||
if err := s.db.Create(token).Error; err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
return rawToken, token, nil
|
||||
}
|
||||
|
||||
func (s *TokenService) ValidateToken(rawToken string, requiredScope string) (*model.APIToken, error) {
|
||||
hash := s.hashToken(rawToken)
|
||||
var token model.APIToken
|
||||
if err := s.db.Where("token_hash = ? AND revoked = ?", hash, false).First(&token).Error; err != nil {
|
||||
return nil, errors.New("invalid token")
|
||||
}
|
||||
|
||||
if token.ExpireAt != nil && time.Now().After(*token.ExpireAt) {
|
||||
return nil, errors.New("token expired")
|
||||
}
|
||||
|
||||
// 检查 Scope (简单包含判断)
|
||||
// 在实际应用中可以实现更复杂的逻辑
|
||||
if requiredScope != "" && !s.checkScope(token.Scope, requiredScope) {
|
||||
return nil, errors.New("insufficient scope")
|
||||
}
|
||||
|
||||
// 更新最后使用时间
|
||||
now := time.Now()
|
||||
s.db.Model(&token).Update("last_used_at", &now)
|
||||
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
func (s *TokenService) hashToken(token string) string {
|
||||
h := sha256.New()
|
||||
h.Write([]byte(token))
|
||||
return hex.EncodeToString(h.Sum(nil))
|
||||
}
|
||||
|
||||
func (s *TokenService) checkScope(tokenScope, requiredScope string) bool {
|
||||
if requiredScope == "" {
|
||||
return true
|
||||
}
|
||||
scopes := strings.Split(tokenScope, ",")
|
||||
for _, s := range scopes {
|
||||
if strings.TrimSpace(s) == requiredScope {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
121
internal/service/upload_service.go
Normal file
121
internal/service/upload_service.go
Normal file
@@ -0,0 +1,121 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"FileRelay/internal/bootstrap"
|
||||
"FileRelay/internal/config"
|
||||
"FileRelay/internal/model"
|
||||
"FileRelay/internal/storage"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"mime/multipart"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type UploadService struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewUploadService() *UploadService {
|
||||
return &UploadService{db: bootstrap.DB}
|
||||
}
|
||||
|
||||
func (s *UploadService) CreateBatch(ctx context.Context, files []*multipart.FileHeader, remark string, expireType string, expireValue interface{}) (*model.FileBatch, error) {
|
||||
// 1. 生成取件码
|
||||
pickupCode, err := s.generatePickupCode(config.GlobalConfig.Security.PickupCodeLength)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 2. 准备 Batch
|
||||
batch := &model.FileBatch{
|
||||
PickupCode: pickupCode,
|
||||
Remark: remark,
|
||||
ExpireType: expireType,
|
||||
Status: "active",
|
||||
}
|
||||
|
||||
switch expireType {
|
||||
case "time":
|
||||
if days, ok := expireValue.(int); ok {
|
||||
expireAt := time.Now().Add(time.Duration(days) * 24 * time.Hour)
|
||||
batch.ExpireAt = &expireAt
|
||||
}
|
||||
case "download":
|
||||
if max, ok := expireValue.(int); ok {
|
||||
batch.MaxDownloads = max
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 处理文件上传
|
||||
return batch, s.db.Transaction(func(tx *gorm.DB) error {
|
||||
if err := tx.Create(batch).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, fileHeader := range files {
|
||||
fileItem, err := s.processFile(ctx, tx, batch.ID, fileHeader)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
batch.FileItems = append(batch.FileItems, *fileItem)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (s *UploadService) processFile(ctx context.Context, tx *gorm.DB, batchID uint, fileHeader *multipart.FileHeader) (*model.FileItem, error) {
|
||||
file, err := fileHeader.Open()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// 生成唯一存储路径
|
||||
ext := filepath.Ext(fileHeader.Filename)
|
||||
storagePath := fmt.Sprintf("%d/%s%s", batchID, uuid.New().String(), ext)
|
||||
|
||||
// 保存到存储层
|
||||
if err := storage.GlobalStorage.Save(ctx, storagePath, file); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 创建数据库记录
|
||||
item := &model.FileItem{
|
||||
BatchID: batchID,
|
||||
OriginalName: fileHeader.Filename,
|
||||
StoragePath: storagePath,
|
||||
Size: fileHeader.Size,
|
||||
MimeType: fileHeader.Header.Get("Content-Type"),
|
||||
}
|
||||
|
||||
if err := tx.Create(item).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return item, nil
|
||||
}
|
||||
|
||||
func (s *UploadService) generatePickupCode(length int) (string, error) {
|
||||
const charset = "0123456789"
|
||||
b := make([]byte, length)
|
||||
for i := range b {
|
||||
num, err := rand.Int(rand.Reader, big.NewInt(int64(len(charset))))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
b[i] = charset[num.Int64()]
|
||||
}
|
||||
// 检查是否冲突
|
||||
var count int64
|
||||
s.db.Model(&model.FileBatch{}).Where("pickup_code = ? AND status = ?", string(b), "active").Count(&count)
|
||||
if count > 0 {
|
||||
return s.generatePickupCode(length) // 递归生成
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
59
internal/task/cleaner.go
Normal file
59
internal/task/cleaner.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package task
|
||||
|
||||
import (
|
||||
"FileRelay/internal/bootstrap"
|
||||
"FileRelay/internal/model"
|
||||
"FileRelay/internal/service"
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Cleaner struct {
|
||||
batchService *service.BatchService
|
||||
}
|
||||
|
||||
func NewCleaner() *Cleaner {
|
||||
return &Cleaner{
|
||||
batchService: service.NewBatchService(),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Cleaner) Start(ctx context.Context) {
|
||||
ticker := time.NewTicker(1 * time.Hour)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
c.Clean()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Cleaner) Clean() {
|
||||
fmt.Println("Running cleanup task...")
|
||||
|
||||
// 1. 寻找过期的 Active Batches
|
||||
var batches []model.FileBatch
|
||||
now := time.Now()
|
||||
bootstrap.DB.Where("status = ? AND expire_type = ? AND expire_at < ?", "active", "time", now).Find(&batches)
|
||||
|
||||
for _, batch := range batches {
|
||||
c.batchService.MarkAsExpired(&batch)
|
||||
}
|
||||
|
||||
// 2. 寻找标记为 expired 或 deleted 的批次并彻底清理文件和记录
|
||||
// 这里可以根据业务需求决定是否立即物理删除,或者等待一段时间
|
||||
// 按照需求:扫描 expired / deleted 批次,批量删除文件,清理数据库记录
|
||||
|
||||
var toDelete []model.FileBatch
|
||||
bootstrap.DB.Unscoped().Where("status IN ? OR deleted_at IS NOT NULL", []string{"expired", "deleted"}).Find(&toDelete)
|
||||
|
||||
for _, batch := range toDelete {
|
||||
fmt.Printf("Deep cleaning batch: %d\n", batch.ID)
|
||||
c.batchService.DeleteBatch(context.Background(), batch.ID)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user