diff --git a/templates/project/providers/req/client.go.tpl b/templates/project/providers/req/client.go.tpl index 4ada235..dec20e8 100644 --- a/templates/project/providers/req/client.go.tpl +++ b/templates/project/providers/req/client.go.tpl @@ -1,19 +1,21 @@ package req import ( - "net/http" - "net/url" - "os" - "path/filepath" - "strconv" - "strings" - "time" + "context" + "io" + "net/http" + "net/url" + "os" + "path/filepath" + "strconv" + "strings" + "time" - "{{.ModuleName}}/providers/req/cookiejar" + "test/providers/req/cookiejar" - "github.com/imroc/req/v3" - "go.ipao.vip/atom/container" - "go.ipao.vip/atom/opt" + "github.com/imroc/req/v3" + "go.ipao.vip/atom/container" + "go.ipao.vip/atom/opt" ) type Client struct { @@ -62,16 +64,25 @@ func Provide(opts ...opt.Option) error { client.EnableInsecureSkipVerify() } - if config.UserAgent != "" { - client.SetUserAgent(config.UserAgent) - } - if config.Timeout > 0 { - client.SetTimeout(time.Duration(config.Timeout) * time.Second) - } + if config.UserAgent != "" { + client.SetUserAgent(config.UserAgent) + } + if config.BaseURL != "" { + client.SetBaseURL(config.BaseURL) + } + if config.Timeout > 0 { + client.SetTimeout(time.Duration(config.Timeout) * time.Second) + } - if config.CommonHeaders != nil { - client.SetCommonHeaders(config.CommonHeaders) - } + if config.CommonHeaders != nil { + client.SetCommonHeaders(config.CommonHeaders) + } + if config.CommonQuery != nil { + client.SetCommonQueryParams(config.CommonQuery) + } + if config.ContentType != "" { + client.SetCommonContentType(config.ContentType) + } if config.AuthBasic.Username != "" && config.AuthBasic.Password != "" { client.SetCommonBasicAuth(config.AuthBasic.Username, config.AuthBasic.Password) @@ -89,9 +100,12 @@ func Provide(opts ...opt.Option) error { client.SetRedirectPolicy(parsePolicies(config.RedirectPolicy)...) } - c.client = client - return c, nil - }, o.DiOptions()...) + c.client = client + if c.jar != nil { + container.AddCloseAble(func() { _ = c.jar.Save() }) + } + return c, nil + }, o.DiOptions()...) } func parsePolicies(policies []string) []req.RedirectPolicy { @@ -126,7 +140,11 @@ func parsePolicies(policies []string) []req.RedirectPolicy { } func (c *Client) R() *req.Request { - return c.client.R() + return c.client.R() +} + +func (c *Client) RWithCtx(ctx context.Context) *req.Request { + return c.client.R().SetContext(ctx) } func (c *Client) SaveCookJar() error { @@ -148,5 +166,78 @@ func (c *Client) AllCookiesKV() map[string]string { } func (c *Client) SetCookie(u *url.URL, cookies []*http.Cookie) { - c.jar.SetCookies(u, cookies) + c.jar.SetCookies(u, cookies) +} + +func (c *Client) DoJSON(ctx context.Context, method, url string, in any, out any) error { + r := c.RWithCtx(ctx) + if in != nil { + r.SetBody(in) + } + if out != nil { + r.SetSuccessResult(out) + } + var resp *req.Response + var err error + switch strings.ToUpper(method) { + case http.MethodGet: + resp, err = r.Get(url) + case http.MethodPost: + resp, err = r.Post(url) + case http.MethodPut: + resp, err = r.Put(url) + case http.MethodPatch: + resp, err = r.Patch(url) + case http.MethodDelete: + resp, err = r.Delete(url) + default: + resp, err = r.Send(method, url) + } + if err != nil { + return err + } + return resp.Err +} + +func (c *Client) GetJSON(ctx context.Context, url string, out any, query map[string]string) error { + r := c.RWithCtx(ctx) + if query != nil { + r.SetQueryParams(query) + } + r.SetSuccessResult(out) + resp, err := r.Get(url) + if err != nil { + return err + } + return resp.Err +} + +func (c *Client) PostJSON(ctx context.Context, url string, in any, out any) error { + r := c.RWithCtx(ctx) + if in != nil { + r.SetBody(in) + } + if out != nil { + r.SetSuccessResult(out) + } + resp, err := r.Post(url) + if err != nil { + return err + } + return resp.Err +} + +func (c *Client) Download(ctx context.Context, url, filepath string) error { + r := c.RWithCtx(ctx) + resp, err := r.Get(url) + if err != nil { + return err + } + f, err := os.Create(filepath) + if err != nil { + return err + } + defer f.Close() + _, err = io.Copy(f, resp.Body) + return err } diff --git a/templates/project/providers/req/config.go.tpl b/templates/project/providers/req/config.go.tpl index 6967681..63361a8 100644 --- a/templates/project/providers/req/config.go.tpl +++ b/templates/project/providers/req/config.go.tpl @@ -17,18 +17,21 @@ func DefaultProvider() container.ProviderContainer { } type Config struct { - DevMode bool - CookieJarFile string - RootCa []string - UserAgent string - InsecureSkipVerify bool - CommonHeaders map[string]string - Timeout uint - AuthBasic struct { - Username string - Password string - } - AuthBearerToken string - ProxyURL string - RedirectPolicy []string // "Max:10;No;SameDomain;SameHost;AllowedHost:x,x,x,x,x,AllowedDomain:x,x,x,x,x" + DevMode bool + CookieJarFile string + RootCa []string + UserAgent string + InsecureSkipVerify bool + CommonHeaders map[string]string + CommonQuery map[string]string + BaseURL string + ContentType string + Timeout uint + AuthBasic struct { + Username string + Password string + } + AuthBearerToken string + ProxyURL string + RedirectPolicy []string // "Max:10;No;SameDomain;SameHost;AllowedHost:x,x,x,x,x,AllowedDomain:x,x,x,x,x" }