From 5b511a5ea18127680391331cfdbc98d859124546 Mon Sep 17 00:00:00 2001 From: Rogee Date: Fri, 17 Jan 2025 15:48:53 +0800 Subject: [PATCH] feat: fix upload file --- backend/app/http/medias/controller.go | 17 +++-- backend/pkg/storage/upload.go | 106 +++++++++----------------- test/go.mod | 10 ++- test/go.sum | 12 +++ test/main_test.go | 30 ++++++++ 5 files changed, 99 insertions(+), 76 deletions(-) diff --git a/backend/app/http/medias/controller.go b/backend/app/http/medias/controller.go index 8b55de8..bc993b2 100644 --- a/backend/app/http/medias/controller.go +++ b/backend/app/http/medias/controller.go @@ -33,17 +33,22 @@ func (ctl *Controller) Prepare() error { // @Bind file file // @Bind claim local func (ctl *Controller) Upload(ctx fiber.Ctx, claim *jwt.Claims, file *multipart.FileHeader, req *UploadReq) (*storage.UploadedFile, error) { - defaultStorage, err := ctl.storageSvc.GetDefault(ctx.Context()) - if err != nil { - return nil, err - } - uploader, err := storage.NewUploader(req.FileName, req.ChunkNumber, req.TotalChunks, req.FileMD5) if err != nil { return nil, err } - uploadedFile, err := uploader.Save(ctx, file) + defaultStorage, err := ctl.storageSvc.GetDefault(ctx.Context()) + if err != nil { + return nil, err + } + + fs, err := ctl.storageSvc.BuildFS(defaultStorage) + if err != nil { + return nil, err + } + + uploadedFile, err := uploader.Save(ctx, fs, file) if err != nil { return nil, err } diff --git a/backend/pkg/storage/upload.go b/backend/pkg/storage/upload.go index b514b14..412bfd0 100644 --- a/backend/pkg/storage/upload.go +++ b/backend/pkg/storage/upload.go @@ -1,6 +1,7 @@ package storage import ( + "bytes" "crypto/md5" "encoding/hex" "errors" @@ -12,6 +13,7 @@ import ( "time" "github.com/gofiber/fiber/v3" + "github.com/spf13/afero" ) type Uploader struct { @@ -22,7 +24,6 @@ type Uploader struct { totalChunks int fileMD5 string - dst string ext string finalPath string } @@ -52,49 +53,27 @@ func NewUploader(fileName string, chunkNumber, totalChunks int, fileMD5 string) totalChunks: totalChunks, fileMD5: fileMD5, ext: filepath.Ext(fileName), - finalPath: filepath.Join(os.TempDir(), fileMD5+filepath.Ext(fileName)), + finalPath: filepath.Join("uploads", time.Now().Format("2006/01/02"), fileMD5+filepath.Ext(fileName)), }, nil } -func (up *Uploader) Save(ctx fiber.Ctx, file *multipart.FileHeader) (*UploadedFile, error) { +func (up *Uploader) Save(ctx fiber.Ctx, fs afero.Fs, file *multipart.FileHeader) (*UploadedFile, error) { if up.chunkNumber != up.totalChunks-1 { return nil, ctx.SaveFile(file, up.chunkPath) } + defer os.RemoveAll(up.tmpDir) // 如果是最后一个分片 // 生成唯一的文件存储路径 - storageDir := filepath.Join(up.dst, time.Now().Format("2006/01/02")) - if err := os.MkdirAll(storageDir, 0o755); err != nil { - os.RemoveAll(filepath.Join(os.TempDir(), up.fileMD5)) - return nil, err - } - - // 计算所有分片的实际大小总和 - totalSize, err := calculateTotalSize(up.tmpDir, up.totalChunks) - if err != nil { - os.RemoveAll(up.tmpDir) - return nil, fmt.Errorf("计算文件大小失败: %w", err) - } // 合并文件 - if err := combineChunks(up.tmpDir, up.finalPath, up.totalChunks); err != nil { - os.RemoveAll(up.tmpDir) + totalSize, err := up.combineChunks(fs) + if err != nil { return nil, fmt.Errorf("合并文件失败: %w", err) } - // 验证MD5 - calculatedMD5, err := calculateFileMD5(up.finalPath) - if err != nil || calculatedMD5 != up.fileMD5 { - os.RemoveAll(up.tmpDir) - os.Remove(up.finalPath) - return nil, errors.New("文件MD5验证失败") - } - - // 清理临时目录 - os.RemoveAll(up.tmpDir) - return &UploadedFile{ - Hash: calculatedMD5, + Hash: up.fileMD5, Name: up.fileName, Path: up.finalPath, Size: totalSize, @@ -102,52 +81,41 @@ func (up *Uploader) Save(ctx fiber.Ctx, file *multipart.FileHeader) (*UploadedFi }, nil } -// 计算所有分片的实际大小总和 -func calculateTotalSize(tempDir string, totalChunks int) (int64, error) { - var totalSize int64 - for i := 0; i < totalChunks; i++ { - chunkPath := filepath.Join(tempDir, fmt.Sprintf("chunk_%d", i)) - info, err := os.Stat(chunkPath) +func (up *Uploader) combineChunks(fs afero.Fs) (int64, error) { + if err := fs.MkdirAll(filepath.Dir(up.finalPath), os.ModePerm); err != nil { + return 0, err + } + + f, err := fs.Create(up.finalPath) + if err != nil { + return 0, err + } + defer f.Close() + + hash := md5.New() + size := int64(0) + for i := 0; i < up.totalChunks; i++ { + chunkPath := fmt.Sprintf("%s/chunk_%d", up.tmpDir, i) + chunk, err := os.ReadFile(chunkPath) if err != nil { return 0, err } - totalSize += info.Size() - } - return totalSize, nil -} + size += int64(len(chunk)) -func combineChunks(tempDir, finalPath string, totalChunks int) error { - finalFile, err := os.Create(finalPath) - if err != nil { - return err - } - defer finalFile.Close() - - for i := 0; i < totalChunks; i++ { - chunkPath := fmt.Sprintf("%s/chunk_%d", tempDir, i) - chunk, err := os.ReadFile(chunkPath) - if err != nil { - return err + if _, err := f.Write(chunk); err != nil { + return 0, err } - if _, err := finalFile.Write(chunk); err != nil { - return err + + if _, err := io.Copy(hash, bytes.NewBuffer(chunk)); err != nil { + return 0, err } + } - return nil -} - -func calculateFileMD5(filePath string) (string, error) { - file, err := os.Open(filePath) - if err != nil { - return "", err - } - defer file.Close() - - hash := md5.New() - if _, err := io.Copy(hash, file); err != nil { - return "", err - } - - return hex.EncodeToString(hash.Sum(nil)), nil + md5 := hex.EncodeToString(hash.Sum(nil)) + if md5 != up.fileMD5 { + return 0, errors.New("文件MD5验证失败") + } + + return size, nil } diff --git a/test/go.mod b/test/go.mod index 823ef1d..f1956f7 100644 --- a/test/go.mod +++ b/test/go.mod @@ -2,12 +2,20 @@ module test go 1.23.2 -require github.com/go-jet/jet/v2 v2.12.0 +require ( + github.com/go-jet/jet/v2 v2.12.0 + github.com/smartystreets/goconvey v1.8.1 + github.com/spf13/afero v1.12.0 +) require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/google/uuid v1.6.0 // indirect + github.com/gopherjs/gopherjs v1.17.2 // indirect + github.com/jtolds/gls v4.20.0+incompatible // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/smarty/assertions v1.15.0 // indirect github.com/stretchr/testify v1.9.0 // indirect + golang.org/x/text v0.21.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/test/go.sum b/test/go.sum index 8ba944a..17d58b4 100644 --- a/test/go.sum +++ b/test/go.sum @@ -6,10 +6,22 @@ github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gopherjs/gopherjs v1.17.2 h1:fQnZVsXk8uxXIStYb0N4bGk7jeyTalG/wsZjQ25dO0g= +github.com/gopherjs/gopherjs v1.17.2/go.mod h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfreZ6J5gM2i+k= +github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo= +github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/smarty/assertions v1.15.0 h1:cR//PqUBUiQRakZWqBiFFQ9wb8emQGDb0HeGdqGByCY= +github.com/smarty/assertions v1.15.0/go.mod h1:yABtdzeQs6l1brC900WlRNwj6ZR55d7B+E8C6HtKdec= +github.com/smartystreets/goconvey v1.8.1 h1:qGjIddxOk4grTu9JPOU31tVfq3cNdBlNa5sSznIX1xY= +github.com/smartystreets/goconvey v1.8.1/go.mod h1:+/u4qLyY6x1jReYOp7GOM2FSt8aP9CzCZL03bI28W60= +github.com/spf13/afero v1.12.0 h1:UcOPyRBYczmFn6yvphxkn9ZEOY65cpwGKb5mL36mrqs= +github.com/spf13/afero v1.12.0/go.mod h1:ZTlWwG4/ahT8W7T0WQ5uYmjI9duaLQGy3Q2OAl4sk/4= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= +golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/test/main_test.go b/test/main_test.go index 20679bf..6bf700f 100644 --- a/test/main_test.go +++ b/test/main_test.go @@ -1,11 +1,15 @@ package main import ( + "os" + "path/filepath" "testing" "test/database/models/qvyun_v2/public/table" . "github.com/go-jet/jet/v2/postgres" + . "github.com/smartystreets/goconvey/convey" + "github.com/spf13/afero" ) func Test_Join(t *testing.T) { @@ -18,3 +22,29 @@ func Test_Join(t *testing.T) { ) t.Log(stmt.DebugSql()) } + +func Test_Afero(t *testing.T) { + Convey("Test afero", t, func() { + fs := afero.NewBasePathFs(afero.NewOsFs(), "/tmp") + + path := "test/a/b/c/test.txt" + err := fs.MkdirAll(filepath.Dir(path), os.ModePerm) + So(err, ShouldBeNil) + + f, err := fs.OpenFile(path, os.O_CREATE|os.O_RDWR, 0o644) + So(err, ShouldBeNil) + + f.Write([]byte("hello ")) + f.Write([]byte("world")) + + f.Close() + + ff, err := fs.Open(path) + So(err, ShouldBeNil) + + b := make([]byte, 1024) + n, err := ff.Read(b) + So(err, ShouldBeNil) + So(string(b[:n]), ShouldEqual, "hello world") + }) +}