Files
quyun-v2/backend/providers/storage/provider.go

293 lines
6.8 KiB
Go

package storage
import (
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"net/url"
"os"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/minio/minio-go/v7"
"github.com/minio/minio-go/v7/pkg/credentials"
"go.ipao.vip/atom/container"
"go.ipao.vip/atom/opt"
)
const DefaultPrefix = "Storage"
func DefaultProvider() container.ProviderContainer {
return container.ProviderContainer{
Provider: Provide,
Options: []opt.Option{
opt.Prefix(DefaultPrefix),
},
}
}
func Provide(opts ...opt.Option) error {
o := opt.New(opts...)
var config Config
if err := o.UnmarshalConfig(&config); err != nil {
return err
}
return container.Container.Provide(func() (*Storage, error) {
store := &Storage{Config: &config}
if store.storageType() == "s3" {
client, err := store.s3ClientForUse()
if err != nil {
return nil, err
}
if store.Config.CheckOnBoot {
// 启动时可选检查 bucket 是否可用,便于尽早暴露配置问题。
exists, err := client.BucketExists(context.Background(), store.Config.Bucket)
if err != nil {
return nil, fmt.Errorf("storage bucket check failed: %w", err)
}
if !exists {
return nil, fmt.Errorf("storage bucket not found: %s", store.Config.Bucket)
}
}
}
return store, nil
}, o.DiOptions()...)
}
type Storage struct {
Config *Config
s3Client *minio.Client
}
func (s *Storage) Download(ctx context.Context, key, filePath string) error {
if s.storageType() == "local" {
localPath := s.Config.LocalPath
if localPath == "" {
localPath = "./storage"
}
srcPath := filepath.Join(localPath, key)
if err := os.MkdirAll(filepath.Dir(filePath), 0o755); err != nil {
return err
}
src, err := os.Open(srcPath)
if err != nil {
return err
}
defer src.Close()
dst, err := os.Create(filePath)
if err != nil {
return err
}
defer dst.Close()
if _, err := io.Copy(dst, src); err != nil {
return err
}
return nil
}
client, err := s.s3ClientForUse()
if err != nil {
return err
}
if err := os.MkdirAll(filepath.Dir(filePath), 0o755); err != nil {
return err
}
return client.FGetObject(ctx, s.Config.Bucket, key, filePath, minio.GetObjectOptions{})
}
func (s *Storage) Delete(key string) error {
if s.storageType() == "local" {
localPath := s.Config.LocalPath
if localPath == "" {
localPath = "./storage"
}
path := filepath.Join(localPath, key)
return os.Remove(path)
}
client, err := s.s3ClientForUse()
if err != nil {
return err
}
return client.RemoveObject(context.Background(), s.Config.Bucket, key, minio.RemoveObjectOptions{})
}
func (s *Storage) SignURL(method, key string, expires time.Duration) (string, error) {
if s.storageType() == "s3" {
client, err := s.s3ClientForUse()
if err != nil {
return "", err
}
switch strings.ToUpper(method) {
case "GET":
u, err := client.PresignedGetObject(context.Background(), s.Config.Bucket, key, expires, nil)
if err != nil {
return "", err
}
return u.String(), nil
case "PUT":
u, err := client.PresignedPutObject(context.Background(), s.Config.Bucket, key, expires)
if err != nil {
return "", err
}
return u.String(), nil
default:
return "", fmt.Errorf("unsupported method")
}
}
exp := time.Now().Add(expires).Unix()
sign := s.signature(method, key, exp)
baseURL := strings.TrimRight(s.Config.BaseURL, "/")
// Ensure BaseURL doesn't end with slash if we add one
// Simplified: assume standard /v1/storage prefix in BaseURL or append it
// We'll append /<key>
u, err := url.Parse(baseURL + "/" + key)
if err != nil {
return "", err
}
q := u.Query()
q.Set("expires", strconv.FormatInt(exp, 10))
q.Set("sign", sign)
u.RawQuery = q.Encode()
return u.String(), nil
}
func (s *Storage) Verify(method, key, expStr, sign string) error {
if s.storageType() == "s3" {
return fmt.Errorf("s3 storage does not use signed local urls")
}
exp, err := strconv.ParseInt(expStr, 10, 64)
if err != nil {
return fmt.Errorf("invalid expiry")
}
if time.Now().Unix() > exp {
return fmt.Errorf("expired")
}
expected := s.signature(method, key, exp)
if !hmac.Equal([]byte(expected), []byte(sign)) {
return fmt.Errorf("invalid signature")
}
return nil
}
func (s *Storage) signature(method, key string, exp int64) string {
str := fmt.Sprintf("%s\n%s\n%d", method, key, exp)
h := hmac.New(sha256.New, []byte(s.Config.Secret))
h.Write([]byte(str))
return hex.EncodeToString(h.Sum(nil))
}
func (s *Storage) PutObject(ctx context.Context, key, filePath, contentType string) error {
if s.storageType() == "local" {
localPath := s.Config.LocalPath
if localPath == "" {
localPath = "./storage"
}
dstPath := filepath.Join(localPath, key)
if err := os.MkdirAll(filepath.Dir(dstPath), 0o755); err != nil {
return err
}
return os.Rename(filePath, dstPath)
}
client, err := s.s3ClientForUse()
if err != nil {
return err
}
opts := minio.PutObjectOptions{}
if contentType != "" {
opts.ContentType = contentType
}
_, err = client.FPutObject(ctx, s.Config.Bucket, key, filePath, opts)
return err
}
func (s *Storage) Provider() string {
if s.storageType() == "s3" {
return "s3"
}
return "local"
}
func (s *Storage) Bucket() string {
if s.storageType() == "s3" && s.Config.Bucket != "" {
return s.Config.Bucket
}
return "default"
}
func (s *Storage) storageType() string {
if s.Config == nil {
return "local"
}
typ := strings.TrimSpace(strings.ToLower(s.Config.Type))
if typ == "" {
return "local"
}
return typ
}
func (s *Storage) s3ClientForUse() (*minio.Client, error) {
if s.s3Client != nil {
return s.s3Client, nil
}
if strings.TrimSpace(s.Config.Endpoint) == "" {
return nil, fmt.Errorf("storage endpoint is required")
}
if strings.TrimSpace(s.Config.AccessKey) == "" || strings.TrimSpace(s.Config.SecretKey) == "" {
return nil, fmt.Errorf("storage access key or secret key is required")
}
if strings.TrimSpace(s.Config.Bucket) == "" {
return nil, fmt.Errorf("storage bucket is required")
}
endpoint, secure, err := parseEndpoint(s.Config.Endpoint)
if err != nil {
return nil, err
}
opts := &minio.Options{
Creds: credentials.NewStaticV4(s.Config.AccessKey, s.Config.SecretKey, ""),
Secure: secure,
Region: s.Config.Region,
}
if s.Config.PathStyle {
opts.BucketLookup = minio.BucketLookupPath
} else {
opts.BucketLookup = minio.BucketLookupDNS
}
client, err := minio.New(endpoint, opts)
if err != nil {
return nil, err
}
s.s3Client = client
return client, nil
}
func parseEndpoint(endpoint string) (string, bool, error) {
if strings.HasPrefix(endpoint, "http://") || strings.HasPrefix(endpoint, "https://") {
u, err := url.Parse(endpoint)
if err != nil {
return "", false, err
}
if u.Host == "" {
return "", false, fmt.Errorf("invalid endpoint")
}
return u.Host, u.Scheme == "https", nil
}
return endpoint, false, nil
}