package storage import ( "context" "crypto/hmac" "crypto/sha256" "encoding/hex" "fmt" "io" "net/url" "os" "path/filepath" "strconv" "strings" "time" "github.com/minio/minio-go/v7" "github.com/minio/minio-go/v7/pkg/credentials" "go.ipao.vip/atom/container" "go.ipao.vip/atom/opt" ) const DefaultPrefix = "Storage" func DefaultProvider() container.ProviderContainer { return container.ProviderContainer{ Provider: Provide, Options: []opt.Option{ opt.Prefix(DefaultPrefix), }, } } func Provide(opts ...opt.Option) error { o := opt.New(opts...) var config Config if err := o.UnmarshalConfig(&config); err != nil { return err } return container.Container.Provide(func() (*Storage, error) { store := &Storage{Config: &config} if store.storageType() == "s3" { client, err := store.s3ClientForUse() if err != nil { return nil, err } if store.Config.CheckOnBoot { // 启动时可选检查 bucket 是否可用,便于尽早暴露配置问题。 exists, err := client.BucketExists(context.Background(), store.Config.Bucket) if err != nil { return nil, fmt.Errorf("storage bucket check failed: %w", err) } if !exists { return nil, fmt.Errorf("storage bucket not found: %s", store.Config.Bucket) } } } return store, nil }, o.DiOptions()...) } type Storage struct { Config *Config s3Client *minio.Client } func (s *Storage) Download(ctx context.Context, key, filePath string) error { if s.storageType() == "local" { localPath := s.Config.LocalPath if localPath == "" { localPath = "./storage" } srcPath := filepath.Join(localPath, key) if err := os.MkdirAll(filepath.Dir(filePath), 0o755); err != nil { return err } src, err := os.Open(srcPath) if err != nil { return err } defer src.Close() dst, err := os.Create(filePath) if err != nil { return err } defer dst.Close() if _, err := io.Copy(dst, src); err != nil { return err } return nil } client, err := s.s3ClientForUse() if err != nil { return err } if err := os.MkdirAll(filepath.Dir(filePath), 0o755); err != nil { return err } return client.FGetObject(ctx, s.Config.Bucket, key, filePath, minio.GetObjectOptions{}) } func (s *Storage) Delete(key string) error { if s.storageType() == "local" { localPath := s.Config.LocalPath if localPath == "" { localPath = "./storage" } path := filepath.Join(localPath, key) return os.Remove(path) } client, err := s.s3ClientForUse() if err != nil { return err } return client.RemoveObject(context.Background(), s.Config.Bucket, key, minio.RemoveObjectOptions{}) } func (s *Storage) SignURL(method, key string, expires time.Duration) (string, error) { if s.storageType() == "s3" { client, err := s.s3ClientForUse() if err != nil { return "", err } switch strings.ToUpper(method) { case "GET": u, err := client.PresignedGetObject(context.Background(), s.Config.Bucket, key, expires, nil) if err != nil { return "", err } return u.String(), nil case "PUT": u, err := client.PresignedPutObject(context.Background(), s.Config.Bucket, key, expires) if err != nil { return "", err } return u.String(), nil default: return "", fmt.Errorf("unsupported method") } } exp := time.Now().Add(expires).Unix() sign := s.signature(method, key, exp) baseURL := strings.TrimRight(s.Config.BaseURL, "/") // Ensure BaseURL doesn't end with slash if we add one // Simplified: assume standard /v1/storage prefix in BaseURL or append it // We'll append / u, err := url.Parse(baseURL + "/" + key) if err != nil { return "", err } q := u.Query() q.Set("expires", strconv.FormatInt(exp, 10)) q.Set("sign", sign) u.RawQuery = q.Encode() return u.String(), nil } func (s *Storage) Verify(method, key, expStr, sign string) error { if s.storageType() == "s3" { return fmt.Errorf("s3 storage does not use signed local urls") } exp, err := strconv.ParseInt(expStr, 10, 64) if err != nil { return fmt.Errorf("invalid expiry") } if time.Now().Unix() > exp { return fmt.Errorf("expired") } expected := s.signature(method, key, exp) if !hmac.Equal([]byte(expected), []byte(sign)) { return fmt.Errorf("invalid signature") } return nil } func (s *Storage) signature(method, key string, exp int64) string { str := fmt.Sprintf("%s\n%s\n%d", method, key, exp) h := hmac.New(sha256.New, []byte(s.Config.Secret)) h.Write([]byte(str)) return hex.EncodeToString(h.Sum(nil)) } func (s *Storage) PutObject(ctx context.Context, key, filePath, contentType string) error { if s.storageType() == "local" { localPath := s.Config.LocalPath if localPath == "" { localPath = "./storage" } dstPath := filepath.Join(localPath, key) if err := os.MkdirAll(filepath.Dir(dstPath), 0o755); err != nil { return err } return os.Rename(filePath, dstPath) } client, err := s.s3ClientForUse() if err != nil { return err } opts := minio.PutObjectOptions{} if contentType != "" { opts.ContentType = contentType } _, err = client.FPutObject(ctx, s.Config.Bucket, key, filePath, opts) return err } func (s *Storage) Provider() string { if s.storageType() == "s3" { return "s3" } return "local" } func (s *Storage) Bucket() string { if s.storageType() == "s3" && s.Config.Bucket != "" { return s.Config.Bucket } return "default" } func (s *Storage) storageType() string { if s.Config == nil { return "local" } typ := strings.TrimSpace(strings.ToLower(s.Config.Type)) if typ == "" { return "local" } return typ } func (s *Storage) s3ClientForUse() (*minio.Client, error) { if s.s3Client != nil { return s.s3Client, nil } if strings.TrimSpace(s.Config.Endpoint) == "" { return nil, fmt.Errorf("storage endpoint is required") } if strings.TrimSpace(s.Config.AccessKey) == "" || strings.TrimSpace(s.Config.SecretKey) == "" { return nil, fmt.Errorf("storage access key or secret key is required") } if strings.TrimSpace(s.Config.Bucket) == "" { return nil, fmt.Errorf("storage bucket is required") } endpoint, secure, err := parseEndpoint(s.Config.Endpoint) if err != nil { return nil, err } opts := &minio.Options{ Creds: credentials.NewStaticV4(s.Config.AccessKey, s.Config.SecretKey, ""), Secure: secure, Region: s.Config.Region, } if s.Config.PathStyle { opts.BucketLookup = minio.BucketLookupPath } else { opts.BucketLookup = minio.BucketLookupDNS } client, err := minio.New(endpoint, opts) if err != nil { return nil, err } s.s3Client = client return client, nil } func parseEndpoint(endpoint string) (string, bool, error) { if strings.HasPrefix(endpoint, "http://") || strings.HasPrefix(endpoint, "https://") { u, err := url.Parse(endpoint) if err != nil { return "", false, err } if u.Host == "" { return "", false, fmt.Errorf("invalid endpoint") } return u.Host, u.Scheme == "https", nil } return endpoint, false, nil }