diff --git a/03-web-crawler/solution/solution.go b/03-web-crawler/solution/solution.go index 227550e..738a5d2 100644 --- a/03-web-crawler/solution/solution.go +++ b/03-web-crawler/solution/solution.go @@ -1,7 +1,6 @@ package main import ( - "fmt" "sync" ) @@ -18,17 +17,14 @@ var ( // Crawl uses fetcher to recursively crawl // pages starting with url, to a maximum of depth. -func Crawl(url string, depth int, fetcher Fetcher) { - // TODO: Fetch URLs in parallel. - // TODO: Don't fetch the same URL twice. - // This implementation doesn't do either: +func Crawl(url string, depth int, fetcher Fetcher) ([]string, error) { if depth <= 0 { - return + return nil, nil } mu.RLock() if m[url] { mu.RUnlock() - return + return nil, nil } mu.RUnlock() @@ -38,72 +34,31 @@ func Crawl(url string, depth int, fetcher Fetcher) { body, urls, err := fetcher.Fetch(url) if err != nil { - fmt.Println(err) - return + return nil, err } - fmt.Printf("found: %s %q\n", url, body) + result := []string{body} + ch := make(chan string) var wg sync.WaitGroup for _, u := range urls { wg.Add(1) go func() { defer wg.Done() - Crawl(u, depth-1, fetcher) + if res, err := Crawl(u, depth-1, fetcher); err == nil { + for _, b := range res { + ch <- b + } + } }() } - wg.Wait() - return -} + go func() { + wg.Wait() + close(ch) + }() -func main() { - Crawl("https://golang.org/", 4, fetcher) -} - -// fakeFetcher is Fetcher that returns canned results. -type fakeFetcher map[string]*fakeResult - -type fakeResult struct { - body string - urls []string -} - -func (f fakeFetcher) Fetch(url string) (string, []string, error) { - if res, ok := f[url]; ok { - return res.body, res.urls, nil + for b := range ch { + result = append(result, b) } - return "", nil, fmt.Errorf("not found: %s", url) -} -// fetcher is a populated fakeFetcher. -var fetcher = fakeFetcher{ - "https://golang.org/": &fakeResult{ - "The Go Programming Language", - []string{ - "https://golang.org/pkg/", - "https://golang.org/cmd/", - }, - }, - "https://golang.org/pkg/": &fakeResult{ - "Packages", - []string{ - "https://golang.org/", - "https://golang.org/cmd/", - "https://golang.org/pkg/fmt/", - "https://golang.org/pkg/os/", - }, - }, - "https://golang.org/pkg/fmt/": &fakeResult{ - "Package fmt", - []string{ - "https://golang.org/", - "https://golang.org/pkg/", - }, - }, - "https://golang.org/pkg/os/": &fakeResult{ - "Package os", - []string{ - "https://golang.org/", - "https://golang.org/pkg/", - }, - }, + return result, nil } diff --git a/03-web-crawler/task.go b/03-web-crawler/task.go index e4c342f..ea7e4c4 100644 --- a/03-web-crawler/task.go +++ b/03-web-crawler/task.go @@ -1,9 +1,5 @@ package main -import ( - "fmt" -) - type Fetcher interface { // Fetch returns the body of URL and // a slice of URLs found on that page. @@ -12,74 +8,22 @@ type Fetcher interface { // Crawl uses fetcher to recursively crawl // pages starting with url, to a maximum of depth. -func Crawl(url string, depth int, fetcher Fetcher) { +func Crawl(url string, depth int, fetcher Fetcher) ([]string, error) { // TODO: Fetch URLs in parallel. // TODO: Don't fetch the same URL twice. // This implementation doesn't do either: if depth <= 0 { - return + return nil, nil } body, urls, err := fetcher.Fetch(url) if err != nil { - fmt.Println(err) - return + return nil, err } - fmt.Printf("found: %s %q\n", url, body) + result := []string{body} for _, u := range urls { - Crawl(u, depth-1, fetcher) + if res, err := Crawl(u, depth-1, fetcher); err == nil { + result = append(result, res...) + } } - return -} - -func main() { - Crawl("https://golang.org/", 4, fetcher) -} - -// fakeFetcher is Fetcher that returns canned results. -type fakeFetcher map[string]*fakeResult - -type fakeResult struct { - body string - urls []string -} - -func (f fakeFetcher) Fetch(url string) (string, []string, error) { - if res, ok := f[url]; ok { - return res.body, res.urls, nil - } - return "", nil, fmt.Errorf("not found: %s", url) -} - -// fetcher is a populated fakeFetcher. -var fetcher = fakeFetcher{ - "https://golang.org/": &fakeResult{ - "The Go Programming Language", - []string{ - "https://golang.org/pkg/", - "https://golang.org/cmd/", - }, - }, - "https://golang.org/pkg/": &fakeResult{ - "Packages", - []string{ - "https://golang.org/", - "https://golang.org/cmd/", - "https://golang.org/pkg/fmt/", - "https://golang.org/pkg/os/", - }, - }, - "https://golang.org/pkg/fmt/": &fakeResult{ - "Package fmt", - []string{ - "https://golang.org/", - "https://golang.org/pkg/", - }, - }, - "https://golang.org/pkg/os/": &fakeResult{ - "Package os", - []string{ - "https://golang.org/", - "https://golang.org/pkg/", - }, - }, + return result, nil } diff --git a/03-web-crawler/task_test.go b/03-web-crawler/task_test.go new file mode 100644 index 0000000..173ee5e --- /dev/null +++ b/03-web-crawler/task_test.go @@ -0,0 +1,93 @@ +package main + +import ( + "fmt" + "reflect" + "sort" + "testing" +) + +// fakeFetcher is Fetcher that returns canned results. +type fakeFetcher map[string]*fakeResult + +type fakeResult struct { + body string + urls []string +} + +func (f fakeFetcher) Fetch(url string) (string, []string, error) { + if res, ok := f[url]; ok { + return res.body, res.urls, nil + } + return "", nil, fmt.Errorf("not found: %s", url) +} + +func TestCrawl(t *testing.T) { + tests := []struct { + name string + url string + depths int + fetcher fakeFetcher + result []string + err error + }{ + { + name: "default", + url: "https://golang.org/", + depths: 4, + fetcher: fakeFetcher{ + "https://golang.org/": &fakeResult{ + "The Go Programming Language", + []string{ + "https://golang.org/pkg/", + "https://golang.org/cmd/", + }, + }, + "https://golang.org/pkg/": &fakeResult{ + "Packages", + []string{ + "https://golang.org/", + "https://golang.org/cmd/", + "https://golang.org/pkg/fmt/", + "https://golang.org/pkg/os/", + }, + }, + "https://golang.org/pkg/fmt/": &fakeResult{ + "Package fmt", + []string{ + "https://golang.org/", + "https://golang.org/pkg/", + }, + }, + "https://golang.org/pkg/os/": &fakeResult{ + "Package os", + []string{ + "https://golang.org/", + "https://golang.org/pkg/", + }, + }, + }, + result: []string{ + "The Go Programming Language", + "Packages", + "Package fmt", + "Package os", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := Crawl(tt.url, tt.depths, tt.fetcher) + if err != tt.err { + t.Error(err) + } + sort.Strings(result) + sort.Strings(tt.result) + + if !reflect.DeepEqual(tt.result, result) { + t.Errorf("Wrong result. Expected: %+q, Got: %+q", tt.result, result) + } + }) + } +}