Files
FileRelay/internal/api/public/pickup.go
2026-01-28 20:44:34 +08:00

315 lines
9.9 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package public
import (
"FileRelay/internal/api/middleware"
"FileRelay/internal/bootstrap"
"FileRelay/internal/config"
"FileRelay/internal/model"
"FileRelay/internal/service"
"FileRelay/internal/storage"
"archive/zip"
"fmt"
"io"
"log/slog"
"net/http"
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin"
)
type PickupResponse struct {
Remark string `json:"remark"`
ExpireAt *time.Time `json:"expire_at"`
ExpireType string `json:"expire_type"`
DownloadCount int `json:"download_count"`
MaxDownloads int `json:"max_downloads"`
Type string `json:"type"`
Content string `json:"content,omitempty"`
Files []model.FileItem `json:"files,omitempty"`
}
type DownloadCountResponse struct {
DownloadCount int `json:"download_count"`
MaxDownloads int `json:"max_downloads"`
}
// DownloadBatch 批量下载文件 (ZIP)
// @Summary 批量下载文件
// @Description 根据取件码将批次内的所有文件打包为 ZIP 格式一次性下载。可选提供带 pickup scope 的 API Token。
// @Tags Public
// @Security APITokenAuth
// @Param pickup_code path string true "取件码"
// @Produce application/zip
// @Success 200 {file} file
// @Failure 404 {object} model.Response
// @Router /api/batches/{pickup_code}/download [get]
func (h *PickupHandler) DownloadBatch(c *gin.Context) {
code := c.Param("pickup_code")
batch, err := h.batchService.GetBatchByPickupCode(code)
if err != nil {
c.JSON(http.StatusNotFound, model.ErrorResponse(model.CodeNotFound, "batch not found or expired"))
return
}
c.Header("Content-Disposition", fmt.Sprintf("attachment; filename=\"batch_%s.zip\"", code))
c.Header("Content-Type", "application/zip")
zw := zip.NewWriter(c.Writer)
defer zw.Close()
for _, item := range batch.FileItems {
reader, err := storage.GlobalStorage.Open(c.Request.Context(), item.StoragePath)
if err != nil {
continue // Skip failed files
}
f, err := zw.Create(item.OriginalName)
if err != nil {
reader.Close()
continue
}
_, _ = io.Copy(f, reader)
reader.Close()
}
// 增加下载次数
if err := h.batchService.IncrementDownloadCount(batch.ID); err != nil {
slog.Error("Failed to increment download count", "batch_id", batch.ID, "error", err)
}
}
type PickupHandler struct {
batchService *service.BatchService
}
func NewPickupHandler() *PickupHandler {
return &PickupHandler{
batchService: service.NewBatchService(),
}
}
// Pickup 获取批次信息
// @Summary 获取批次信息
// @Description 根据取件码获取文件批次详细信息和文件列表。可选提供带 pickup scope 的 API Token。
// @Tags Public
// @Security APITokenAuth
// @Produce json
// @Param pickup_code path string true "取件码"
// @Success 200 {object} model.Response{data=PickupResponse}
// @Failure 404 {object} model.Response
// @Router /api/batches/{pickup_code} [get]
func (h *PickupHandler) Pickup(c *gin.Context) {
code := c.Param("pickup_code")
if code == "" {
c.JSON(http.StatusBadRequest, model.ErrorResponse(model.CodeBadRequest, "pickup code required"))
return
}
batch, err := h.batchService.GetBatchByPickupCode(code)
if err != nil {
middleware.RecordPickupFailure(c.ClientIP())
c.JSON(http.StatusNotFound, model.ErrorResponse(model.CodeNotFound, "batch not found or expired"))
return
}
if batch.Type == "text" {
if err := h.batchService.IncrementDownloadCount(batch.ID); err != nil {
slog.Error("Failed to increment download count for batch", "batch_id", batch.ID, "error", err)
} else {
batch.DownloadCount++
}
}
baseURL := getBaseURL(c)
for i := range batch.FileItems {
batch.FileItems[i].DownloadURL = fmt.Sprintf("%s/api/files/%s/%s", baseURL, batch.FileItems[i].ID, batch.FileItems[i].OriginalName)
}
c.JSON(http.StatusOK, model.SuccessResponse(PickupResponse{
Remark: batch.Remark,
ExpireAt: batch.ExpireAt,
ExpireType: batch.ExpireType,
DownloadCount: batch.DownloadCount,
MaxDownloads: batch.MaxDownloads,
Type: batch.Type,
Content: batch.Content,
Files: batch.FileItems,
}))
}
// GetDownloadCount 查询下载次数
// @Summary 查询下载次数
// @Description 根据取件码查询当前下载次数和最大允许下载次数。支持已过期的批次。
// @Tags Public
// @Produce json
// @Param pickup_code path string true "取件码"
// @Success 200 {object} model.Response{data=DownloadCountResponse}
// @Failure 400 {object} model.Response
// @Failure 404 {object} model.Response
// @Router /api/batches/{pickup_code}/count [get]
func (h *PickupHandler) GetDownloadCount(c *gin.Context) {
code := c.Param("pickup_code")
if code == "" {
c.JSON(http.StatusBadRequest, model.ErrorResponse(model.CodeBadRequest, "pickup code required"))
return
}
count, max, err := h.batchService.GetDownloadCountByPickupCode(code)
if err != nil {
middleware.RecordPickupFailure(c.ClientIP())
c.JSON(http.StatusNotFound, model.ErrorResponse(model.CodeNotFound, "batch not found"))
return
}
c.JSON(http.StatusOK, model.SuccessResponse(DownloadCountResponse{
DownloadCount: count,
MaxDownloads: max,
}))
}
func getBaseURL(c *gin.Context) string {
// 优先使用配置中的 BaseURL
if config.GlobalConfig.Site.BaseURL != "" {
return strings.TrimSuffix(config.GlobalConfig.Site.BaseURL, "/")
}
// 自动检测逻辑
scheme := "http"
if c.Request.TLS != nil {
scheme = "https"
} else {
// 检查常用的代理协议头 (优先)
// 增加对用户提供的 :scheme (可能被某些代理转为普通 header) 的支持
// 增加对 X-Forwarded-Proto 可能存在的逗号分隔列表的处理
checkHeaders := []struct {
name string
values []string
}{
{"X-Forwarded-Proto", []string{"https"}},
{"X-Forwarded-Protocol", []string{"https"}},
{"X-Url-Scheme", []string{"https"}},
{"Front-End-Https", []string{"on", "https"}},
{"X-Forwarded-Ssl", []string{"on", "https"}},
{":scheme", []string{"https"}},
{"X-Scheme", []string{"https"}},
}
found := false
for _, h := range checkHeaders {
val := c.GetHeader(h.name)
if val == "" {
continue
}
// 处理可能的逗号分隔列表 (如 X-Forwarded-Proto: https, http)
firstVal := strings.TrimSpace(strings.ToLower(strings.Split(val, ",")[0]))
for _, target := range h.values {
if firstVal == target {
scheme = "https"
found = true
break
}
}
if found {
break
}
}
// 检查 Forwarded 头部 (RFC 7239)
if !found {
if forwarded := c.GetHeader("Forwarded"); forwarded != "" {
if strings.Contains(strings.ToLower(forwarded), "proto=https") {
scheme = "https"
found = true
}
}
}
// 启发式判断:如果上述头部都没有,但 Referer 是 https则认为也是 https
// 这在同域 API 请求时非常可靠
if !found {
if referer := c.GetHeader("Referer"); strings.HasPrefix(strings.ToLower(referer), "https://") {
scheme = "https"
}
}
}
host := c.Request.Host
if forwardedHost := c.GetHeader("X-Forwarded-Host"); forwardedHost != "" {
// 处理可能的逗号分隔列表
host = strings.TrimSpace(strings.Split(forwardedHost, ",")[0])
}
return fmt.Sprintf("%s://%s", scheme, host)
}
// DownloadFile 下载单个文件
// @Summary 下载单个文件
// @Description 根据文件 ID 下载单个文件。支持直观的文件名结尾以方便下载工具识别。可选提供带 pickup scope 的 API Token。
// @Tags Public
// @Security APITokenAuth
// @Param file_id path string true "文件 ID (UUID)"
// @Param filename path string false "文件名"
// @Produce application/octet-stream
// @Success 200 {file} file
// @Failure 404 {object} model.Response
// @Failure 410 {object} model.Response
// @Router /api/files/{file_id}/{filename} [get]
// @Router /api/files/{file_id}/download [get]
func (h *PickupHandler) DownloadFile(c *gin.Context) {
fileID := c.Param("file_id")
var item model.FileItem
if err := bootstrap.DB.First(&item, "id = ?", fileID).Error; err != nil {
c.JSON(http.StatusNotFound, model.ErrorResponse(model.CodeNotFound, "file not found"))
return
}
var batch model.FileBatch
if err := bootstrap.DB.First(&batch, "id = ?", item.BatchID).Error; err != nil {
c.JSON(http.StatusNotFound, model.ErrorResponse(model.CodeNotFound, "batch not found"))
return
}
if h.batchService.IsExpired(&batch) {
h.batchService.MarkAsExpired(&batch)
// 按照需求,如果不存在(已在上面处理)或达到上限,返回 404
if batch.ExpireType == "download" && batch.DownloadCount >= batch.MaxDownloads {
c.JSON(http.StatusNotFound, model.ErrorResponse(model.CodeNotFound, "file not found or download limit reached"))
} else {
c.JSON(http.StatusGone, model.ErrorResponse(model.CodeGone, "batch expired"))
}
return
}
// 打开文件
reader, err := storage.GlobalStorage.Open(c.Request.Context(), item.StoragePath)
if err != nil {
c.JSON(http.StatusInternalServerError, model.ErrorResponse(model.CodeInternalError, "failed to open file"))
return
}
defer reader.Close()
// 增加下载次数
if err := h.batchService.IncrementDownloadCount(batch.ID); err != nil {
// 记录错误但不中断下载过程
slog.Error("Failed to increment download count for batch", "batch_id", batch.ID, "error", err)
}
c.Header("Content-Disposition", fmt.Sprintf("attachment; filename=\"%s\"", item.OriginalName))
c.Header("Content-Type", item.MimeType)
c.Header("Content-Length", strconv.FormatInt(item.Size, 10))
// 如果是 HEAD 请求,只返回 Header
if c.Request.Method == http.MethodHead {
return
}
if _, err := io.Copy(c.Writer, reader); err != nil {
slog.Error("Error during file download", "file_id", item.ID, "error", err)
}
}