337 lines
8.1 KiB
Go
337 lines
8.1 KiB
Go
package storage
|
|
|
|
import (
|
|
"context"
|
|
"crypto/hmac"
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net/url"
|
|
"os"
|
|
"path/filepath"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/minio/minio-go/v7"
|
|
"github.com/minio/minio-go/v7/pkg/credentials"
|
|
"go.ipao.vip/atom/container"
|
|
"go.ipao.vip/atom/opt"
|
|
)
|
|
|
|
const DefaultPrefix = "Storage"
|
|
|
|
var (
|
|
errStorageBucketNotFound = errors.New("storage bucket not found")
|
|
errStorageBucketCheckFailed = errors.New("storage bucket check failed")
|
|
errStorageEndpointRequired = errors.New("storage endpoint is required")
|
|
errStorageAccessKeyRequired = errors.New("storage access key or secret key is required")
|
|
errStorageBucketRequired = errors.New("storage bucket is required")
|
|
errStorageInvalidEndpoint = errors.New("storage endpoint is invalid")
|
|
errStorageUnsupportedMethod = errors.New("unsupported method")
|
|
errStorageSignedURLUnsupported = errors.New("s3 storage does not use signed local urls")
|
|
errStorageInvalidExpiry = errors.New("invalid expiry")
|
|
errStorageExpired = errors.New("expired")
|
|
errStorageInvalidSignature = errors.New("invalid signature")
|
|
)
|
|
|
|
func DefaultProvider() container.ProviderContainer {
|
|
return container.ProviderContainer{
|
|
Provider: Provide,
|
|
Options: []opt.Option{
|
|
opt.Prefix(DefaultPrefix),
|
|
},
|
|
}
|
|
}
|
|
|
|
func Provide(opts ...opt.Option) error {
|
|
o := opt.New(opts...)
|
|
var config Config
|
|
if err := o.UnmarshalConfig(&config); err != nil {
|
|
return err
|
|
}
|
|
|
|
return container.Container.Provide(func() (*Storage, error) {
|
|
store := &Storage{Config: &config}
|
|
if store.storageType() == "s3" {
|
|
client, err := store.s3ClientForUse()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if store.Config.CheckOnBoot {
|
|
// 启动时可选检查 bucket 是否可用,便于尽早暴露配置问题。
|
|
exists, err := client.BucketExists(context.Background(), store.Config.Bucket)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%w: %w", errStorageBucketCheckFailed, err)
|
|
}
|
|
if !exists {
|
|
return nil, fmt.Errorf("%w: %s", errStorageBucketNotFound, store.Config.Bucket)
|
|
}
|
|
}
|
|
}
|
|
|
|
return store, nil
|
|
}, o.DiOptions()...)
|
|
}
|
|
|
|
type Storage struct {
|
|
Config *Config
|
|
s3Client *minio.Client
|
|
}
|
|
|
|
func (s *Storage) Download(ctx context.Context, key, filePath string) error {
|
|
if s.storageType() == "local" {
|
|
localPath := s.Config.LocalPath
|
|
if localPath == "" {
|
|
localPath = "./storage"
|
|
}
|
|
srcPath := filepath.Join(localPath, key)
|
|
if err := os.MkdirAll(filepath.Dir(filePath), 0o755); err != nil {
|
|
return fmt.Errorf("create download dir: %w", err)
|
|
}
|
|
src, err := os.Open(srcPath)
|
|
if err != nil {
|
|
return fmt.Errorf("open source file: %w", err)
|
|
}
|
|
defer src.Close()
|
|
|
|
dst, err := os.Create(filePath)
|
|
if err != nil {
|
|
return fmt.Errorf("create destination file: %w", err)
|
|
}
|
|
defer dst.Close()
|
|
|
|
if _, err := io.Copy(dst, src); err != nil {
|
|
return fmt.Errorf("copy file content: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
client, err := s.s3ClientForUse()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if err := os.MkdirAll(filepath.Dir(filePath), 0o755); err != nil {
|
|
return fmt.Errorf("create download dir: %w", err)
|
|
}
|
|
if err := client.FGetObject(ctx, s.Config.Bucket, key, filePath, minio.GetObjectOptions{}); err != nil {
|
|
return fmt.Errorf("download object: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *Storage) Delete(key string) error {
|
|
if s.storageType() == "local" {
|
|
localPath := s.Config.LocalPath
|
|
if localPath == "" {
|
|
localPath = "./storage"
|
|
}
|
|
filePath := filepath.Join(localPath, key)
|
|
if err := os.Remove(filePath); err != nil {
|
|
return fmt.Errorf("remove local object: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
client, err := s.s3ClientForUse()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if err := client.RemoveObject(context.Background(), s.Config.Bucket, key, minio.RemoveObjectOptions{}); err != nil {
|
|
return fmt.Errorf("remove s3 object: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *Storage) SignURL(method, key string, expires time.Duration) (string, error) {
|
|
if s.storageType() == "s3" {
|
|
client, err := s.s3ClientForUse()
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
switch strings.ToUpper(method) {
|
|
case "GET":
|
|
u, err := client.PresignedGetObject(context.Background(), s.Config.Bucket, key, expires, nil)
|
|
if err != nil {
|
|
return "", fmt.Errorf("presign get object: %w", err)
|
|
}
|
|
|
|
return u.String(), nil
|
|
case "PUT":
|
|
u, err := client.PresignedPutObject(context.Background(), s.Config.Bucket, key, expires)
|
|
if err != nil {
|
|
return "", fmt.Errorf("presign put object: %w", err)
|
|
}
|
|
|
|
return u.String(), nil
|
|
default:
|
|
return "", errStorageUnsupportedMethod
|
|
}
|
|
}
|
|
|
|
exp := time.Now().Add(expires).Unix()
|
|
sign := s.signature(method, key, exp)
|
|
|
|
baseURL := strings.TrimRight(s.Config.BaseURL, "/")
|
|
|
|
u, err := url.Parse(baseURL + "/" + key)
|
|
if err != nil {
|
|
return "", fmt.Errorf("parse base url: %w", err)
|
|
}
|
|
|
|
q := u.Query()
|
|
q.Set("expires", strconv.FormatInt(exp, 10))
|
|
q.Set("sign", sign)
|
|
u.RawQuery = q.Encode()
|
|
|
|
return u.String(), nil
|
|
}
|
|
|
|
func (s *Storage) Verify(method, key, expStr, sign string) error {
|
|
if s.storageType() == "s3" {
|
|
return errStorageSignedURLUnsupported
|
|
}
|
|
exp, err := strconv.ParseInt(expStr, 10, 64)
|
|
if err != nil {
|
|
return errStorageInvalidExpiry
|
|
}
|
|
if time.Now().Unix() > exp {
|
|
return errStorageExpired
|
|
}
|
|
|
|
expected := s.signature(method, key, exp)
|
|
if !hmac.Equal([]byte(expected), []byte(sign)) {
|
|
return errStorageInvalidSignature
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *Storage) signature(method, key string, exp int64) string {
|
|
str := fmt.Sprintf("%s\n%s\n%d", method, key, exp)
|
|
h := hmac.New(sha256.New, []byte(s.Config.Secret))
|
|
h.Write([]byte(str))
|
|
|
|
return hex.EncodeToString(h.Sum(nil))
|
|
}
|
|
|
|
func (s *Storage) PutObject(ctx context.Context, key, filePath, contentType string) error {
|
|
if s.storageType() == "local" {
|
|
localPath := s.Config.LocalPath
|
|
if localPath == "" {
|
|
localPath = "./storage"
|
|
}
|
|
dstPath := filepath.Join(localPath, key)
|
|
if err := os.MkdirAll(filepath.Dir(dstPath), 0o755); err != nil {
|
|
return fmt.Errorf("create object dir: %w", err)
|
|
}
|
|
if err := os.Rename(filePath, dstPath); err != nil {
|
|
return fmt.Errorf("move object file: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
client, err := s.s3ClientForUse()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
opts := minio.PutObjectOptions{}
|
|
if contentType != "" {
|
|
opts.ContentType = contentType
|
|
}
|
|
if _, err := client.FPutObject(ctx, s.Config.Bucket, key, filePath, opts); err != nil {
|
|
return fmt.Errorf("upload object: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *Storage) Provider() string {
|
|
if s.storageType() == "s3" {
|
|
return "s3"
|
|
}
|
|
|
|
return "local"
|
|
}
|
|
|
|
func (s *Storage) Bucket() string {
|
|
if s.storageType() == "s3" && s.Config.Bucket != "" {
|
|
return s.Config.Bucket
|
|
}
|
|
|
|
return "default"
|
|
}
|
|
|
|
func (s *Storage) storageType() string {
|
|
if s.Config == nil {
|
|
return "local"
|
|
}
|
|
typ := strings.TrimSpace(strings.ToLower(s.Config.Type))
|
|
if typ == "" {
|
|
return "local"
|
|
}
|
|
|
|
return typ
|
|
}
|
|
|
|
func (s *Storage) s3ClientForUse() (*minio.Client, error) {
|
|
if s.s3Client != nil {
|
|
return s.s3Client, nil
|
|
}
|
|
if strings.TrimSpace(s.Config.Endpoint) == "" {
|
|
return nil, errStorageEndpointRequired
|
|
}
|
|
if strings.TrimSpace(s.Config.AccessKey) == "" || strings.TrimSpace(s.Config.SecretKey) == "" {
|
|
return nil, errStorageAccessKeyRequired
|
|
}
|
|
if strings.TrimSpace(s.Config.Bucket) == "" {
|
|
return nil, errStorageBucketRequired
|
|
}
|
|
|
|
endpoint, secure, err := parseEndpoint(s.Config.Endpoint)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
opts := &minio.Options{
|
|
Creds: credentials.NewStaticV4(s.Config.AccessKey, s.Config.SecretKey, ""),
|
|
Secure: secure,
|
|
Region: s.Config.Region,
|
|
}
|
|
if s.Config.PathStyle {
|
|
opts.BucketLookup = minio.BucketLookupPath
|
|
} else {
|
|
opts.BucketLookup = minio.BucketLookupDNS
|
|
}
|
|
|
|
client, err := minio.New(endpoint, opts)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
s.s3Client = client
|
|
|
|
return client, nil
|
|
}
|
|
|
|
func parseEndpoint(endpoint string) (string, bool, error) {
|
|
if strings.HasPrefix(endpoint, "http://") || strings.HasPrefix(endpoint, "https://") {
|
|
u, err := url.Parse(endpoint)
|
|
if err != nil {
|
|
return "", false, fmt.Errorf("parse endpoint: %w", err)
|
|
}
|
|
if u.Host == "" {
|
|
return "", false, errStorageInvalidEndpoint
|
|
}
|
|
|
|
return u.Host, u.Scheme == "https", nil
|
|
}
|
|
|
|
return endpoint, false, nil
|
|
}
|