Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 81 additions & 33 deletions github/github.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@ var ErrPathForbidden = errors.New("path must not contain '..' due to auth vulner
type Client struct {
client *http.Client // HTTP client used to communicate with the API.
clientIgnoreRedirects *http.Client // HTTP client used to communicate with the API on endpoints where we don't want to follow redirects.
transport http.RoundTripper
token *string

// Base URL for API requests. Defaults to the public GitHub API, but can be
// set to a domain endpoint to use with GitHub Enterprise. baseURL should
Expand Down Expand Up @@ -255,8 +257,8 @@ type service struct {
}

// Client returns the http.Client used by this GitHub client.
// This should only be used for requests to the GitHub API because
// request headers will contain an authorization token.
// This should only be used for requests to the configured GitHub API or upload
// URLs because requests to those origins may contain an authorization token.
func (c *Client) Client() *http.Client {
clientCopy := *c.client
return &clientCopy
Expand Down Expand Up @@ -434,7 +436,8 @@ func WithEnvProxy() ClientOptionsFunc {
}

// WithAuthToken returns a ClientOptionsFunc that sets the authentication token
// for a Client. If not set, the client will make unauthenticated requests.
// for requests to the Client's configured API and upload URLs. If not set, the
// client will make unauthenticated requests.
func WithAuthToken(token string) ClientOptionsFunc {
return func(o *clientOptions) error {
if token == "" {
Expand Down Expand Up @@ -602,25 +605,6 @@ func newClient(opts clientOptions) (*Client, error) {
c.client.Transport = t2
}

if opts.token != nil {
transport := c.client.Transport
if transport == nil {
transport = http.DefaultTransport
}
c.client.Transport = roundTripperFunc(func(req *http.Request) (*http.Response, error) {
req = req.Clone(req.Context())
req.Header.Set("Authorization", fmt.Sprintf("Bearer %v", *opts.token))
return transport.RoundTrip(req)
})
}

c.clientIgnoreRedirects = &http.Client{
Transport: c.client.Transport,
Timeout: c.client.Timeout,
Jar: c.client.Jar,
CheckRedirect: func(*http.Request, []*http.Request) error { return http.ErrUseLastResponse },
}

if opts.apiVersionMin != nil {
c.apiVersionMin = *opts.apiVersionMin
}
Expand All @@ -647,6 +631,33 @@ func newClient(opts clientOptions) (*Client, error) {
c.uploadURL, _ = url.Parse(uploadBaseURL)
}

c.transport = c.client.Transport
if opts.token != nil {
transport := c.client.Transport
if transport == nil {
transport = http.DefaultTransport
}
token := "Bearer " + *opts.token
baseURL := c.baseURL
uploadURL := c.uploadURL
c.token = Ptr(*opts.token)
c.transport = transport
c.client.Transport = roundTripperFunc(func(req *http.Request) (*http.Response, error) {
if shouldAuthorizeURL(req.URL, baseURL, uploadURL) {
req = req.Clone(req.Context())
req.Header.Set("Authorization", token)
}
return transport.RoundTrip(req)
})
}

c.clientIgnoreRedirects = &http.Client{
Transport: c.client.Transport,
Timeout: c.client.Timeout,
Jar: c.client.Jar,
CheckRedirect: func(*http.Request, []*http.Request) error { return http.ErrUseLastResponse },
}

c.disableRateLimitCheck = opts.disableRateLimitCheck

if !c.disableRateLimitCheck {
Expand Down Expand Up @@ -705,6 +716,38 @@ func newClient(opts clientOptions) (*Client, error) {
return c, nil
}

func shouldAuthorizeURL(u, baseURL, uploadURL *url.URL) bool {
if u == nil {
return false
}

return sameOrigin(u, baseURL) || sameOrigin(u, uploadURL)
}

func sameOrigin(u, base *url.URL) bool {
if u == nil || base == nil {
return false
}

return strings.EqualFold(u.Scheme, base.Scheme) &&
strings.EqualFold(u.Hostname(), base.Hostname()) &&
defaultPort(u) == defaultPort(base)
}

func defaultPort(u *url.URL) string {
if port := u.Port(); port != "" {
return port
}
switch strings.ToLower(u.Scheme) {
case "http":
return "80"
case "https":
return "443"
default:
return ""
}
}

// UserAgent returns the User-Agent header value for the client.
func (c *Client) UserAgent() string {
return c.userAgent
Expand Down Expand Up @@ -744,6 +787,7 @@ func (c *Client) Clone(opts ...ClientOptionsFunc) (*Client, error) {
userAgent: &c.userAgent,
baseURL: Ptr(*c.baseURL),
uploadURL: Ptr(*c.uploadURL),
token: c.token,
disableRateLimitCheck: c.disableRateLimitCheck,
rateLimitRedirectionalEndpoints: c.rateLimitRedirectionalEndpoints,
maxSecondaryRateLimitRetryAfterDuration: &c.maxSecondaryRateLimitRetryAfterDuration,
Expand All @@ -760,8 +804,12 @@ func (c *Client) Clone(opts ...ClientOptionsFunc) (*Client, error) {
}

if o.httpClient == nil {
transport := c.client.Transport
if c.token != nil {
transport = c.transport
}
o.httpClient = &http.Client{
Transport: c.client.Transport,
Transport: transport,
CheckRedirect: c.client.CheckRedirect,
Jar: c.client.Jar,
Timeout: c.client.Timeout,
Expand Down Expand Up @@ -1366,11 +1414,11 @@ func (c *Client) bareDoUntilFound(req *http.Request, maxRedirects int) (*url.URL
return nil, nil, errInvalidLocation
}
newURL := c.baseURL.ResolveReference(rerr.Location)
// Refuse to follow a permanent redirect to a different host:
// Refuse to follow a permanent redirect to a different origin:
// req.Clone preserves Authorization headers added by the auth
// transport, so a cross-host target would leak credentials.
if newURL.Host != c.baseURL.Host {
return nil, response, fmt.Errorf("refusing to follow cross-host redirect from %q to %q", c.baseURL.Host, newURL.Host)
// transport, so a cross-origin target would leak credentials.
if !sameOrigin(newURL, c.baseURL) {
return nil, response, fmt.Errorf("refusing to follow cross-origin redirect from %q to %q", c.baseURL.Host, newURL.Host)
}
newRequest := req.Clone(req.Context())
newRequest.URL = newURL
Expand Down Expand Up @@ -2137,20 +2185,20 @@ func (c *Client) roundTripWithOptionalFollowRedirect(ctx context.Context, u stri
if maxRedirects > 0 && resp.StatusCode == http.StatusMovedPermanently {
_ = resp.Body.Close()
u = resp.Header.Get("Location")
if err := c.checkRedirectHost(u); err != nil {
if err := c.checkRedirectOrigin(u); err != nil {
return nil, err
}
resp, err = c.roundTripWithOptionalFollowRedirect(ctx, u, maxRedirects-1, opts...)
}
return resp, err
}

// checkRedirectHost returns an error if the redirect target is on a different
// host than the client's configured BaseURL. This prevents credentials attached
// checkRedirectOrigin returns an error if the redirect target is on a different
// origin than the client's configured BaseURL. This prevents credentials attached
// by the auth transport from being sent to an attacker-controlled host when a
// compromised or malicious API response returns a cross-origin Location header.
// An empty Location is also rejected.
func (c *Client) checkRedirectHost(location string) error {
func (c *Client) checkRedirectOrigin(location string) error {
if location == "" {
return errInvalidLocation
}
Expand All @@ -2160,8 +2208,8 @@ func (c *Client) checkRedirectHost(location string) error {
}
// Resolve relative locations against BaseURL so relative paths are allowed.
target = c.baseURL.ResolveReference(target)
if target.Host != c.baseURL.Host {
return fmt.Errorf("refusing to follow cross-host redirect from %q to %q", c.baseURL.Host, target.Host)
if !sameOrigin(target, c.baseURL) {
return fmt.Errorf("refusing to follow cross-origin redirect from %q to %q", c.baseURL.Host, target.Host)
}
return nil
}
Expand Down
Loading
Loading