mirror of
https://git.fightbot.fun/hxuanyu/FileRelay.git
synced 2026-02-15 10:21:43 +08:00
Initial commit
This commit is contained in:
314
internal/api/public/pickup.go
Normal file
314
internal/api/public/pickup.go
Normal file
@@ -0,0 +1,314 @@
|
||||
package public
|
||||
|
||||
import (
|
||||
"FileRelay/internal/api/middleware"
|
||||
"FileRelay/internal/bootstrap"
|
||||
"FileRelay/internal/config"
|
||||
"FileRelay/internal/model"
|
||||
"FileRelay/internal/service"
|
||||
"FileRelay/internal/storage"
|
||||
"archive/zip"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type PickupResponse struct {
|
||||
Remark string `json:"remark"`
|
||||
ExpireAt *time.Time `json:"expire_at"`
|
||||
ExpireType string `json:"expire_type"`
|
||||
DownloadCount int `json:"download_count"`
|
||||
MaxDownloads int `json:"max_downloads"`
|
||||
Type string `json:"type"`
|
||||
Content string `json:"content,omitempty"`
|
||||
Files []model.FileItem `json:"files,omitempty"`
|
||||
}
|
||||
|
||||
type DownloadCountResponse struct {
|
||||
DownloadCount int `json:"download_count"`
|
||||
MaxDownloads int `json:"max_downloads"`
|
||||
}
|
||||
|
||||
// DownloadBatch 批量下载文件 (ZIP)
|
||||
// @Summary 批量下载文件
|
||||
// @Description 根据取件码将批次内的所有文件打包为 ZIP 格式一次性下载。可选提供带 pickup scope 的 API Token。
|
||||
// @Tags Public
|
||||
// @Security APITokenAuth
|
||||
// @Param pickup_code path string true "取件码"
|
||||
// @Produce application/zip
|
||||
// @Success 200 {file} file
|
||||
// @Failure 404 {object} model.Response
|
||||
// @Router /api/batches/{pickup_code}/download [get]
|
||||
func (h *PickupHandler) DownloadBatch(c *gin.Context) {
|
||||
code := c.Param("pickup_code")
|
||||
batch, err := h.batchService.GetBatchByPickupCode(code)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, model.ErrorResponse(model.CodeNotFound, "batch not found or expired"))
|
||||
return
|
||||
}
|
||||
|
||||
c.Header("Content-Disposition", fmt.Sprintf("attachment; filename=\"batch_%s.zip\"", code))
|
||||
c.Header("Content-Type", "application/zip")
|
||||
|
||||
zw := zip.NewWriter(c.Writer)
|
||||
defer zw.Close()
|
||||
|
||||
for _, item := range batch.FileItems {
|
||||
reader, err := storage.GlobalStorage.Open(c.Request.Context(), item.StoragePath)
|
||||
if err != nil {
|
||||
continue // Skip failed files
|
||||
}
|
||||
|
||||
f, err := zw.Create(item.OriginalName)
|
||||
if err != nil {
|
||||
reader.Close()
|
||||
continue
|
||||
}
|
||||
|
||||
_, _ = io.Copy(f, reader)
|
||||
reader.Close()
|
||||
}
|
||||
|
||||
// 增加下载次数
|
||||
if err := h.batchService.IncrementDownloadCount(batch.ID); err != nil {
|
||||
slog.Error("Failed to increment download count", "batch_id", batch.ID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
type PickupHandler struct {
|
||||
batchService *service.BatchService
|
||||
}
|
||||
|
||||
func NewPickupHandler() *PickupHandler {
|
||||
return &PickupHandler{
|
||||
batchService: service.NewBatchService(),
|
||||
}
|
||||
}
|
||||
|
||||
// Pickup 获取批次信息
|
||||
// @Summary 获取批次信息
|
||||
// @Description 根据取件码获取文件批次详细信息和文件列表。可选提供带 pickup scope 的 API Token。
|
||||
// @Tags Public
|
||||
// @Security APITokenAuth
|
||||
// @Produce json
|
||||
// @Param pickup_code path string true "取件码"
|
||||
// @Success 200 {object} model.Response{data=PickupResponse}
|
||||
// @Failure 404 {object} model.Response
|
||||
// @Router /api/batches/{pickup_code} [get]
|
||||
func (h *PickupHandler) Pickup(c *gin.Context) {
|
||||
code := c.Param("pickup_code")
|
||||
if code == "" {
|
||||
c.JSON(http.StatusBadRequest, model.ErrorResponse(model.CodeBadRequest, "pickup code required"))
|
||||
return
|
||||
}
|
||||
|
||||
batch, err := h.batchService.GetBatchByPickupCode(code)
|
||||
if err != nil {
|
||||
middleware.RecordPickupFailure(c.ClientIP())
|
||||
c.JSON(http.StatusNotFound, model.ErrorResponse(model.CodeNotFound, "batch not found or expired"))
|
||||
return
|
||||
}
|
||||
|
||||
if batch.Type == "text" {
|
||||
if err := h.batchService.IncrementDownloadCount(batch.ID); err != nil {
|
||||
slog.Error("Failed to increment download count for batch", "batch_id", batch.ID, "error", err)
|
||||
} else {
|
||||
batch.DownloadCount++
|
||||
}
|
||||
}
|
||||
|
||||
baseURL := getBaseURL(c)
|
||||
|
||||
for i := range batch.FileItems {
|
||||
batch.FileItems[i].DownloadURL = fmt.Sprintf("%s/api/files/%s/%s", baseURL, batch.FileItems[i].ID, batch.FileItems[i].OriginalName)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, model.SuccessResponse(PickupResponse{
|
||||
Remark: batch.Remark,
|
||||
ExpireAt: batch.ExpireAt,
|
||||
ExpireType: batch.ExpireType,
|
||||
DownloadCount: batch.DownloadCount,
|
||||
MaxDownloads: batch.MaxDownloads,
|
||||
Type: batch.Type,
|
||||
Content: batch.Content,
|
||||
Files: batch.FileItems,
|
||||
}))
|
||||
}
|
||||
|
||||
// GetDownloadCount 查询下载次数
|
||||
// @Summary 查询下载次数
|
||||
// @Description 根据取件码查询当前下载次数和最大允许下载次数。支持已过期的批次。
|
||||
// @Tags Public
|
||||
// @Produce json
|
||||
// @Param pickup_code path string true "取件码"
|
||||
// @Success 200 {object} model.Response{data=DownloadCountResponse}
|
||||
// @Failure 400 {object} model.Response
|
||||
// @Failure 404 {object} model.Response
|
||||
// @Router /api/batches/{pickup_code}/count [get]
|
||||
func (h *PickupHandler) GetDownloadCount(c *gin.Context) {
|
||||
code := c.Param("pickup_code")
|
||||
if code == "" {
|
||||
c.JSON(http.StatusBadRequest, model.ErrorResponse(model.CodeBadRequest, "pickup code required"))
|
||||
return
|
||||
}
|
||||
|
||||
count, max, err := h.batchService.GetDownloadCountByPickupCode(code)
|
||||
if err != nil {
|
||||
middleware.RecordPickupFailure(c.ClientIP())
|
||||
c.JSON(http.StatusNotFound, model.ErrorResponse(model.CodeNotFound, "batch not found"))
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, model.SuccessResponse(DownloadCountResponse{
|
||||
DownloadCount: count,
|
||||
MaxDownloads: max,
|
||||
}))
|
||||
}
|
||||
|
||||
func getBaseURL(c *gin.Context) string {
|
||||
// 优先使用配置中的 BaseURL
|
||||
if config.GlobalConfig.Site.BaseURL != "" {
|
||||
return strings.TrimSuffix(config.GlobalConfig.Site.BaseURL, "/")
|
||||
}
|
||||
|
||||
// 自动检测逻辑
|
||||
scheme := "http"
|
||||
if c.Request.TLS != nil {
|
||||
scheme = "https"
|
||||
} else {
|
||||
// 检查常用的代理协议头 (优先)
|
||||
// 增加对用户提供的 :scheme (可能被某些代理转为普通 header) 的支持
|
||||
// 增加对 X-Forwarded-Proto 可能存在的逗号分隔列表的处理
|
||||
checkHeaders := []struct {
|
||||
name string
|
||||
values []string
|
||||
}{
|
||||
{"X-Forwarded-Proto", []string{"https"}},
|
||||
{"X-Forwarded-Protocol", []string{"https"}},
|
||||
{"X-Url-Scheme", []string{"https"}},
|
||||
{"Front-End-Https", []string{"on", "https"}},
|
||||
{"X-Forwarded-Ssl", []string{"on", "https"}},
|
||||
{":scheme", []string{"https"}},
|
||||
{"X-Scheme", []string{"https"}},
|
||||
}
|
||||
|
||||
found := false
|
||||
for _, h := range checkHeaders {
|
||||
val := c.GetHeader(h.name)
|
||||
if val == "" {
|
||||
continue
|
||||
}
|
||||
// 处理可能的逗号分隔列表 (如 X-Forwarded-Proto: https, http)
|
||||
firstVal := strings.TrimSpace(strings.ToLower(strings.Split(val, ",")[0]))
|
||||
for _, target := range h.values {
|
||||
if firstVal == target {
|
||||
scheme = "https"
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if found {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 检查 Forwarded 头部 (RFC 7239)
|
||||
if !found {
|
||||
if forwarded := c.GetHeader("Forwarded"); forwarded != "" {
|
||||
if strings.Contains(strings.ToLower(forwarded), "proto=https") {
|
||||
scheme = "https"
|
||||
found = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 启发式判断:如果上述头部都没有,但 Referer 是 https,则认为也是 https
|
||||
// 这在同域 API 请求时非常可靠
|
||||
if !found {
|
||||
if referer := c.GetHeader("Referer"); strings.HasPrefix(strings.ToLower(referer), "https://") {
|
||||
scheme = "https"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
host := c.Request.Host
|
||||
if forwardedHost := c.GetHeader("X-Forwarded-Host"); forwardedHost != "" {
|
||||
// 处理可能的逗号分隔列表
|
||||
host = strings.TrimSpace(strings.Split(forwardedHost, ",")[0])
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s://%s", scheme, host)
|
||||
}
|
||||
|
||||
// DownloadFile 下载单个文件
|
||||
// @Summary 下载单个文件
|
||||
// @Description 根据文件 ID 下载单个文件。支持直观的文件名结尾以方便下载工具识别。可选提供带 pickup scope 的 API Token。
|
||||
// @Tags Public
|
||||
// @Security APITokenAuth
|
||||
// @Param file_id path string true "文件 ID (UUID)"
|
||||
// @Param filename path string false "文件名"
|
||||
// @Produce application/octet-stream
|
||||
// @Success 200 {file} file
|
||||
// @Failure 404 {object} model.Response
|
||||
// @Failure 410 {object} model.Response
|
||||
// @Router /api/files/{file_id}/{filename} [get]
|
||||
// @Router /api/files/{file_id}/download [get]
|
||||
func (h *PickupHandler) DownloadFile(c *gin.Context) {
|
||||
fileID := c.Param("file_id")
|
||||
|
||||
var item model.FileItem
|
||||
if err := bootstrap.DB.First(&item, "id = ?", fileID).Error; err != nil {
|
||||
c.JSON(http.StatusNotFound, model.ErrorResponse(model.CodeNotFound, "file not found"))
|
||||
return
|
||||
}
|
||||
|
||||
var batch model.FileBatch
|
||||
if err := bootstrap.DB.First(&batch, "id = ?", item.BatchID).Error; err != nil {
|
||||
c.JSON(http.StatusNotFound, model.ErrorResponse(model.CodeNotFound, "batch not found"))
|
||||
return
|
||||
}
|
||||
|
||||
if h.batchService.IsExpired(&batch) {
|
||||
h.batchService.MarkAsExpired(&batch)
|
||||
// 按照需求,如果不存在(已在上面处理)或达到上限,返回 404
|
||||
if batch.ExpireType == "download" && batch.DownloadCount >= batch.MaxDownloads {
|
||||
c.JSON(http.StatusNotFound, model.ErrorResponse(model.CodeNotFound, "file not found or download limit reached"))
|
||||
} else {
|
||||
c.JSON(http.StatusGone, model.ErrorResponse(model.CodeGone, "batch expired"))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 打开文件
|
||||
reader, err := storage.GlobalStorage.Open(c.Request.Context(), item.StoragePath)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, model.ErrorResponse(model.CodeInternalError, "failed to open file"))
|
||||
return
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
// 增加下载次数
|
||||
if err := h.batchService.IncrementDownloadCount(batch.ID); err != nil {
|
||||
// 记录错误但不中断下载过程
|
||||
slog.Error("Failed to increment download count for batch", "batch_id", batch.ID, "error", err)
|
||||
}
|
||||
|
||||
c.Header("Content-Disposition", fmt.Sprintf("attachment; filename=\"%s\"", item.OriginalName))
|
||||
c.Header("Content-Type", item.MimeType)
|
||||
c.Header("Content-Length", strconv.FormatInt(item.Size, 10))
|
||||
|
||||
// 如果是 HEAD 请求,只返回 Header
|
||||
if c.Request.Method == http.MethodHead {
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := io.Copy(c.Writer, reader); err != nil {
|
||||
slog.Error("Error during file download", "file_id", item.ID, "error", err)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user