Skip to content

Commit cfcb8e1

Browse files
authored
Enable Client.RateLimits to bypass the rate limit check (#1907)
Fixes #1899.
1 parent 9bd4751 commit cfcb8e1

File tree

2 files changed

+66
-6
lines changed

2 files changed

+66
-6
lines changed

github/github.go

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -522,6 +522,12 @@ func parseRate(r *http.Response) Rate {
522522
return rate
523523
}
524524

525+
type requestContext uint8
526+
527+
const (
528+
bypassRateLimitCheck requestContext = iota
529+
)
530+
525531
// BareDo sends an API request and lets you handle the api response. If an error
526532
// or API Error occurs, the error will contain more information. Otherwise you
527533
// are supposed to read and close the response's Body. If rate limit is exceeded
@@ -538,12 +544,14 @@ func (c *Client) BareDo(ctx context.Context, req *http.Request) (*Response, erro
538544

539545
rateLimitCategory := category(req.URL.Path)
540546

541-
// If we've hit rate limit, don't make further requests before Reset time.
542-
if err := c.checkRateLimitBeforeDo(req, rateLimitCategory); err != nil {
543-
return &Response{
544-
Response: err.Response,
545-
Rate: err.Rate,
546-
}, err
547+
if bypass := ctx.Value(bypassRateLimitCheck); bypass == nil {
548+
// If we've hit rate limit, don't make further requests before Reset time.
549+
if err := c.checkRateLimitBeforeDo(req, rateLimitCategory); err != nil {
550+
return &Response{
551+
Response: err.Response,
552+
Rate: err.Rate,
553+
}, err
554+
}
547555
}
548556

549557
resp, err := c.client.Do(req)
@@ -1025,6 +1033,9 @@ func (c *Client) RateLimits(ctx context.Context) (*RateLimits, *Response, error)
10251033
response := new(struct {
10261034
Resources *RateLimits `json:"resources"`
10271035
})
1036+
1037+
// This resource is not subject to rate limits.
1038+
ctx = context.WithValue(ctx, bypassRateLimitCheck, true)
10281039
resp, err := c.Do(ctx, req, response)
10291040
if err != nil {
10301041
return nil, resp, err

github/github_test.go

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,9 @@ func testNewRequestAndDoFailure(t *testing.T, methodName string, client *Client,
201201
client.BaseURL.Path = "/api-v3/"
202202
client.rateLimits[0].Reset.Time = time.Now().Add(10 * time.Minute)
203203
resp, err = f()
204+
if bypass := resp.Request.Context().Value(bypassRateLimitCheck); bypass != nil {
205+
return
206+
}
204207
if want := http.StatusForbidden; resp == nil || resp.Response.StatusCode != want {
205208
if resp != nil {
206209
t.Errorf("rate.Reset.Time > now %v resp = %#v, want StatusCode=%v", methodName, resp.Response, want)
@@ -1678,6 +1681,52 @@ func TestRateLimits_coverage(t *testing.T) {
16781681
})
16791682
}
16801683

1684+
func TestRateLimits_overQuota(t *testing.T) {
1685+
client, mux, _, teardown := setup()
1686+
defer teardown()
1687+
1688+
client.rateLimits[coreCategory] = Rate{
1689+
Limit: 1,
1690+
Remaining: 0,
1691+
Reset: Timestamp{time.Now().Add(time.Hour).Local()},
1692+
}
1693+
mux.HandleFunc("/rate_limit", func(w http.ResponseWriter, r *http.Request) {
1694+
fmt.Fprint(w, `{"resources":{
1695+
"core": {"limit":2,"remaining":1,"reset":1372700873},
1696+
"search": {"limit":3,"remaining":2,"reset":1372700874}
1697+
}}`)
1698+
})
1699+
1700+
ctx := context.Background()
1701+
rate, _, err := client.RateLimits(ctx)
1702+
if err != nil {
1703+
t.Errorf("RateLimits returned error: %v", err)
1704+
}
1705+
1706+
want := &RateLimits{
1707+
Core: &Rate{
1708+
Limit: 2,
1709+
Remaining: 1,
1710+
Reset: Timestamp{time.Date(2013, time.July, 1, 17, 47, 53, 0, time.UTC).Local()},
1711+
},
1712+
Search: &Rate{
1713+
Limit: 3,
1714+
Remaining: 2,
1715+
Reset: Timestamp{time.Date(2013, time.July, 1, 17, 47, 54, 0, time.UTC).Local()},
1716+
},
1717+
}
1718+
if !cmp.Equal(rate, want) {
1719+
t.Errorf("RateLimits returned %+v, want %+v", rate, want)
1720+
}
1721+
1722+
if got, want := client.rateLimits[coreCategory], *want.Core; got != want {
1723+
t.Errorf("client.rateLimits[coreCategory] is %+v, want %+v", got, want)
1724+
}
1725+
if got, want := client.rateLimits[searchCategory], *want.Search; got != want {
1726+
t.Errorf("client.rateLimits[searchCategory] is %+v, want %+v", got, want)
1727+
}
1728+
}
1729+
16811730
func TestSetCredentialsAsHeaders(t *testing.T) {
16821731
req := new(http.Request)
16831732
id, secret := "id", "secret"

0 commit comments

Comments
 (0)