diff options
| author | Fuwn <[email protected]> | 2026-02-10 01:59:01 -0800 |
|---|---|---|
| committer | Fuwn <[email protected]> | 2026-02-10 01:59:01 -0800 |
| commit | 871985bc9eb42c6a088563e7c34db181f603f407 (patch) | |
| tree | 31299597a9f246d332b3bf6d5e2bed177648b577 /services/worker | |
| parent | feat: reorder feature grid by attention-grabbing impact (diff) | |
| download | asa.news-871985bc9eb42c6a088563e7c34db181f603f407.tar.xz asa.news-871985bc9eb42c6a088563e7c34db181f603f407.zip | |
fix: harden CI and close remaining test/security gaps
- Make webhook URL tests deterministic with injectable DNS resolver
- Wire tier parity checker into CI and root scripts
- Add rate_limits cleanup cron job (hourly, >1hr retention)
- Change rate limiter to fail closed on RPC error
- Add Go worker tests: parser, SSRF protection, error classification,
authentication, and worker pool (48 test functions)
Diffstat (limited to 'services/worker')
| -rw-r--r-- | services/worker/go.mod | 2 | ||||
| -rw-r--r-- | services/worker/internal/fetcher/authentication_test.go | 90 | ||||
| -rw-r--r-- | services/worker/internal/fetcher/errors_test.go | 169 | ||||
| -rw-r--r-- | services/worker/internal/fetcher/ssrf_protection_test.go | 114 | ||||
| -rw-r--r-- | services/worker/internal/parser/parser_test.go | 287 | ||||
| -rw-r--r-- | services/worker/internal/pool/pool_test.go | 125 | ||||
| -rw-r--r-- | services/worker/internal/writer/writer.go | 4 |
7 files changed, 788 insertions, 3 deletions
diff --git a/services/worker/go.mod b/services/worker/go.mod index 2588959..332a44d 100644 --- a/services/worker/go.mod +++ b/services/worker/go.mod @@ -8,6 +8,7 @@ require ( github.com/craigpastro/pgmq-go v0.6.0 github.com/jackc/pgx/v5 v5.7.2 github.com/mmcdole/gofeed v1.3.0 + github.com/stretchr/testify v1.11.1 ) require ( @@ -58,7 +59,6 @@ require ( github.com/shirou/gopsutil/v3 v3.24.5 // indirect github.com/shoenig/go-m1cpu v0.1.6 // indirect github.com/sirupsen/logrus v1.9.3 // indirect - github.com/stretchr/testify v1.11.1 // indirect github.com/testcontainers/testcontainers-go v0.35.0 // indirect github.com/tklauser/go-sysconf v0.3.14 // indirect github.com/tklauser/numcpus v0.9.0 // indirect diff --git a/services/worker/internal/fetcher/authentication_test.go b/services/worker/internal/fetcher/authentication_test.go new file mode 100644 index 0000000..bc840b9 --- /dev/null +++ b/services/worker/internal/fetcher/authentication_test.go @@ -0,0 +1,90 @@ +package fetcher + +import ( + "encoding/base64" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "net/http" + "testing" +) + +func TestApplyBearerAuthentication(test *testing.T) { + request, _ := http.NewRequest(http.MethodGet, "https://example.com/feed", nil) + authenticationError := ApplyAuthentication(request, AuthenticationConfiguration{ + AuthenticationType: "bearer", + AuthenticationValue: "my-secret-token", + }) + + require.NoError(test, authenticationError) + assert.Equal(test, "Bearer my-secret-token", request.Header.Get("Authorization")) +} + +func TestApplyBasicAuthentication(test *testing.T) { + request, _ := http.NewRequest(http.MethodGet, "https://example.com/feed", nil) + authenticationError := ApplyAuthentication(request, AuthenticationConfiguration{ + AuthenticationType: "basic", + AuthenticationValue: "user:pass", + }) + + require.NoError(test, authenticationError) + + expectedEncoded := base64.StdEncoding.EncodeToString([]byte("user:pass")) + + assert.Equal(test, "Basic "+expectedEncoded, request.Header.Get("Authorization")) +} + +func TestApplyQueryParamAuthentication(test *testing.T) { + request, _ := http.NewRequest(http.MethodGet, "https://example.com/feed?existing=value", nil) + authenticationError := ApplyAuthentication(request, AuthenticationConfiguration{ + AuthenticationType: "query_param", + AuthenticationValue: "api_key=abc123", + }) + + require.NoError(test, authenticationError) + assert.Equal(test, "abc123", request.URL.Query().Get("api_key")) + assert.Equal(test, "value", request.URL.Query().Get("existing")) +} + +func TestApplyEmptyAuthenticationType(test *testing.T) { + request, _ := http.NewRequest(http.MethodGet, "https://example.com/feed", nil) + authenticationError := ApplyAuthentication(request, AuthenticationConfiguration{ + AuthenticationType: "", + AuthenticationValue: "", + }) + + require.NoError(test, authenticationError) + assert.Empty(test, request.Header.Get("Authorization")) +} + +func TestApplyUnsupportedAuthenticationType(test *testing.T) { + request, _ := http.NewRequest(http.MethodGet, "https://example.com/feed", nil) + authenticationError := ApplyAuthentication(request, AuthenticationConfiguration{ + AuthenticationType: "oauth2", + AuthenticationValue: "token", + }) + + assert.Error(test, authenticationError) + assert.Contains(test, authenticationError.Error(), "unsupported") +} + +func TestApplyQueryParamInvalidFormat(test *testing.T) { + request, _ := http.NewRequest(http.MethodGet, "https://example.com/feed", nil) + authenticationError := ApplyAuthentication(request, AuthenticationConfiguration{ + AuthenticationType: "query_param", + AuthenticationValue: "no-equals-sign", + }) + + assert.Error(test, authenticationError) + assert.Contains(test, authenticationError.Error(), "param_name=value") +} + +func TestApplyQueryParamWithEqualsInValue(test *testing.T) { + request, _ := http.NewRequest(http.MethodGet, "https://example.com/feed", nil) + authenticationError := ApplyAuthentication(request, AuthenticationConfiguration{ + AuthenticationType: "query_param", + AuthenticationValue: "token=abc=def", + }) + + require.NoError(test, authenticationError) + assert.Equal(test, "abc=def", request.URL.Query().Get("token")) +} diff --git a/services/worker/internal/fetcher/errors_test.go b/services/worker/internal/fetcher/errors_test.go new file mode 100644 index 0000000..e81251b --- /dev/null +++ b/services/worker/internal/fetcher/errors_test.go @@ -0,0 +1,169 @@ +package fetcher + +import ( + "fmt" + "github.com/stretchr/testify/assert" + "net" + "net/url" + "testing" +) + +func TestClassifyHTTPStatus401(test *testing.T) { + fetchError := ClassifyError(nil, 401) + + assert.Equal(test, 401, fetchError.StatusCode) + assert.False(test, fetchError.Retryable) + assert.Contains(test, fetchError.UserMessage, "authentication") +} + +func TestClassifyHTTPStatus403(test *testing.T) { + fetchError := ClassifyError(nil, 403) + + assert.Equal(test, 403, fetchError.StatusCode) + assert.False(test, fetchError.Retryable) + assert.Contains(test, fetchError.UserMessage, "forbidden") +} + +func TestClassifyHTTPStatus404(test *testing.T) { + fetchError := ClassifyError(nil, 404) + + assert.Equal(test, 404, fetchError.StatusCode) + assert.False(test, fetchError.Retryable) + assert.Contains(test, fetchError.UserMessage, "not found") +} + +func TestClassifyHTTPStatus410(test *testing.T) { + fetchError := ClassifyError(nil, 410) + + assert.Equal(test, 410, fetchError.StatusCode) + assert.False(test, fetchError.Retryable) + assert.Contains(test, fetchError.UserMessage, "permanently removed") +} + +func TestClassifyHTTPStatus429(test *testing.T) { + fetchError := ClassifyError(nil, 429) + + assert.Equal(test, 429, fetchError.StatusCode) + assert.True(test, fetchError.Retryable) + assert.Contains(test, fetchError.UserMessage, "rate limited") +} + +func TestClassifyHTTPServerErrors(test *testing.T) { + serverErrorCodes := []int{500, 502, 503, 504} + + for _, statusCode := range serverErrorCodes { + test.Run(fmt.Sprintf("status_%d", statusCode), func(test *testing.T) { + fetchError := ClassifyError(nil, statusCode) + + assert.Equal(test, statusCode, fetchError.StatusCode) + assert.True(test, fetchError.Retryable) + assert.Contains(test, fetchError.UserMessage, "server error") + }) + } +} + +func TestClassifyHTTPUnknownClientError(test *testing.T) { + fetchError := ClassifyError(nil, 418) + + assert.Equal(test, 418, fetchError.StatusCode) + assert.False(test, fetchError.Retryable) + assert.Contains(test, fetchError.UserMessage, "418") +} + +func TestClassifyDNSNotFoundError(test *testing.T) { + dnsError := &net.DNSError{ + Name: "nonexistent.example.com", + IsNotFound: true, + } + fetchError := ClassifyError(dnsError, 0) + + assert.False(test, fetchError.Retryable) + assert.Contains(test, fetchError.UserMessage, "DNS") +} + +func TestClassifyDNSTemporaryError(test *testing.T) { + dnsError := &net.DNSError{ + Name: "flaky.example.com", + IsNotFound: false, + } + fetchError := ClassifyError(dnsError, 0) + + assert.True(test, fetchError.Retryable) + assert.Contains(test, fetchError.UserMessage, "DNS") +} + +func TestClassifyTimeoutError(test *testing.T) { + timeoutError := &url.Error{ + Op: "Get", + URL: "https://slow.example.com", + Err: &timeoutErr{}, + } + fetchError := ClassifyError(timeoutError, 0) + + assert.True(test, fetchError.Retryable) + assert.Contains(test, fetchError.UserMessage, "timed out") +} + +type timeoutErr struct{} + +func (timeoutError *timeoutErr) Error() string { return "connection timed out" } +func (timeoutError *timeoutErr) Timeout() bool { return true } +func (timeoutError *timeoutErr) Temporary() bool { return true } + +func TestClassifyNetworkOpError(test *testing.T) { + opError := &net.OpError{ + Op: "dial", + Net: "tcp", + Err: fmt.Errorf("connection refused"), + } + fetchError := ClassifyError(opError, 0) + + assert.True(test, fetchError.Retryable) + assert.Contains(test, fetchError.UserMessage, "network") +} + +func TestClassifyTLSError(test *testing.T) { + tlsError := fmt.Errorf("tls: handshake failure") + fetchError := ClassifyError(tlsError, 0) + + assert.False(test, fetchError.Retryable) + assert.Contains(test, fetchError.UserMessage, "TLS") +} + +func TestClassifyCertificateError(test *testing.T) { + certError := fmt.Errorf("x509: certificate has expired") + fetchError := ClassifyError(certError, 0) + + assert.False(test, fetchError.Retryable) + assert.Contains(test, fetchError.UserMessage, "certificate") +} + +func TestClassifyNilError(test *testing.T) { + fetchError := ClassifyError(nil, 0) + + assert.True(test, fetchError.Retryable) + assert.Contains(test, fetchError.UserMessage, "unknown") +} + +func TestFetchErrorInterface(test *testing.T) { + underlyingError := fmt.Errorf("some network issue") + fetchError := &FetchError{ + StatusCode: 0, + UserMessage: "test error", + Retryable: true, + UnderlyingError: underlyingError, + } + + assert.Contains(test, fetchError.Error(), "test error") + assert.Contains(test, fetchError.Error(), "some network issue") + assert.Equal(test, underlyingError, fetchError.Unwrap()) +} + +func TestFetchErrorWithoutUnderlying(test *testing.T) { + fetchError := &FetchError{ + UserMessage: "standalone error", + } + + assert.Equal(test, "standalone error", fetchError.Error()) + assert.Nil(test, fetchError.Unwrap()) +} diff --git a/services/worker/internal/fetcher/ssrf_protection_test.go b/services/worker/internal/fetcher/ssrf_protection_test.go new file mode 100644 index 0000000..3e78380 --- /dev/null +++ b/services/worker/internal/fetcher/ssrf_protection_test.go @@ -0,0 +1,114 @@ +package fetcher + +import ( + "github.com/stretchr/testify/assert" + "net" + "testing" +) + +func TestIsReservedAddress(test *testing.T) { + reservedAddresses := []struct { + name string + address string + }{ + {"loopback ipv4", "127.0.0.1"}, + {"loopback ipv4 alternate", "127.0.0.2"}, + {"private 10.x", "10.0.0.1"}, + {"private 10.x deep", "10.255.255.255"}, + {"private 172.16.x", "172.16.0.1"}, + {"private 172.31.x", "172.31.255.255"}, + {"private 192.168.x", "192.168.1.1"}, + {"link-local", "169.254.1.1"}, + {"null address", "0.0.0.1"}, + {"ipv6 loopback", "::1"}, + {"ipv6 unique local fc", "fc00::1"}, + {"ipv6 unique local fd", "fd00::1"}, + {"ipv6 link-local", "fe80::1"}, + } + + for _, testCase := range reservedAddresses { + test.Run(testCase.name, func(test *testing.T) { + parsedIP := net.ParseIP(testCase.address) + + assert.True(test, isReservedAddress(parsedIP), "expected %s to be reserved", testCase.address) + }) + } +} + +func TestIsNotReservedAddress(test *testing.T) { + publicAddresses := []struct { + name string + address string + }{ + {"google dns", "8.8.8.8"}, + {"cloudflare dns", "1.1.1.1"}, + {"random public", "93.184.216.34"}, + {"public 172.32", "172.32.0.1"}, + {"public ipv6", "2001:db8::1"}, + } + + for _, testCase := range publicAddresses { + test.Run(testCase.name, func(test *testing.T) { + parsedIP := net.ParseIP(testCase.address) + + assert.False(test, isReservedAddress(parsedIP), "expected %s to not be reserved", testCase.address) + }) + } +} + +func TestValidateFeedURLRejectsUnsupportedSchemes(test *testing.T) { + unsupportedURLs := []string{ + "ftp://example.com/feed", + "file:///etc/passwd", + "javascript:alert(1)", + "data:text/html,hello", + } + + for _, feedURL := range unsupportedURLs { + test.Run(feedURL, func(test *testing.T) { + assert.Error(test, ValidateFeedURL(feedURL)) + }) + } +} + +func TestValidateFeedURLRejectsEmptyHostname(test *testing.T) { + assert.Error(test, ValidateFeedURL("http:///path")) +} + +func TestValidateFeedURLRejectsPrivateIPs(test *testing.T) { + privateURLs := []string{ + "http://127.0.0.1/feed", + "https://10.0.0.1/feed", + "http://192.168.1.1/feed", + "http://172.16.0.1/feed", + "http://[::1]/feed", + } + + for _, feedURL := range privateURLs { + test.Run(feedURL, func(test *testing.T) { + assert.Error(test, ValidateFeedURL(feedURL)) + }) + } +} + +func TestValidateFeedURLAcceptsPublicIPs(test *testing.T) { + publicURLs := []string{ + "http://93.184.216.34/feed", + "https://8.8.8.8/feed", + } + + for _, feedURL := range publicURLs { + test.Run(feedURL, func(test *testing.T) { + assert.NoError(test, ValidateFeedURL(feedURL)) + }) + } +} + +func TestValidateFeedURLAcceptsHTTPAndHTTPS(test *testing.T) { + assert.NoError(test, ValidateFeedURL("http://93.184.216.34/feed")) + assert.NoError(test, ValidateFeedURL("https://93.184.216.34/feed")) +} + +func TestValidateFeedURLRejectsInvalidURL(test *testing.T) { + assert.Error(test, ValidateFeedURL("://broken")) +} diff --git a/services/worker/internal/parser/parser_test.go b/services/worker/internal/parser/parser_test.go new file mode 100644 index 0000000..7a28132 --- /dev/null +++ b/services/worker/internal/parser/parser_test.go @@ -0,0 +1,287 @@ +package parser + +import ( + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "testing" + "time" +) + +const validRSS = `<?xml version="1.0" encoding="UTF-8"?> +<rss version="2.0"> + <channel> + <title>Test Feed</title> + <link>https://example.com</link> + <item> + <guid>entry-1</guid> + <title>First Entry</title> + <link>https://example.com/1</link> + <description>Summary of the first entry.</description> + <content:encoded xmlns:content="http://purl.org/rss/1.0/modules/content/"><p>Full content here.</p></content:encoded> + <author>Alice</author> + <pubDate>Mon, 01 Jan 2024 12:00:00 GMT</pubDate> + </item> + <item> + <guid>entry-2</guid> + <title>Second Entry</title> + <link>https://example.com/2</link> + <description>Summary of the second entry.</description> + <pubDate>Tue, 02 Jan 2024 12:00:00 GMT</pubDate> + </item> + </channel> +</rss>` +const validAtom = `<?xml version="1.0" encoding="UTF-8"?> +<feed xmlns="http://www.w3.org/2005/Atom"> + <title>Atom Feed</title> + <link href="https://atom.example.com"/> + <entry> + <id>atom-1</id> + <title>Atom Entry</title> + <link href="https://atom.example.com/1"/> + <summary>An atom summary.</summary> + <updated>2024-01-01T12:00:00Z</updated> + <author><name>Bob</name></author> + </entry> +</feed>` +const podcastRSS = `<?xml version="1.0" encoding="UTF-8"?> +<rss version="2.0"> + <channel> + <title>My Podcast</title> + <link>https://podcast.example.com</link> + <item> + <guid>ep-1</guid> + <title>Episode 1</title> + <enclosure url="https://cdn.example.com/ep1.mp3" type="audio/mpeg" length="12345678"/> + <pubDate>Mon, 01 Jan 2024 12:00:00 GMT</pubDate> + </item> + <item> + <guid>ep-2</guid> + <title>Episode 2</title> + <enclosure url="https://cdn.example.com/ep2.mp3" type="audio/mpeg" length="87654321"/> + <pubDate>Tue, 02 Jan 2024 12:00:00 GMT</pubDate> + </item> + </channel> +</rss>` +const mixedEnclosureRSS = `<?xml version="1.0" encoding="UTF-8"?> +<rss version="2.0"> + <channel> + <title>Mixed Feed</title> + <link>https://mixed.example.com</link> + <item> + <guid>item-audio</guid> + <title>Audio Item</title> + <enclosure url="https://cdn.example.com/audio.mp3" type="audio/mpeg" length="1000"/> + </item> + <item> + <guid>item-text</guid> + <title>Text Item</title> + <description>Just a text item.</description> + </item> + </channel> +</rss>` +const noGUIDRSS = `<?xml version="1.0" encoding="UTF-8"?> +<rss version="2.0"> + <channel> + <title>No GUID Feed</title> + <link>https://example.com</link> + <item> + <title>No GUID but has link</title> + <link>https://example.com/no-guid</link> + </item> + <item> + <title>No GUID no link</title> + <description>Only title and description.</description> + </item> + </channel> +</rss>` + +func TestParseValidRSS(test *testing.T) { + feedParser := NewParser() + result, parseError := feedParser.Parse("feed-123", nil, []byte(validRSS)) + + require.NoError(test, parseError) + assert.Equal(test, "Test Feed", result.FeedTitle) + assert.Equal(test, "https://example.com", result.SiteURL) + assert.Len(test, result.Entries, 2) + assert.Equal(test, 0.0, result.AudioEnclosureRatio) + + first := result.Entries[0] + + assert.Equal(test, "entry-1", first.GUID) + assert.Equal(test, "feed-123", first.FeedIdentifier) + require.NotNil(test, first.Title) + assert.Equal(test, "First Entry", *first.Title) + require.NotNil(test, first.URL) + assert.Equal(test, "https://example.com/1", *first.URL) + require.NotNil(test, first.Summary) + assert.Equal(test, "Summary of the first entry.", *first.Summary) + require.NotNil(test, first.ContentHTML) + assert.Contains(test, *first.ContentHTML, "Full content here.") + require.NotNil(test, first.PublishedAt) + assert.Nil(test, first.OwnerIdentifier) +} + +func TestParseValidAtom(test *testing.T) { + feedParser := NewParser() + result, parseError := feedParser.Parse("atom-feed", nil, []byte(validAtom)) + + require.NoError(test, parseError) + assert.Equal(test, "Atom Feed", result.FeedTitle) + assert.Len(test, result.Entries, 1) + + entry := result.Entries[0] + + assert.Equal(test, "atom-1", entry.GUID) + require.NotNil(test, entry.Title) + assert.Equal(test, "Atom Entry", *entry.Title) + require.NotNil(test, entry.Author) + assert.Equal(test, "Bob", *entry.Author) +} + +func TestParsePodcastFeed(test *testing.T) { + feedParser := NewParser() + result, parseError := feedParser.Parse("podcast-1", nil, []byte(podcastRSS)) + + require.NoError(test, parseError) + assert.Equal(test, "My Podcast", result.FeedTitle) + assert.Len(test, result.Entries, 2) + assert.InDelta(test, 1.0, result.AudioEnclosureRatio, 0.01) + + firstEpisode := result.Entries[0] + + require.NotNil(test, firstEpisode.EnclosureURL) + assert.Equal(test, "https://cdn.example.com/ep1.mp3", *firstEpisode.EnclosureURL) + require.NotNil(test, firstEpisode.EnclosureType) + assert.Equal(test, "audio/mpeg", *firstEpisode.EnclosureType) + require.NotNil(test, firstEpisode.EnclosureLength) + assert.Equal(test, int64(12345678), *firstEpisode.EnclosureLength) +} + +func TestParseMixedEnclosureFeed(test *testing.T) { + feedParser := NewParser() + result, parseError := feedParser.Parse("mixed-1", nil, []byte(mixedEnclosureRSS)) + + require.NoError(test, parseError) + assert.Len(test, result.Entries, 2) + assert.InDelta(test, 0.5, result.AudioEnclosureRatio, 0.01) + + audioItem := result.Entries[0] + + require.NotNil(test, audioItem.EnclosureURL) + + textItem := result.Entries[1] + + assert.Nil(test, textItem.EnclosureURL) + assert.Nil(test, textItem.EnclosureType) + assert.Nil(test, textItem.EnclosureLength) +} + +func TestParseGUIDFallback(test *testing.T) { + feedParser := NewParser() + result, parseError := feedParser.Parse("no-guid-feed", nil, []byte(noGUIDRSS)) + + require.NoError(test, parseError) + assert.Len(test, result.Entries, 2) + + entryWithLink := result.Entries[0] + + assert.Equal(test, "https://example.com/no-guid", entryWithLink.GUID) + + entryWithHash := result.Entries[1] + + assert.True(test, len(entryWithHash.GUID) > 0) + assert.Contains(test, entryWithHash.GUID, "sha256:") +} + +func TestParseOwnerIdentifier(test *testing.T) { + feedParser := NewParser() + ownerIdentifier := "user-abc" + result, parseError := feedParser.Parse("owned-feed", &ownerIdentifier, []byte(validRSS)) + + require.NoError(test, parseError) + + for _, entry := range result.Entries { + require.NotNil(test, entry.OwnerIdentifier) + assert.Equal(test, "user-abc", *entry.OwnerIdentifier) + } +} + +func TestParseInvalidXML(test *testing.T) { + feedParser := NewParser() + _, parseError := feedParser.Parse("bad-feed", nil, []byte("this is not xml at all")) + + assert.Error(test, parseError) +} + +func TestParsePublishedDateFallsBackToUpdated(test *testing.T) { + feedParser := NewParser() + result, parseError := feedParser.Parse("atom-feed", nil, []byte(validAtom)) + + require.NoError(test, parseError) + + entry := result.Entries[0] + + require.NotNil(test, entry.PublishedAt) + + expectedTime := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC) + + assert.True(test, entry.PublishedAt.Equal(expectedTime)) +} + +func TestParseWordCount(test *testing.T) { + feedParser := NewParser() + result, parseError := feedParser.Parse("feed-123", nil, []byte(validRSS)) + + require.NoError(test, parseError) + + first := result.Entries[0] + + require.NotNil(test, first.WordCount) + assert.Greater(test, *first.WordCount, 0) +} + +func TestStripHTMLTags(test *testing.T) { + testCases := []struct { + name string + input string + expected string + }{ + {"plain text passthrough", "hello world", "hello world"}, + {"strips simple tags", "<p>hello</p>", "hello"}, + {"strips nested tags", "<div><p>hello <strong>world</strong></p></div>", "hello world"}, + {"handles unicode", "<p>café résumé naïve</p>", "café résumé naïve"}, + {"empty string", "", ""}, + {"only tags", "<br/><hr/>", ""}, + } + + for _, testCase := range testCases { + test.Run(testCase.name, func(test *testing.T) { + assert.Equal(test, testCase.expected, stripHTMLTags(testCase.input)) + }) + } +} + +func TestCountWords(test *testing.T) { + emptyResult := countWords("") + + assert.Nil(test, emptyResult) + + twoWords := countWords("hello world") + + require.NotNil(test, twoWords) + assert.Equal(test, 2, *twoWords) + + withExtraSpaces := countWords(" hello world ") + + require.NotNil(test, withExtraSpaces) + assert.Equal(test, 2, *withExtraSpaces) +} + +func TestStringPointerOrNil(test *testing.T) { + assert.Nil(test, stringPointerOrNil("")) + + result := stringPointerOrNil("hello") + + require.NotNil(test, result) + assert.Equal(test, "hello", *result) +} diff --git a/services/worker/internal/pool/pool_test.go b/services/worker/internal/pool/pool_test.go new file mode 100644 index 0000000..f347faf --- /dev/null +++ b/services/worker/internal/pool/pool_test.go @@ -0,0 +1,125 @@ +package pool + +import ( + "context" + "github.com/stretchr/testify/assert" + "log/slog" + "sync" + "sync/atomic" + "testing" + "time" +) + +func newTestLogger() *slog.Logger { + return slog.New(slog.DiscardHandler) +} + +func TestWorkerPoolConcurrencyLimit(test *testing.T) { + workerPool := NewWorkerPool(2, newTestLogger()) + workerContext := context.Background() + + var peakConcurrency atomic.Int32 + var currentConcurrency atomic.Int32 + var completedJobs atomic.Int32 + + for jobIndex := 0; jobIndex < 10; jobIndex++ { + workerPool.Submit(workerContext, func(workContext context.Context) { + current := currentConcurrency.Add(1) + + for { + peak := peakConcurrency.Load() + + if current <= peak || peakConcurrency.CompareAndSwap(peak, current) { + break + } + } + + time.Sleep(10 * time.Millisecond) + currentConcurrency.Add(-1) + completedJobs.Add(1) + }) + } + + workerPool.Wait() + assert.Equal(test, int32(10), completedJobs.Load()) + assert.LessOrEqual(test, peakConcurrency.Load(), int32(2)) +} + +func TestWorkerPoolPanicRecovery(test *testing.T) { + workerPool := NewWorkerPool(2, newTestLogger()) + workerContext := context.Background() + + var completedAfterPanic atomic.Bool + + workerPool.Submit(workerContext, func(workContext context.Context) { + panic("intentional test panic") + }) + time.Sleep(20 * time.Millisecond) + workerPool.Submit(workerContext, func(workContext context.Context) { + completedAfterPanic.Store(true) + }) + workerPool.Wait() + assert.True(test, completedAfterPanic.Load()) +} + +func TestWorkerPoolCancelledContext(test *testing.T) { + workerPool := NewWorkerPool(1, newTestLogger()) + cancelledContext, cancelFunction := context.WithCancel(context.Background()) + + var blockingMutex sync.Mutex + + blockingMutex.Lock() + workerPool.Submit(context.Background(), func(workContext context.Context) { + blockingMutex.Lock() + + defer blockingMutex.Unlock() + }) + cancelFunction() + + submitted := workerPool.Submit(cancelledContext, func(workContext context.Context) {}) + + assert.False(test, submitted) + blockingMutex.Unlock() + workerPool.Wait() +} + +func TestWorkerPoolWaitBlocksUntilDone(test *testing.T) { + workerPool := NewWorkerPool(4, newTestLogger()) + workerContext := context.Background() + + var counter atomic.Int32 + + for jobIndex := 0; jobIndex < 20; jobIndex++ { + workerPool.Submit(workerContext, func(workContext context.Context) { + time.Sleep(5 * time.Millisecond) + counter.Add(1) + }) + } + + workerPool.Wait() + assert.Equal(test, int32(20), counter.Load()) +} + +func TestWorkerPoolActiveWorkerCount(test *testing.T) { + workerPool := NewWorkerPool(3, newTestLogger()) + + assert.Equal(test, 0, workerPool.ActiveWorkerCount()) + + var releaseMutex sync.Mutex + + releaseMutex.Lock() + workerPool.Submit(context.Background(), func(workContext context.Context) { + releaseMutex.Lock() + + defer releaseMutex.Unlock() + }) + workerPool.Submit(context.Background(), func(workContext context.Context) { + releaseMutex.Lock() + + defer releaseMutex.Unlock() + }) + time.Sleep(20 * time.Millisecond) + assert.Equal(test, 2, workerPool.ActiveWorkerCount()) + releaseMutex.Unlock() + workerPool.Wait() +} diff --git a/services/worker/internal/writer/writer.go b/services/worker/internal/writer/writer.go index f681a4e..543b6e6 100644 --- a/services/worker/internal/writer/writer.go +++ b/services/worker/internal/writer/writer.go @@ -3,11 +3,11 @@ package writer import ( "context" "fmt" + "github.com/Fuwn/asa-news/internal/model" + "github.com/jackc/pgx/v5/pgxpool" "net/url" "strings" "time" - "github.com/Fuwn/asa-news/internal/model" - "github.com/jackc/pgx/v5/pgxpool" ) type Writer struct { |