Skip to content

Commit 5965217

Browse files
authored
Merge pull request #681 from WGH-/fix-revisit-on-redirects
Fix redirects ignoring AllowURLRevisit=false
2 parents cf68133 + 0be3b71 commit 5965217

File tree

2 files changed

+94
-10
lines changed

2 files changed

+94
-10
lines changed

colly.go

Lines changed: 57 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,26 @@ type ScrapedCallback func(*Response)
157157
// ProxyFunc is a type alias for proxy setter functions.
158158
type ProxyFunc func(*http.Request) (*url.URL, error)
159159

160+
// AlreadyVisitedError is the error type for already visited URLs.
161+
//
162+
// It's returned synchronously by Visit when the URL passed to Visit
163+
// is already visited.
164+
//
165+
// When already visited URL is encountered after following
166+
// redirects, this error appears in OnError callback, and if Async
167+
// mode is not enabled, is also returned by Visit.
168+
type AlreadyVisitedError struct {
169+
// Destination is the URL that was attempted to be visited.
170+
// It might not match the URL passed to Visit if redirect
171+
// was followed.
172+
Destination *url.URL
173+
}
174+
175+
// Error implements error interface.
176+
func (e *AlreadyVisitedError) Error() string {
177+
return fmt.Sprintf("%q already visited", e.Destination)
178+
}
179+
160180
type htmlCallbackContainer struct {
161181
Selector string
162182
Function HTMLCallback
@@ -196,8 +216,6 @@ var (
196216
// ErrNoURLFiltersMatch is the error thrown if visiting
197217
// a URL which is not allowed by URLFilters
198218
ErrNoURLFiltersMatch = errors.New("No URLFilters match")
199-
// ErrAlreadyVisited is the error type for already visited URLs
200-
ErrAlreadyVisited = errors.New("URL already visited")
201219
// ErrRobotsTxtBlocked is the error type for robots.txt errors
202220
ErrRobotsTxtBlocked = errors.New("URL blocked by robots.txt")
203221
// ErrNoCookieJar is the error type for missing cookie jar
@@ -603,7 +621,7 @@ func (c *Collector) scrape(u, method string, depth int, requestData io.Reader, c
603621
// note: once 1.13 is minimum supported Go version,
604622
// replace this with http.NewRequestWithContext
605623
req = req.WithContext(c.Context)
606-
if err := c.requestCheck(u, parsedURL, method, req.GetBody, depth, checkRevisit); err != nil {
624+
if err := c.requestCheck(parsedURL, method, req.GetBody, depth, checkRevisit); err != nil {
607625
return err
608626
}
609627
u = parsedURL.String()
@@ -694,10 +712,8 @@ func (c *Collector) fetch(u, method string, depth int, requestData io.Reader, ct
694712
return err
695713
}
696714

697-
func (c *Collector) requestCheck(u string, parsedURL *url.URL, method string, getBody func() (io.ReadCloser, error), depth int, checkRevisit bool) error {
698-
if u == "" {
699-
return ErrMissingURL
700-
}
715+
func (c *Collector) requestCheck(parsedURL *url.URL, method string, getBody func() (io.ReadCloser, error), depth int, checkRevisit bool) error {
716+
u := parsedURL.String()
701717
if c.MaxDepth > 0 && c.MaxDepth < depth {
702718
return ErrMaxDepth
703719
}
@@ -732,7 +748,7 @@ func (c *Collector) requestCheck(u string, parsedURL *url.URL, method string, ge
732748
return err
733749
}
734750
if visited {
735-
return ErrAlreadyVisited
751+
return &AlreadyVisitedError{parsedURL}
736752
}
737753
return c.store.Visited(uHash)
738754
}
@@ -1292,6 +1308,31 @@ func (c *Collector) checkRedirectFunc() func(req *http.Request, via []*http.Requ
12921308
if err := c.checkFilters(req.URL.String(), req.URL.Hostname()); err != nil {
12931309
return fmt.Errorf("Not following redirect to %q: %w", req.URL, err)
12941310
}
1311+
1312+
if !c.AllowURLRevisit {
1313+
var body io.ReadCloser
1314+
if req.GetBody != nil {
1315+
var err error
1316+
body, err = req.GetBody()
1317+
if err != nil {
1318+
return err
1319+
}
1320+
defer body.Close()
1321+
}
1322+
uHash := requestHash(req.URL.String(), body)
1323+
visited, err := c.store.IsVisited(uHash)
1324+
if err != nil {
1325+
return err
1326+
}
1327+
if visited {
1328+
return &AlreadyVisitedError{req.URL}
1329+
}
1330+
err = c.store.Visited(uHash)
1331+
if err != nil {
1332+
return err
1333+
}
1334+
}
1335+
12951336
if c.redirectHandler != nil {
12961337
return c.redirectHandler(req, via)
12971338
}
@@ -1442,7 +1483,14 @@ func isMatchingFilter(fs []*regexp.Regexp, d []byte) bool {
14421483

14431484
func requestHash(url string, body io.Reader) uint64 {
14441485
h := fnv.New64a()
1445-
h.Write([]byte(url))
1486+
// reparse the url to fix ambiguities such as
1487+
// "http://example.com" vs "http://example.com/"
1488+
parsedWhatwgURL, err := whatwgUrl.Parse(url)
1489+
if err == nil {
1490+
h.Write([]byte(parsedWhatwgURL.String()))
1491+
} else {
1492+
h.Write([]byte(url))
1493+
}
14461494
if body != nil {
14471495
io.Copy(h, body)
14481496
}

colly_test.go

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,11 @@ func newTestServer() *httptest.Server {
101101
})
102102

103103
mux.Handle("/redirect", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
104-
http.Redirect(w, r, "/redirected/", http.StatusSeeOther)
104+
destination := "/redirected/"
105+
if d := r.URL.Query().Get("d"); d != "" {
106+
destination = d
107+
}
108+
http.Redirect(w, r, destination, http.StatusSeeOther)
105109

106110
}))
107111

@@ -674,6 +678,38 @@ func TestCollectorURLRevisitCheck(t *testing.T) {
674678
if visited != true {
675679
t.Error("Expected URL to have been visited")
676680
}
681+
682+
errorTestCases := []struct {
683+
Path string
684+
DestinationError string
685+
}{
686+
{"/", "/"},
687+
{"/redirect?d=/", "/"},
688+
// now that /redirect?d=/ itself is recorded as visited,
689+
// it's now returned in error
690+
{"/redirect?d=/", "/redirect?d=/"},
691+
{"/redirect?d=/redirect%3Fd%3D/", "/redirect?d=/"},
692+
{"/redirect?d=/redirect%3Fd%3D/", "/redirect?d=/redirect%3Fd%3D/"},
693+
{"/redirect?d=/redirect%3Fd%3D/&foo=bar", "/redirect?d=/"},
694+
}
695+
696+
for i, testCase := range errorTestCases {
697+
err := c.Visit(ts.URL + testCase.Path)
698+
if testCase.DestinationError == "" {
699+
if err != nil {
700+
t.Errorf("got unexpected error in test %d: %q", i, err)
701+
}
702+
} else {
703+
var ave *AlreadyVisitedError
704+
if !errors.As(err, &ave) {
705+
t.Errorf("err=%q returned when trying to revisit, expected AlreadyVisitedError", err)
706+
} else {
707+
if got, want := ave.Destination.String(), ts.URL+testCase.DestinationError; got != want {
708+
t.Errorf("wrong destination in AlreadyVisitedError in test %d, got=%q want=%q", i, got, want)
709+
}
710+
}
711+
}
712+
}
677713
}
678714

679715
func TestCollectorPostURLRevisitCheck(t *testing.T) {

0 commit comments

Comments
 (0)