package storage import ( "context" "crypto/hmac" "crypto/sha256" "encoding/hex" "errors" "fmt" "io" "net/url" "os" "path/filepath" "strconv" "strings" "time" "github.com/minio/minio-go/v7" "github.com/minio/minio-go/v7/pkg/credentials" "go.ipao.vip/atom/container" "go.ipao.vip/atom/opt" ) const DefaultPrefix = "Storage" var ( errStorageBucketNotFound = errors.New("storage bucket not found") errStorageBucketCheckFailed = errors.New("storage bucket check failed") errStorageEndpointRequired = errors.New("storage endpoint is required") errStorageAccessKeyRequired = errors.New("storage access key or secret key is required") errStorageBucketRequired = errors.New("storage bucket is required") errStorageInvalidEndpoint = errors.New("storage endpoint is invalid") errStorageUnsupportedMethod = errors.New("unsupported method") errStorageSignedURLUnsupported = errors.New("s3 storage does not use signed local urls") errStorageInvalidExpiry = errors.New("invalid expiry") errStorageExpired = errors.New("expired") errStorageInvalidSignature = errors.New("invalid signature") ) func DefaultProvider() container.ProviderContainer { return container.ProviderContainer{ Provider: Provide, Options: []opt.Option{ opt.Prefix(DefaultPrefix), }, } } func Provide(opts ...opt.Option) error { o := opt.New(opts...) var config Config if err := o.UnmarshalConfig(&config); err != nil { return err } return container.Container.Provide(func() (*Storage, error) { store := &Storage{Config: &config} if store.storageType() == "s3" { client, err := store.s3ClientForUse() if err != nil { return nil, err } if store.Config.CheckOnBoot { // 启动时可选检查 bucket 是否可用,便于尽早暴露配置问题。 exists, err := client.BucketExists(context.Background(), store.Config.Bucket) if err != nil { return nil, fmt.Errorf("%w: %w", errStorageBucketCheckFailed, err) } if !exists { return nil, fmt.Errorf("%w: %s", errStorageBucketNotFound, store.Config.Bucket) } } } return store, nil }, o.DiOptions()...) } type Storage struct { Config *Config s3Client *minio.Client } func (s *Storage) Download(ctx context.Context, key, filePath string) error { if s.storageType() == "local" { localPath := s.Config.LocalPath if localPath == "" { localPath = "./storage" } srcPath := filepath.Join(localPath, key) if err := os.MkdirAll(filepath.Dir(filePath), 0o755); err != nil { return fmt.Errorf("create download dir: %w", err) } src, err := os.Open(srcPath) if err != nil { return fmt.Errorf("open source file: %w", err) } defer src.Close() dst, err := os.Create(filePath) if err != nil { return fmt.Errorf("create destination file: %w", err) } defer dst.Close() if _, err := io.Copy(dst, src); err != nil { return fmt.Errorf("copy file content: %w", err) } return nil } client, err := s.s3ClientForUse() if err != nil { return err } if err := os.MkdirAll(filepath.Dir(filePath), 0o755); err != nil { return fmt.Errorf("create download dir: %w", err) } if err := client.FGetObject(ctx, s.Config.Bucket, key, filePath, minio.GetObjectOptions{}); err != nil { return fmt.Errorf("download object: %w", err) } return nil } func (s *Storage) Delete(key string) error { if s.storageType() == "local" { localPath := s.Config.LocalPath if localPath == "" { localPath = "./storage" } filePath := filepath.Join(localPath, key) if err := os.Remove(filePath); err != nil { return fmt.Errorf("remove local object: %w", err) } return nil } client, err := s.s3ClientForUse() if err != nil { return err } if err := client.RemoveObject(context.Background(), s.Config.Bucket, key, minio.RemoveObjectOptions{}); err != nil { return fmt.Errorf("remove s3 object: %w", err) } return nil } func (s *Storage) SignURL(method, key string, expires time.Duration) (string, error) { if s.storageType() == "s3" { client, err := s.s3ClientForUse() if err != nil { return "", err } switch strings.ToUpper(method) { case "GET": u, err := client.PresignedGetObject(context.Background(), s.Config.Bucket, key, expires, nil) if err != nil { return "", fmt.Errorf("presign get object: %w", err) } return u.String(), nil case "PUT": u, err := client.PresignedPutObject(context.Background(), s.Config.Bucket, key, expires) if err != nil { return "", fmt.Errorf("presign put object: %w", err) } return u.String(), nil default: return "", errStorageUnsupportedMethod } } exp := time.Now().Add(expires).Unix() sign := s.signature(method, key, exp) baseURL := strings.TrimRight(s.Config.BaseURL, "/") u, err := url.Parse(baseURL + "/" + key) if err != nil { return "", fmt.Errorf("parse base url: %w", err) } q := u.Query() q.Set("expires", strconv.FormatInt(exp, 10)) q.Set("sign", sign) u.RawQuery = q.Encode() return u.String(), nil } func (s *Storage) Verify(method, key, expStr, sign string) error { if s.storageType() == "s3" { return errStorageSignedURLUnsupported } exp, err := strconv.ParseInt(expStr, 10, 64) if err != nil { return errStorageInvalidExpiry } if time.Now().Unix() > exp { return errStorageExpired } expected := s.signature(method, key, exp) if !hmac.Equal([]byte(expected), []byte(sign)) { return errStorageInvalidSignature } return nil } func (s *Storage) signature(method, key string, exp int64) string { str := fmt.Sprintf("%s\n%s\n%d", method, key, exp) h := hmac.New(sha256.New, []byte(s.Config.Secret)) h.Write([]byte(str)) return hex.EncodeToString(h.Sum(nil)) } func (s *Storage) PutObject(ctx context.Context, key, filePath, contentType string) error { if s.storageType() == "local" { localPath := s.Config.LocalPath if localPath == "" { localPath = "./storage" } dstPath := filepath.Join(localPath, key) if err := os.MkdirAll(filepath.Dir(dstPath), 0o755); err != nil { return fmt.Errorf("create object dir: %w", err) } if err := os.Rename(filePath, dstPath); err != nil { return fmt.Errorf("move object file: %w", err) } return nil } client, err := s.s3ClientForUse() if err != nil { return err } opts := minio.PutObjectOptions{} if contentType != "" { opts.ContentType = contentType } if _, err := client.FPutObject(ctx, s.Config.Bucket, key, filePath, opts); err != nil { return fmt.Errorf("upload object: %w", err) } return nil } func (s *Storage) Provider() string { if s.storageType() == "s3" { return "s3" } return "local" } func (s *Storage) Bucket() string { if s.storageType() == "s3" && s.Config.Bucket != "" { return s.Config.Bucket } return "default" } func (s *Storage) storageType() string { if s.Config == nil { return "local" } typ := strings.TrimSpace(strings.ToLower(s.Config.Type)) if typ == "" { return "local" } return typ } func (s *Storage) s3ClientForUse() (*minio.Client, error) { if s.s3Client != nil { return s.s3Client, nil } if strings.TrimSpace(s.Config.Endpoint) == "" { return nil, errStorageEndpointRequired } if strings.TrimSpace(s.Config.AccessKey) == "" || strings.TrimSpace(s.Config.SecretKey) == "" { return nil, errStorageAccessKeyRequired } if strings.TrimSpace(s.Config.Bucket) == "" { return nil, errStorageBucketRequired } endpoint, secure, err := parseEndpoint(s.Config.Endpoint) if err != nil { return nil, err } opts := &minio.Options{ Creds: credentials.NewStaticV4(s.Config.AccessKey, s.Config.SecretKey, ""), Secure: secure, Region: s.Config.Region, } if s.Config.PathStyle { opts.BucketLookup = minio.BucketLookupPath } else { opts.BucketLookup = minio.BucketLookupDNS } client, err := minio.New(endpoint, opts) if err != nil { return nil, err } s.s3Client = client return client, nil } func parseEndpoint(endpoint string) (string, bool, error) { if strings.HasPrefix(endpoint, "http://") || strings.HasPrefix(endpoint, "https://") { u, err := url.Parse(endpoint) if err != nil { return "", false, fmt.Errorf("parse endpoint: %w", err) } if u.Host == "" { return "", false, errStorageInvalidEndpoint } return u.Host, u.Scheme == "https", nil } return endpoint, false, nil }