mirror of
https://git.fightbot.fun/hxuanyu/FileRelay.git
synced 2026-02-15 06:01:43 +08:00
315 lines
9.9 KiB
Go
315 lines
9.9 KiB
Go
package public
|
||
|
||
import (
|
||
"FileRelay/internal/api/middleware"
|
||
"FileRelay/internal/bootstrap"
|
||
"FileRelay/internal/config"
|
||
"FileRelay/internal/model"
|
||
"FileRelay/internal/service"
|
||
"FileRelay/internal/storage"
|
||
"archive/zip"
|
||
"fmt"
|
||
"io"
|
||
"log/slog"
|
||
"net/http"
|
||
"strconv"
|
||
"strings"
|
||
"time"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
)
|
||
|
||
type PickupResponse struct {
|
||
Remark string `json:"remark"`
|
||
ExpireAt *time.Time `json:"expire_at"`
|
||
ExpireType string `json:"expire_type"`
|
||
DownloadCount int `json:"download_count"`
|
||
MaxDownloads int `json:"max_downloads"`
|
||
Type string `json:"type"`
|
||
Content string `json:"content,omitempty"`
|
||
Files []model.FileItem `json:"files,omitempty"`
|
||
}
|
||
|
||
type DownloadCountResponse struct {
|
||
DownloadCount int `json:"download_count"`
|
||
MaxDownloads int `json:"max_downloads"`
|
||
}
|
||
|
||
// DownloadBatch 批量下载文件 (ZIP)
|
||
// @Summary 批量下载文件
|
||
// @Description 根据取件码将批次内的所有文件打包为 ZIP 格式一次性下载。可选提供带 pickup scope 的 API Token。
|
||
// @Tags Public
|
||
// @Security APITokenAuth
|
||
// @Param pickup_code path string true "取件码"
|
||
// @Produce application/zip
|
||
// @Success 200 {file} file
|
||
// @Failure 404 {object} model.Response
|
||
// @Router /api/batches/{pickup_code}/download [get]
|
||
func (h *PickupHandler) DownloadBatch(c *gin.Context) {
|
||
code := c.Param("pickup_code")
|
||
batch, err := h.batchService.GetBatchByPickupCode(code)
|
||
if err != nil {
|
||
c.JSON(http.StatusNotFound, model.ErrorResponse(model.CodeNotFound, "batch not found or expired"))
|
||
return
|
||
}
|
||
|
||
c.Header("Content-Disposition", fmt.Sprintf("attachment; filename=\"batch_%s.zip\"", code))
|
||
c.Header("Content-Type", "application/zip")
|
||
|
||
zw := zip.NewWriter(c.Writer)
|
||
defer zw.Close()
|
||
|
||
for _, item := range batch.FileItems {
|
||
reader, err := storage.GlobalStorage.Open(c.Request.Context(), item.StoragePath)
|
||
if err != nil {
|
||
continue // Skip failed files
|
||
}
|
||
|
||
f, err := zw.Create(item.OriginalName)
|
||
if err != nil {
|
||
reader.Close()
|
||
continue
|
||
}
|
||
|
||
_, _ = io.Copy(f, reader)
|
||
reader.Close()
|
||
}
|
||
|
||
// 增加下载次数
|
||
if err := h.batchService.IncrementDownloadCount(batch.ID); err != nil {
|
||
slog.Error("Failed to increment download count", "batch_id", batch.ID, "error", err)
|
||
}
|
||
}
|
||
|
||
type PickupHandler struct {
|
||
batchService *service.BatchService
|
||
}
|
||
|
||
func NewPickupHandler() *PickupHandler {
|
||
return &PickupHandler{
|
||
batchService: service.NewBatchService(),
|
||
}
|
||
}
|
||
|
||
// Pickup 获取批次信息
|
||
// @Summary 获取批次信息
|
||
// @Description 根据取件码获取文件批次详细信息和文件列表。可选提供带 pickup scope 的 API Token。
|
||
// @Tags Public
|
||
// @Security APITokenAuth
|
||
// @Produce json
|
||
// @Param pickup_code path string true "取件码"
|
||
// @Success 200 {object} model.Response{data=PickupResponse}
|
||
// @Failure 404 {object} model.Response
|
||
// @Router /api/batches/{pickup_code} [get]
|
||
func (h *PickupHandler) Pickup(c *gin.Context) {
|
||
code := c.Param("pickup_code")
|
||
if code == "" {
|
||
c.JSON(http.StatusBadRequest, model.ErrorResponse(model.CodeBadRequest, "pickup code required"))
|
||
return
|
||
}
|
||
|
||
batch, err := h.batchService.GetBatchByPickupCode(code)
|
||
if err != nil {
|
||
middleware.RecordPickupFailure(c.ClientIP())
|
||
c.JSON(http.StatusNotFound, model.ErrorResponse(model.CodeNotFound, "batch not found or expired"))
|
||
return
|
||
}
|
||
|
||
if batch.Type == "text" {
|
||
if err := h.batchService.IncrementDownloadCount(batch.ID); err != nil {
|
||
slog.Error("Failed to increment download count for batch", "batch_id", batch.ID, "error", err)
|
||
} else {
|
||
batch.DownloadCount++
|
||
}
|
||
}
|
||
|
||
baseURL := getBaseURL(c)
|
||
|
||
for i := range batch.FileItems {
|
||
batch.FileItems[i].DownloadURL = fmt.Sprintf("%s/api/files/%s/%s", baseURL, batch.FileItems[i].ID, batch.FileItems[i].OriginalName)
|
||
}
|
||
|
||
c.JSON(http.StatusOK, model.SuccessResponse(PickupResponse{
|
||
Remark: batch.Remark,
|
||
ExpireAt: batch.ExpireAt,
|
||
ExpireType: batch.ExpireType,
|
||
DownloadCount: batch.DownloadCount,
|
||
MaxDownloads: batch.MaxDownloads,
|
||
Type: batch.Type,
|
||
Content: batch.Content,
|
||
Files: batch.FileItems,
|
||
}))
|
||
}
|
||
|
||
// GetDownloadCount 查询下载次数
|
||
// @Summary 查询下载次数
|
||
// @Description 根据取件码查询当前下载次数和最大允许下载次数。支持已过期的批次。
|
||
// @Tags Public
|
||
// @Produce json
|
||
// @Param pickup_code path string true "取件码"
|
||
// @Success 200 {object} model.Response{data=DownloadCountResponse}
|
||
// @Failure 400 {object} model.Response
|
||
// @Failure 404 {object} model.Response
|
||
// @Router /api/batches/{pickup_code}/count [get]
|
||
func (h *PickupHandler) GetDownloadCount(c *gin.Context) {
|
||
code := c.Param("pickup_code")
|
||
if code == "" {
|
||
c.JSON(http.StatusBadRequest, model.ErrorResponse(model.CodeBadRequest, "pickup code required"))
|
||
return
|
||
}
|
||
|
||
count, max, err := h.batchService.GetDownloadCountByPickupCode(code)
|
||
if err != nil {
|
||
middleware.RecordPickupFailure(c.ClientIP())
|
||
c.JSON(http.StatusNotFound, model.ErrorResponse(model.CodeNotFound, "batch not found"))
|
||
return
|
||
}
|
||
|
||
c.JSON(http.StatusOK, model.SuccessResponse(DownloadCountResponse{
|
||
DownloadCount: count,
|
||
MaxDownloads: max,
|
||
}))
|
||
}
|
||
|
||
func getBaseURL(c *gin.Context) string {
|
||
// 优先使用配置中的 BaseURL
|
||
if config.GlobalConfig.Site.BaseURL != "" {
|
||
return strings.TrimSuffix(config.GlobalConfig.Site.BaseURL, "/")
|
||
}
|
||
|
||
// 自动检测逻辑
|
||
scheme := "http"
|
||
if c.Request.TLS != nil {
|
||
scheme = "https"
|
||
} else {
|
||
// 检查常用的代理协议头 (优先)
|
||
// 增加对用户提供的 :scheme (可能被某些代理转为普通 header) 的支持
|
||
// 增加对 X-Forwarded-Proto 可能存在的逗号分隔列表的处理
|
||
checkHeaders := []struct {
|
||
name string
|
||
values []string
|
||
}{
|
||
{"X-Forwarded-Proto", []string{"https"}},
|
||
{"X-Forwarded-Protocol", []string{"https"}},
|
||
{"X-Url-Scheme", []string{"https"}},
|
||
{"Front-End-Https", []string{"on", "https"}},
|
||
{"X-Forwarded-Ssl", []string{"on", "https"}},
|
||
{":scheme", []string{"https"}},
|
||
{"X-Scheme", []string{"https"}},
|
||
}
|
||
|
||
found := false
|
||
for _, h := range checkHeaders {
|
||
val := c.GetHeader(h.name)
|
||
if val == "" {
|
||
continue
|
||
}
|
||
// 处理可能的逗号分隔列表 (如 X-Forwarded-Proto: https, http)
|
||
firstVal := strings.TrimSpace(strings.ToLower(strings.Split(val, ",")[0]))
|
||
for _, target := range h.values {
|
||
if firstVal == target {
|
||
scheme = "https"
|
||
found = true
|
||
break
|
||
}
|
||
}
|
||
if found {
|
||
break
|
||
}
|
||
}
|
||
|
||
// 检查 Forwarded 头部 (RFC 7239)
|
||
if !found {
|
||
if forwarded := c.GetHeader("Forwarded"); forwarded != "" {
|
||
if strings.Contains(strings.ToLower(forwarded), "proto=https") {
|
||
scheme = "https"
|
||
found = true
|
||
}
|
||
}
|
||
}
|
||
|
||
// 启发式判断:如果上述头部都没有,但 Referer 是 https,则认为也是 https
|
||
// 这在同域 API 请求时非常可靠
|
||
if !found {
|
||
if referer := c.GetHeader("Referer"); strings.HasPrefix(strings.ToLower(referer), "https://") {
|
||
scheme = "https"
|
||
}
|
||
}
|
||
}
|
||
|
||
host := c.Request.Host
|
||
if forwardedHost := c.GetHeader("X-Forwarded-Host"); forwardedHost != "" {
|
||
// 处理可能的逗号分隔列表
|
||
host = strings.TrimSpace(strings.Split(forwardedHost, ",")[0])
|
||
}
|
||
|
||
return fmt.Sprintf("%s://%s", scheme, host)
|
||
}
|
||
|
||
// DownloadFile 下载单个文件
|
||
// @Summary 下载单个文件
|
||
// @Description 根据文件 ID 下载单个文件。支持直观的文件名结尾以方便下载工具识别。可选提供带 pickup scope 的 API Token。
|
||
// @Tags Public
|
||
// @Security APITokenAuth
|
||
// @Param file_id path string true "文件 ID (UUID)"
|
||
// @Param filename path string false "文件名"
|
||
// @Produce application/octet-stream
|
||
// @Success 200 {file} file
|
||
// @Failure 404 {object} model.Response
|
||
// @Failure 410 {object} model.Response
|
||
// @Router /api/files/{file_id}/{filename} [get]
|
||
// @Router /api/files/{file_id}/download [get]
|
||
func (h *PickupHandler) DownloadFile(c *gin.Context) {
|
||
fileID := c.Param("file_id")
|
||
|
||
var item model.FileItem
|
||
if err := bootstrap.DB.First(&item, "id = ?", fileID).Error; err != nil {
|
||
c.JSON(http.StatusNotFound, model.ErrorResponse(model.CodeNotFound, "file not found"))
|
||
return
|
||
}
|
||
|
||
var batch model.FileBatch
|
||
if err := bootstrap.DB.First(&batch, "id = ?", item.BatchID).Error; err != nil {
|
||
c.JSON(http.StatusNotFound, model.ErrorResponse(model.CodeNotFound, "batch not found"))
|
||
return
|
||
}
|
||
|
||
if h.batchService.IsExpired(&batch) {
|
||
h.batchService.MarkAsExpired(&batch)
|
||
// 按照需求,如果不存在(已在上面处理)或达到上限,返回 404
|
||
if batch.ExpireType == "download" && batch.DownloadCount >= batch.MaxDownloads {
|
||
c.JSON(http.StatusNotFound, model.ErrorResponse(model.CodeNotFound, "file not found or download limit reached"))
|
||
} else {
|
||
c.JSON(http.StatusGone, model.ErrorResponse(model.CodeGone, "batch expired"))
|
||
}
|
||
return
|
||
}
|
||
|
||
// 打开文件
|
||
reader, err := storage.GlobalStorage.Open(c.Request.Context(), item.StoragePath)
|
||
if err != nil {
|
||
c.JSON(http.StatusInternalServerError, model.ErrorResponse(model.CodeInternalError, "failed to open file"))
|
||
return
|
||
}
|
||
defer reader.Close()
|
||
|
||
// 增加下载次数
|
||
if err := h.batchService.IncrementDownloadCount(batch.ID); err != nil {
|
||
// 记录错误但不中断下载过程
|
||
slog.Error("Failed to increment download count for batch", "batch_id", batch.ID, "error", err)
|
||
}
|
||
|
||
c.Header("Content-Disposition", fmt.Sprintf("attachment; filename=\"%s\"", item.OriginalName))
|
||
c.Header("Content-Type", item.MimeType)
|
||
c.Header("Content-Length", strconv.FormatInt(item.Size, 10))
|
||
|
||
// 如果是 HEAD 请求,只返回 Header
|
||
if c.Request.Method == http.MethodHead {
|
||
return
|
||
}
|
||
|
||
if _, err := io.Copy(c.Writer, reader); err != nil {
|
||
slog.Error("Error during file download", "file_id", item.ID, "error", err)
|
||
}
|
||
}
|