summaryrefslogtreecommitdiff
path: root/services/worker
diff options
context:
space:
mode:
authorFuwn <[email protected]>2026-02-10 01:59:01 -0800
committerFuwn <[email protected]>2026-02-10 01:59:01 -0800
commit871985bc9eb42c6a088563e7c34db181f603f407 (patch)
tree31299597a9f246d332b3bf6d5e2bed177648b577 /services/worker
parentfeat: reorder feature grid by attention-grabbing impact (diff)
downloadasa.news-871985bc9eb42c6a088563e7c34db181f603f407.tar.xz
asa.news-871985bc9eb42c6a088563e7c34db181f603f407.zip
fix: harden CI and close remaining test/security gaps
- Make webhook URL tests deterministic with injectable DNS resolver - Wire tier parity checker into CI and root scripts - Add rate_limits cleanup cron job (hourly, >1hr retention) - Change rate limiter to fail closed on RPC error - Add Go worker tests: parser, SSRF protection, error classification, authentication, and worker pool (48 test functions)
Diffstat (limited to 'services/worker')
-rw-r--r--services/worker/go.mod2
-rw-r--r--services/worker/internal/fetcher/authentication_test.go90
-rw-r--r--services/worker/internal/fetcher/errors_test.go169
-rw-r--r--services/worker/internal/fetcher/ssrf_protection_test.go114
-rw-r--r--services/worker/internal/parser/parser_test.go287
-rw-r--r--services/worker/internal/pool/pool_test.go125
-rw-r--r--services/worker/internal/writer/writer.go4
7 files changed, 788 insertions, 3 deletions
diff --git a/services/worker/go.mod b/services/worker/go.mod
index 2588959..332a44d 100644
--- a/services/worker/go.mod
+++ b/services/worker/go.mod
@@ -8,6 +8,7 @@ require (
github.com/craigpastro/pgmq-go v0.6.0
github.com/jackc/pgx/v5 v5.7.2
github.com/mmcdole/gofeed v1.3.0
+ github.com/stretchr/testify v1.11.1
)
require (
@@ -58,7 +59,6 @@ require (
github.com/shirou/gopsutil/v3 v3.24.5 // indirect
github.com/shoenig/go-m1cpu v0.1.6 // indirect
github.com/sirupsen/logrus v1.9.3 // indirect
- github.com/stretchr/testify v1.11.1 // indirect
github.com/testcontainers/testcontainers-go v0.35.0 // indirect
github.com/tklauser/go-sysconf v0.3.14 // indirect
github.com/tklauser/numcpus v0.9.0 // indirect
diff --git a/services/worker/internal/fetcher/authentication_test.go b/services/worker/internal/fetcher/authentication_test.go
new file mode 100644
index 0000000..bc840b9
--- /dev/null
+++ b/services/worker/internal/fetcher/authentication_test.go
@@ -0,0 +1,90 @@
+package fetcher
+
+import (
+ "encoding/base64"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "net/http"
+ "testing"
+)
+
+func TestApplyBearerAuthentication(test *testing.T) {
+ request, _ := http.NewRequest(http.MethodGet, "https://example.com/feed", nil)
+ authenticationError := ApplyAuthentication(request, AuthenticationConfiguration{
+ AuthenticationType: "bearer",
+ AuthenticationValue: "my-secret-token",
+ })
+
+ require.NoError(test, authenticationError)
+ assert.Equal(test, "Bearer my-secret-token", request.Header.Get("Authorization"))
+}
+
+func TestApplyBasicAuthentication(test *testing.T) {
+ request, _ := http.NewRequest(http.MethodGet, "https://example.com/feed", nil)
+ authenticationError := ApplyAuthentication(request, AuthenticationConfiguration{
+ AuthenticationType: "basic",
+ AuthenticationValue: "user:pass",
+ })
+
+ require.NoError(test, authenticationError)
+
+ expectedEncoded := base64.StdEncoding.EncodeToString([]byte("user:pass"))
+
+ assert.Equal(test, "Basic "+expectedEncoded, request.Header.Get("Authorization"))
+}
+
+func TestApplyQueryParamAuthentication(test *testing.T) {
+ request, _ := http.NewRequest(http.MethodGet, "https://example.com/feed?existing=value", nil)
+ authenticationError := ApplyAuthentication(request, AuthenticationConfiguration{
+ AuthenticationType: "query_param",
+ AuthenticationValue: "api_key=abc123",
+ })
+
+ require.NoError(test, authenticationError)
+ assert.Equal(test, "abc123", request.URL.Query().Get("api_key"))
+ assert.Equal(test, "value", request.URL.Query().Get("existing"))
+}
+
+func TestApplyEmptyAuthenticationType(test *testing.T) {
+ request, _ := http.NewRequest(http.MethodGet, "https://example.com/feed", nil)
+ authenticationError := ApplyAuthentication(request, AuthenticationConfiguration{
+ AuthenticationType: "",
+ AuthenticationValue: "",
+ })
+
+ require.NoError(test, authenticationError)
+ assert.Empty(test, request.Header.Get("Authorization"))
+}
+
+func TestApplyUnsupportedAuthenticationType(test *testing.T) {
+ request, _ := http.NewRequest(http.MethodGet, "https://example.com/feed", nil)
+ authenticationError := ApplyAuthentication(request, AuthenticationConfiguration{
+ AuthenticationType: "oauth2",
+ AuthenticationValue: "token",
+ })
+
+ assert.Error(test, authenticationError)
+ assert.Contains(test, authenticationError.Error(), "unsupported")
+}
+
+func TestApplyQueryParamInvalidFormat(test *testing.T) {
+ request, _ := http.NewRequest(http.MethodGet, "https://example.com/feed", nil)
+ authenticationError := ApplyAuthentication(request, AuthenticationConfiguration{
+ AuthenticationType: "query_param",
+ AuthenticationValue: "no-equals-sign",
+ })
+
+ assert.Error(test, authenticationError)
+ assert.Contains(test, authenticationError.Error(), "param_name=value")
+}
+
+func TestApplyQueryParamWithEqualsInValue(test *testing.T) {
+ request, _ := http.NewRequest(http.MethodGet, "https://example.com/feed", nil)
+ authenticationError := ApplyAuthentication(request, AuthenticationConfiguration{
+ AuthenticationType: "query_param",
+ AuthenticationValue: "token=abc=def",
+ })
+
+ require.NoError(test, authenticationError)
+ assert.Equal(test, "abc=def", request.URL.Query().Get("token"))
+}
diff --git a/services/worker/internal/fetcher/errors_test.go b/services/worker/internal/fetcher/errors_test.go
new file mode 100644
index 0000000..e81251b
--- /dev/null
+++ b/services/worker/internal/fetcher/errors_test.go
@@ -0,0 +1,169 @@
+package fetcher
+
+import (
+ "fmt"
+ "github.com/stretchr/testify/assert"
+ "net"
+ "net/url"
+ "testing"
+)
+
+func TestClassifyHTTPStatus401(test *testing.T) {
+ fetchError := ClassifyError(nil, 401)
+
+ assert.Equal(test, 401, fetchError.StatusCode)
+ assert.False(test, fetchError.Retryable)
+ assert.Contains(test, fetchError.UserMessage, "authentication")
+}
+
+func TestClassifyHTTPStatus403(test *testing.T) {
+ fetchError := ClassifyError(nil, 403)
+
+ assert.Equal(test, 403, fetchError.StatusCode)
+ assert.False(test, fetchError.Retryable)
+ assert.Contains(test, fetchError.UserMessage, "forbidden")
+}
+
+func TestClassifyHTTPStatus404(test *testing.T) {
+ fetchError := ClassifyError(nil, 404)
+
+ assert.Equal(test, 404, fetchError.StatusCode)
+ assert.False(test, fetchError.Retryable)
+ assert.Contains(test, fetchError.UserMessage, "not found")
+}
+
+func TestClassifyHTTPStatus410(test *testing.T) {
+ fetchError := ClassifyError(nil, 410)
+
+ assert.Equal(test, 410, fetchError.StatusCode)
+ assert.False(test, fetchError.Retryable)
+ assert.Contains(test, fetchError.UserMessage, "permanently removed")
+}
+
+func TestClassifyHTTPStatus429(test *testing.T) {
+ fetchError := ClassifyError(nil, 429)
+
+ assert.Equal(test, 429, fetchError.StatusCode)
+ assert.True(test, fetchError.Retryable)
+ assert.Contains(test, fetchError.UserMessage, "rate limited")
+}
+
+func TestClassifyHTTPServerErrors(test *testing.T) {
+ serverErrorCodes := []int{500, 502, 503, 504}
+
+ for _, statusCode := range serverErrorCodes {
+ test.Run(fmt.Sprintf("status_%d", statusCode), func(test *testing.T) {
+ fetchError := ClassifyError(nil, statusCode)
+
+ assert.Equal(test, statusCode, fetchError.StatusCode)
+ assert.True(test, fetchError.Retryable)
+ assert.Contains(test, fetchError.UserMessage, "server error")
+ })
+ }
+}
+
+func TestClassifyHTTPUnknownClientError(test *testing.T) {
+ fetchError := ClassifyError(nil, 418)
+
+ assert.Equal(test, 418, fetchError.StatusCode)
+ assert.False(test, fetchError.Retryable)
+ assert.Contains(test, fetchError.UserMessage, "418")
+}
+
+func TestClassifyDNSNotFoundError(test *testing.T) {
+ dnsError := &net.DNSError{
+ Name: "nonexistent.example.com",
+ IsNotFound: true,
+ }
+ fetchError := ClassifyError(dnsError, 0)
+
+ assert.False(test, fetchError.Retryable)
+ assert.Contains(test, fetchError.UserMessage, "DNS")
+}
+
+func TestClassifyDNSTemporaryError(test *testing.T) {
+ dnsError := &net.DNSError{
+ Name: "flaky.example.com",
+ IsNotFound: false,
+ }
+ fetchError := ClassifyError(dnsError, 0)
+
+ assert.True(test, fetchError.Retryable)
+ assert.Contains(test, fetchError.UserMessage, "DNS")
+}
+
+func TestClassifyTimeoutError(test *testing.T) {
+ timeoutError := &url.Error{
+ Op: "Get",
+ URL: "https://slow.example.com",
+ Err: &timeoutErr{},
+ }
+ fetchError := ClassifyError(timeoutError, 0)
+
+ assert.True(test, fetchError.Retryable)
+ assert.Contains(test, fetchError.UserMessage, "timed out")
+}
+
+type timeoutErr struct{}
+
+func (timeoutError *timeoutErr) Error() string { return "connection timed out" }
+func (timeoutError *timeoutErr) Timeout() bool { return true }
+func (timeoutError *timeoutErr) Temporary() bool { return true }
+
+func TestClassifyNetworkOpError(test *testing.T) {
+ opError := &net.OpError{
+ Op: "dial",
+ Net: "tcp",
+ Err: fmt.Errorf("connection refused"),
+ }
+ fetchError := ClassifyError(opError, 0)
+
+ assert.True(test, fetchError.Retryable)
+ assert.Contains(test, fetchError.UserMessage, "network")
+}
+
+func TestClassifyTLSError(test *testing.T) {
+ tlsError := fmt.Errorf("tls: handshake failure")
+ fetchError := ClassifyError(tlsError, 0)
+
+ assert.False(test, fetchError.Retryable)
+ assert.Contains(test, fetchError.UserMessage, "TLS")
+}
+
+func TestClassifyCertificateError(test *testing.T) {
+ certError := fmt.Errorf("x509: certificate has expired")
+ fetchError := ClassifyError(certError, 0)
+
+ assert.False(test, fetchError.Retryable)
+ assert.Contains(test, fetchError.UserMessage, "certificate")
+}
+
+func TestClassifyNilError(test *testing.T) {
+ fetchError := ClassifyError(nil, 0)
+
+ assert.True(test, fetchError.Retryable)
+ assert.Contains(test, fetchError.UserMessage, "unknown")
+}
+
+func TestFetchErrorInterface(test *testing.T) {
+ underlyingError := fmt.Errorf("some network issue")
+ fetchError := &FetchError{
+ StatusCode: 0,
+ UserMessage: "test error",
+ Retryable: true,
+ UnderlyingError: underlyingError,
+ }
+
+ assert.Contains(test, fetchError.Error(), "test error")
+ assert.Contains(test, fetchError.Error(), "some network issue")
+ assert.Equal(test, underlyingError, fetchError.Unwrap())
+}
+
+func TestFetchErrorWithoutUnderlying(test *testing.T) {
+ fetchError := &FetchError{
+ UserMessage: "standalone error",
+ }
+
+ assert.Equal(test, "standalone error", fetchError.Error())
+ assert.Nil(test, fetchError.Unwrap())
+}
diff --git a/services/worker/internal/fetcher/ssrf_protection_test.go b/services/worker/internal/fetcher/ssrf_protection_test.go
new file mode 100644
index 0000000..3e78380
--- /dev/null
+++ b/services/worker/internal/fetcher/ssrf_protection_test.go
@@ -0,0 +1,114 @@
+package fetcher
+
+import (
+ "github.com/stretchr/testify/assert"
+ "net"
+ "testing"
+)
+
+func TestIsReservedAddress(test *testing.T) {
+ reservedAddresses := []struct {
+ name string
+ address string
+ }{
+ {"loopback ipv4", "127.0.0.1"},
+ {"loopback ipv4 alternate", "127.0.0.2"},
+ {"private 10.x", "10.0.0.1"},
+ {"private 10.x deep", "10.255.255.255"},
+ {"private 172.16.x", "172.16.0.1"},
+ {"private 172.31.x", "172.31.255.255"},
+ {"private 192.168.x", "192.168.1.1"},
+ {"link-local", "169.254.1.1"},
+ {"null address", "0.0.0.1"},
+ {"ipv6 loopback", "::1"},
+ {"ipv6 unique local fc", "fc00::1"},
+ {"ipv6 unique local fd", "fd00::1"},
+ {"ipv6 link-local", "fe80::1"},
+ }
+
+ for _, testCase := range reservedAddresses {
+ test.Run(testCase.name, func(test *testing.T) {
+ parsedIP := net.ParseIP(testCase.address)
+
+ assert.True(test, isReservedAddress(parsedIP), "expected %s to be reserved", testCase.address)
+ })
+ }
+}
+
+func TestIsNotReservedAddress(test *testing.T) {
+ publicAddresses := []struct {
+ name string
+ address string
+ }{
+ {"google dns", "8.8.8.8"},
+ {"cloudflare dns", "1.1.1.1"},
+ {"random public", "93.184.216.34"},
+ {"public 172.32", "172.32.0.1"},
+ {"public ipv6", "2001:db8::1"},
+ }
+
+ for _, testCase := range publicAddresses {
+ test.Run(testCase.name, func(test *testing.T) {
+ parsedIP := net.ParseIP(testCase.address)
+
+ assert.False(test, isReservedAddress(parsedIP), "expected %s to not be reserved", testCase.address)
+ })
+ }
+}
+
+func TestValidateFeedURLRejectsUnsupportedSchemes(test *testing.T) {
+ unsupportedURLs := []string{
+ "ftp://example.com/feed",
+ "file:///etc/passwd",
+ "javascript:alert(1)",
+ "data:text/html,hello",
+ }
+
+ for _, feedURL := range unsupportedURLs {
+ test.Run(feedURL, func(test *testing.T) {
+ assert.Error(test, ValidateFeedURL(feedURL))
+ })
+ }
+}
+
+func TestValidateFeedURLRejectsEmptyHostname(test *testing.T) {
+ assert.Error(test, ValidateFeedURL("http:///path"))
+}
+
+func TestValidateFeedURLRejectsPrivateIPs(test *testing.T) {
+ privateURLs := []string{
+ "http://127.0.0.1/feed",
+ "https://10.0.0.1/feed",
+ "http://192.168.1.1/feed",
+ "http://172.16.0.1/feed",
+ "http://[::1]/feed",
+ }
+
+ for _, feedURL := range privateURLs {
+ test.Run(feedURL, func(test *testing.T) {
+ assert.Error(test, ValidateFeedURL(feedURL))
+ })
+ }
+}
+
+func TestValidateFeedURLAcceptsPublicIPs(test *testing.T) {
+ publicURLs := []string{
+ "http://93.184.216.34/feed",
+ "https://8.8.8.8/feed",
+ }
+
+ for _, feedURL := range publicURLs {
+ test.Run(feedURL, func(test *testing.T) {
+ assert.NoError(test, ValidateFeedURL(feedURL))
+ })
+ }
+}
+
+func TestValidateFeedURLAcceptsHTTPAndHTTPS(test *testing.T) {
+ assert.NoError(test, ValidateFeedURL("http://93.184.216.34/feed"))
+ assert.NoError(test, ValidateFeedURL("https://93.184.216.34/feed"))
+}
+
+func TestValidateFeedURLRejectsInvalidURL(test *testing.T) {
+ assert.Error(test, ValidateFeedURL("://broken"))
+}
diff --git a/services/worker/internal/parser/parser_test.go b/services/worker/internal/parser/parser_test.go
new file mode 100644
index 0000000..7a28132
--- /dev/null
+++ b/services/worker/internal/parser/parser_test.go
@@ -0,0 +1,287 @@
+package parser
+
+import (
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "testing"
+ "time"
+)
+
+const validRSS = `<?xml version="1.0" encoding="UTF-8"?>
+<rss version="2.0">
+ <channel>
+ <title>Test Feed</title>
+ <link>https://example.com</link>
+ <item>
+ <guid>entry-1</guid>
+ <title>First Entry</title>
+ <link>https://example.com/1</link>
+ <description>Summary of the first entry.</description>
+ <content:encoded xmlns:content="http://purl.org/rss/1.0/modules/content/">&lt;p&gt;Full content here.&lt;/p&gt;</content:encoded>
+ <author>Alice</author>
+ <pubDate>Mon, 01 Jan 2024 12:00:00 GMT</pubDate>
+ </item>
+ <item>
+ <guid>entry-2</guid>
+ <title>Second Entry</title>
+ <link>https://example.com/2</link>
+ <description>Summary of the second entry.</description>
+ <pubDate>Tue, 02 Jan 2024 12:00:00 GMT</pubDate>
+ </item>
+ </channel>
+</rss>`
+const validAtom = `<?xml version="1.0" encoding="UTF-8"?>
+<feed xmlns="http://www.w3.org/2005/Atom">
+ <title>Atom Feed</title>
+ <link href="https://atom.example.com"/>
+ <entry>
+ <id>atom-1</id>
+ <title>Atom Entry</title>
+ <link href="https://atom.example.com/1"/>
+ <summary>An atom summary.</summary>
+ <updated>2024-01-01T12:00:00Z</updated>
+ <author><name>Bob</name></author>
+ </entry>
+</feed>`
+const podcastRSS = `<?xml version="1.0" encoding="UTF-8"?>
+<rss version="2.0">
+ <channel>
+ <title>My Podcast</title>
+ <link>https://podcast.example.com</link>
+ <item>
+ <guid>ep-1</guid>
+ <title>Episode 1</title>
+ <enclosure url="https://cdn.example.com/ep1.mp3" type="audio/mpeg" length="12345678"/>
+ <pubDate>Mon, 01 Jan 2024 12:00:00 GMT</pubDate>
+ </item>
+ <item>
+ <guid>ep-2</guid>
+ <title>Episode 2</title>
+ <enclosure url="https://cdn.example.com/ep2.mp3" type="audio/mpeg" length="87654321"/>
+ <pubDate>Tue, 02 Jan 2024 12:00:00 GMT</pubDate>
+ </item>
+ </channel>
+</rss>`
+const mixedEnclosureRSS = `<?xml version="1.0" encoding="UTF-8"?>
+<rss version="2.0">
+ <channel>
+ <title>Mixed Feed</title>
+ <link>https://mixed.example.com</link>
+ <item>
+ <guid>item-audio</guid>
+ <title>Audio Item</title>
+ <enclosure url="https://cdn.example.com/audio.mp3" type="audio/mpeg" length="1000"/>
+ </item>
+ <item>
+ <guid>item-text</guid>
+ <title>Text Item</title>
+ <description>Just a text item.</description>
+ </item>
+ </channel>
+</rss>`
+const noGUIDRSS = `<?xml version="1.0" encoding="UTF-8"?>
+<rss version="2.0">
+ <channel>
+ <title>No GUID Feed</title>
+ <link>https://example.com</link>
+ <item>
+ <title>No GUID but has link</title>
+ <link>https://example.com/no-guid</link>
+ </item>
+ <item>
+ <title>No GUID no link</title>
+ <description>Only title and description.</description>
+ </item>
+ </channel>
+</rss>`
+
+func TestParseValidRSS(test *testing.T) {
+ feedParser := NewParser()
+ result, parseError := feedParser.Parse("feed-123", nil, []byte(validRSS))
+
+ require.NoError(test, parseError)
+ assert.Equal(test, "Test Feed", result.FeedTitle)
+ assert.Equal(test, "https://example.com", result.SiteURL)
+ assert.Len(test, result.Entries, 2)
+ assert.Equal(test, 0.0, result.AudioEnclosureRatio)
+
+ first := result.Entries[0]
+
+ assert.Equal(test, "entry-1", first.GUID)
+ assert.Equal(test, "feed-123", first.FeedIdentifier)
+ require.NotNil(test, first.Title)
+ assert.Equal(test, "First Entry", *first.Title)
+ require.NotNil(test, first.URL)
+ assert.Equal(test, "https://example.com/1", *first.URL)
+ require.NotNil(test, first.Summary)
+ assert.Equal(test, "Summary of the first entry.", *first.Summary)
+ require.NotNil(test, first.ContentHTML)
+ assert.Contains(test, *first.ContentHTML, "Full content here.")
+ require.NotNil(test, first.PublishedAt)
+ assert.Nil(test, first.OwnerIdentifier)
+}
+
+func TestParseValidAtom(test *testing.T) {
+ feedParser := NewParser()
+ result, parseError := feedParser.Parse("atom-feed", nil, []byte(validAtom))
+
+ require.NoError(test, parseError)
+ assert.Equal(test, "Atom Feed", result.FeedTitle)
+ assert.Len(test, result.Entries, 1)
+
+ entry := result.Entries[0]
+
+ assert.Equal(test, "atom-1", entry.GUID)
+ require.NotNil(test, entry.Title)
+ assert.Equal(test, "Atom Entry", *entry.Title)
+ require.NotNil(test, entry.Author)
+ assert.Equal(test, "Bob", *entry.Author)
+}
+
+func TestParsePodcastFeed(test *testing.T) {
+ feedParser := NewParser()
+ result, parseError := feedParser.Parse("podcast-1", nil, []byte(podcastRSS))
+
+ require.NoError(test, parseError)
+ assert.Equal(test, "My Podcast", result.FeedTitle)
+ assert.Len(test, result.Entries, 2)
+ assert.InDelta(test, 1.0, result.AudioEnclosureRatio, 0.01)
+
+ firstEpisode := result.Entries[0]
+
+ require.NotNil(test, firstEpisode.EnclosureURL)
+ assert.Equal(test, "https://cdn.example.com/ep1.mp3", *firstEpisode.EnclosureURL)
+ require.NotNil(test, firstEpisode.EnclosureType)
+ assert.Equal(test, "audio/mpeg", *firstEpisode.EnclosureType)
+ require.NotNil(test, firstEpisode.EnclosureLength)
+ assert.Equal(test, int64(12345678), *firstEpisode.EnclosureLength)
+}
+
+func TestParseMixedEnclosureFeed(test *testing.T) {
+ feedParser := NewParser()
+ result, parseError := feedParser.Parse("mixed-1", nil, []byte(mixedEnclosureRSS))
+
+ require.NoError(test, parseError)
+ assert.Len(test, result.Entries, 2)
+ assert.InDelta(test, 0.5, result.AudioEnclosureRatio, 0.01)
+
+ audioItem := result.Entries[0]
+
+ require.NotNil(test, audioItem.EnclosureURL)
+
+ textItem := result.Entries[1]
+
+ assert.Nil(test, textItem.EnclosureURL)
+ assert.Nil(test, textItem.EnclosureType)
+ assert.Nil(test, textItem.EnclosureLength)
+}
+
+func TestParseGUIDFallback(test *testing.T) {
+ feedParser := NewParser()
+ result, parseError := feedParser.Parse("no-guid-feed", nil, []byte(noGUIDRSS))
+
+ require.NoError(test, parseError)
+ assert.Len(test, result.Entries, 2)
+
+ entryWithLink := result.Entries[0]
+
+ assert.Equal(test, "https://example.com/no-guid", entryWithLink.GUID)
+
+ entryWithHash := result.Entries[1]
+
+ assert.True(test, len(entryWithHash.GUID) > 0)
+ assert.Contains(test, entryWithHash.GUID, "sha256:")
+}
+
+func TestParseOwnerIdentifier(test *testing.T) {
+ feedParser := NewParser()
+ ownerIdentifier := "user-abc"
+ result, parseError := feedParser.Parse("owned-feed", &ownerIdentifier, []byte(validRSS))
+
+ require.NoError(test, parseError)
+
+ for _, entry := range result.Entries {
+ require.NotNil(test, entry.OwnerIdentifier)
+ assert.Equal(test, "user-abc", *entry.OwnerIdentifier)
+ }
+}
+
+func TestParseInvalidXML(test *testing.T) {
+ feedParser := NewParser()
+ _, parseError := feedParser.Parse("bad-feed", nil, []byte("this is not xml at all"))
+
+ assert.Error(test, parseError)
+}
+
+func TestParsePublishedDateFallsBackToUpdated(test *testing.T) {
+ feedParser := NewParser()
+ result, parseError := feedParser.Parse("atom-feed", nil, []byte(validAtom))
+
+ require.NoError(test, parseError)
+
+ entry := result.Entries[0]
+
+ require.NotNil(test, entry.PublishedAt)
+
+ expectedTime := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC)
+
+ assert.True(test, entry.PublishedAt.Equal(expectedTime))
+}
+
+func TestParseWordCount(test *testing.T) {
+ feedParser := NewParser()
+ result, parseError := feedParser.Parse("feed-123", nil, []byte(validRSS))
+
+ require.NoError(test, parseError)
+
+ first := result.Entries[0]
+
+ require.NotNil(test, first.WordCount)
+ assert.Greater(test, *first.WordCount, 0)
+}
+
+func TestStripHTMLTags(test *testing.T) {
+ testCases := []struct {
+ name string
+ input string
+ expected string
+ }{
+ {"plain text passthrough", "hello world", "hello world"},
+ {"strips simple tags", "<p>hello</p>", "hello"},
+ {"strips nested tags", "<div><p>hello <strong>world</strong></p></div>", "hello world"},
+ {"handles unicode", "<p>café résumé naïve</p>", "café résumé naïve"},
+ {"empty string", "", ""},
+ {"only tags", "<br/><hr/>", ""},
+ }
+
+ for _, testCase := range testCases {
+ test.Run(testCase.name, func(test *testing.T) {
+ assert.Equal(test, testCase.expected, stripHTMLTags(testCase.input))
+ })
+ }
+}
+
+func TestCountWords(test *testing.T) {
+ emptyResult := countWords("")
+
+ assert.Nil(test, emptyResult)
+
+ twoWords := countWords("hello world")
+
+ require.NotNil(test, twoWords)
+ assert.Equal(test, 2, *twoWords)
+
+ withExtraSpaces := countWords(" hello world ")
+
+ require.NotNil(test, withExtraSpaces)
+ assert.Equal(test, 2, *withExtraSpaces)
+}
+
+func TestStringPointerOrNil(test *testing.T) {
+ assert.Nil(test, stringPointerOrNil(""))
+
+ result := stringPointerOrNil("hello")
+
+ require.NotNil(test, result)
+ assert.Equal(test, "hello", *result)
+}
diff --git a/services/worker/internal/pool/pool_test.go b/services/worker/internal/pool/pool_test.go
new file mode 100644
index 0000000..f347faf
--- /dev/null
+++ b/services/worker/internal/pool/pool_test.go
@@ -0,0 +1,125 @@
+package pool
+
+import (
+ "context"
+ "github.com/stretchr/testify/assert"
+ "log/slog"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+)
+
+func newTestLogger() *slog.Logger {
+ return slog.New(slog.DiscardHandler)
+}
+
+func TestWorkerPoolConcurrencyLimit(test *testing.T) {
+ workerPool := NewWorkerPool(2, newTestLogger())
+ workerContext := context.Background()
+
+ var peakConcurrency atomic.Int32
+ var currentConcurrency atomic.Int32
+ var completedJobs atomic.Int32
+
+ for jobIndex := 0; jobIndex < 10; jobIndex++ {
+ workerPool.Submit(workerContext, func(workContext context.Context) {
+ current := currentConcurrency.Add(1)
+
+ for {
+ peak := peakConcurrency.Load()
+
+ if current <= peak || peakConcurrency.CompareAndSwap(peak, current) {
+ break
+ }
+ }
+
+ time.Sleep(10 * time.Millisecond)
+ currentConcurrency.Add(-1)
+ completedJobs.Add(1)
+ })
+ }
+
+ workerPool.Wait()
+ assert.Equal(test, int32(10), completedJobs.Load())
+ assert.LessOrEqual(test, peakConcurrency.Load(), int32(2))
+}
+
+func TestWorkerPoolPanicRecovery(test *testing.T) {
+ workerPool := NewWorkerPool(2, newTestLogger())
+ workerContext := context.Background()
+
+ var completedAfterPanic atomic.Bool
+
+ workerPool.Submit(workerContext, func(workContext context.Context) {
+ panic("intentional test panic")
+ })
+ time.Sleep(20 * time.Millisecond)
+ workerPool.Submit(workerContext, func(workContext context.Context) {
+ completedAfterPanic.Store(true)
+ })
+ workerPool.Wait()
+ assert.True(test, completedAfterPanic.Load())
+}
+
+func TestWorkerPoolCancelledContext(test *testing.T) {
+ workerPool := NewWorkerPool(1, newTestLogger())
+ cancelledContext, cancelFunction := context.WithCancel(context.Background())
+
+ var blockingMutex sync.Mutex
+
+ blockingMutex.Lock()
+ workerPool.Submit(context.Background(), func(workContext context.Context) {
+ blockingMutex.Lock()
+
+ defer blockingMutex.Unlock()
+ })
+ cancelFunction()
+
+ submitted := workerPool.Submit(cancelledContext, func(workContext context.Context) {})
+
+ assert.False(test, submitted)
+ blockingMutex.Unlock()
+ workerPool.Wait()
+}
+
+func TestWorkerPoolWaitBlocksUntilDone(test *testing.T) {
+ workerPool := NewWorkerPool(4, newTestLogger())
+ workerContext := context.Background()
+
+ var counter atomic.Int32
+
+ for jobIndex := 0; jobIndex < 20; jobIndex++ {
+ workerPool.Submit(workerContext, func(workContext context.Context) {
+ time.Sleep(5 * time.Millisecond)
+ counter.Add(1)
+ })
+ }
+
+ workerPool.Wait()
+ assert.Equal(test, int32(20), counter.Load())
+}
+
+func TestWorkerPoolActiveWorkerCount(test *testing.T) {
+ workerPool := NewWorkerPool(3, newTestLogger())
+
+ assert.Equal(test, 0, workerPool.ActiveWorkerCount())
+
+ var releaseMutex sync.Mutex
+
+ releaseMutex.Lock()
+ workerPool.Submit(context.Background(), func(workContext context.Context) {
+ releaseMutex.Lock()
+
+ defer releaseMutex.Unlock()
+ })
+ workerPool.Submit(context.Background(), func(workContext context.Context) {
+ releaseMutex.Lock()
+
+ defer releaseMutex.Unlock()
+ })
+ time.Sleep(20 * time.Millisecond)
+ assert.Equal(test, 2, workerPool.ActiveWorkerCount())
+ releaseMutex.Unlock()
+ workerPool.Wait()
+}
diff --git a/services/worker/internal/writer/writer.go b/services/worker/internal/writer/writer.go
index f681a4e..543b6e6 100644
--- a/services/worker/internal/writer/writer.go
+++ b/services/worker/internal/writer/writer.go
@@ -3,11 +3,11 @@ package writer
import (
"context"
"fmt"
+ "github.com/Fuwn/asa-news/internal/model"
+ "github.com/jackc/pgx/v5/pgxpool"
"net/url"
"strings"
"time"
- "github.com/Fuwn/asa-news/internal/model"
- "github.com/jackc/pgx/v5/pgxpool"
)
type Writer struct {