Fix API path style issues on UDM Pro API proxy

This commit is contained in:
Paul Tyng
2020-09-24 16:57:04 -04:00
parent 16c246525b
commit 3d37110380

View File

@@ -6,9 +6,20 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"net/http" "net/http"
"net/http/cookiejar" "net/http/cookiejar"
"net/url" "net/url"
"path"
"strings"
)
const (
apiPath = "/api"
apiPathNew = "/proxy/network/api"
loginPath = "/api/login"
loginPathNew = "/api/auth/login"
) )
type NotFoundError struct{} type NotFoundError struct{}
@@ -29,6 +40,9 @@ func (err *APIError) Error() string {
type Client struct { type Client struct {
c *http.Client c *http.Client
baseURL *url.URL baseURL *url.URL
apiPath string
loginPath string
} }
func (c *Client) SetBaseURL(base string) error { func (c *Client) SetBaseURL(base string) error {
@@ -37,6 +51,12 @@ func (c *Client) SetBaseURL(base string) error {
if err != nil { if err != nil {
return err return err
} }
// error for people who are still passing hard coded old paths
if path := strings.TrimSuffix(c.baseURL.Path, "/"); path == apiPath {
return fmt.Errorf("expected a base URL without the `/api`, got: %q", c.baseURL)
}
return nil return nil
} }
@@ -45,6 +65,46 @@ func (c *Client) SetHTTPClient(hc *http.Client) error {
return nil return nil
} }
func (c *Client) setAPIUrlStyle(ctx context.Context) error {
// check if new style API
// this is modified from the unifi-poller (https://github.com/unifi-poller/unifi) implementation.
// see https://github.com/unifi-poller/unifi/blob/4dc44f11f61a2e08bf7ec5b20c71d5bced837b5d/unifi.go#L101-L104
// and https://github.com/unifi-poller/unifi/commit/43a6b225031a28f2b358f52d03a7217c7b524143
req, err := http.NewRequestWithContext(ctx, "GET", c.baseURL.String(), nil)
if err != nil {
return err
}
// We can't share these cookies with other requests, so make a new client.
// Checking the return code on the first request so don't follow a redirect.
client := &http.Client{
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
Transport: c.c.Transport,
}
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
_, _ = io.Copy(ioutil.Discard, resp.Body)
if resp.StatusCode == http.StatusOK {
// the new API returns a 200 for a / request
c.apiPath = apiPathNew
c.loginPath = loginPathNew
return nil
}
// The old version returns a "302" (to /manage) for a / request
c.apiPath = apiPath
c.loginPath = loginPath
return nil
}
func (c *Client) Login(ctx context.Context, user, pass string) error { func (c *Client) Login(ctx context.Context, user, pass string) error {
if c.c == nil { if c.c == nil {
c.c = &http.Client{} c.c = &http.Client{}
@@ -53,7 +113,12 @@ func (c *Client) Login(ctx context.Context, user, pass string) error {
c.c.Jar = jar c.c.Jar = jar
} }
err := c.do(ctx, "POST", "login", &struct { err := c.setAPIUrlStyle(ctx)
if err != nil {
return fmt.Errorf("unable to determine API URL style: %w", err)
}
err = c.do(ctx, "POST", c.loginPath, &struct {
Username string `json:"username"` Username string `json:"username"`
Password string `json:"password"` Password string `json:"password"`
}{ }{
@@ -74,7 +139,6 @@ func (c *Client) do(ctx context.Context, method, relativeURL string, reqBody int
reqBytes []byte reqBytes []byte
) )
if reqBody != nil { if reqBody != nil {
reqBytes, err = json.Marshal(reqBody) reqBytes, err = json.Marshal(reqBody)
if err != nil { if err != nil {
return fmt.Errorf("unable to marshal JSON: %s %s %w", method, relativeURL, err) return fmt.Errorf("unable to marshal JSON: %s %s %w", method, relativeURL, err)
@@ -86,9 +150,11 @@ func (c *Client) do(ctx context.Context, method, relativeURL string, reqBody int
if err != nil { if err != nil {
return fmt.Errorf("unable to parse URL: %s %s %w", method, relativeURL, err) return fmt.Errorf("unable to parse URL: %s %s %w", method, relativeURL, err)
} }
if !strings.HasPrefix(relativeURL, "/") && !reqURL.IsAbs() {
reqURL.Path = path.Join(c.apiPath, reqURL.Path)
}
url := c.baseURL.ResolveReference(reqURL) url := c.baseURL.ResolveReference(reqURL)
req, err := http.NewRequestWithContext(ctx, method, url.String(), reqReader) req, err := http.NewRequestWithContext(ctx, method, url.String(), reqReader)
if err != nil { if err != nil {
return fmt.Errorf("unable to create request: %s %s %w", method, relativeURL, err) return fmt.Errorf("unable to create request: %s %s %w", method, relativeURL, err)