summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFuwn <[email protected]>2026-02-10 01:59:01 -0800
committerFuwn <[email protected]>2026-02-10 01:59:01 -0800
commit871985bc9eb42c6a088563e7c34db181f603f407 (patch)
tree31299597a9f246d332b3bf6d5e2bed177648b577
parentfeat: reorder feature grid by attention-grabbing impact (diff)
downloadasa.news-871985bc9eb42c6a088563e7c34db181f603f407.tar.xz
asa.news-871985bc9eb42c6a088563e7c34db181f603f407.zip
fix: harden CI and close remaining test/security gaps
- Make webhook URL tests deterministic with injectable DNS resolver - Wire tier parity checker into CI and root scripts - Add rate_limits cleanup cron job (hourly, >1hr retention) - Change rate limiter to fail closed on RPC error - Add Go worker tests: parser, SSRF protection, error classification, authentication, and worker pool (48 test functions)
-rw-r--r--.github/workflows/ci.yml3
-rw-r--r--apps/web/lib/rate-limit.ts2
-rw-r--r--apps/web/lib/validate-webhook-url.test.ts34
-rw-r--r--apps/web/lib/validate-webhook-url.ts21
-rw-r--r--package.json3
-rw-r--r--scripts/check-tier-parity.ts29
-rw-r--r--services/worker/go.mod2
-rw-r--r--services/worker/internal/fetcher/authentication_test.go90
-rw-r--r--services/worker/internal/fetcher/errors_test.go169
-rw-r--r--services/worker/internal/fetcher/ssrf_protection_test.go114
-rw-r--r--services/worker/internal/parser/parser_test.go287
-rw-r--r--services/worker/internal/pool/pool_test.go125
-rw-r--r--services/worker/internal/writer/writer.go4
-rw-r--r--supabase/schema.sql37
14 files changed, 881 insertions, 39 deletions
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index c8d20a7..ade8b00 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -22,6 +22,9 @@ jobs:
- run: pnpm install --frozen-lockfile
+ - name: tier parity check
+ run: pnpm check:tier-parity
+
- name: lint
run: pnpm lint
diff --git a/apps/web/lib/rate-limit.ts b/apps/web/lib/rate-limit.ts
index c68f02c..45c3a7f 100644
--- a/apps/web/lib/rate-limit.ts
+++ b/apps/web/lib/rate-limit.ts
@@ -16,7 +16,7 @@ export async function rateLimit(
if (error) {
console.error("rate limit check failed:", error)
- return { success: true, remaining: limit }
+ return { success: false, remaining: 0 }
}
return {
diff --git a/apps/web/lib/validate-webhook-url.test.ts b/apps/web/lib/validate-webhook-url.test.ts
index 3375133..78f5b62 100644
--- a/apps/web/lib/validate-webhook-url.test.ts
+++ b/apps/web/lib/validate-webhook-url.test.ts
@@ -1,5 +1,20 @@
import { describe, it, expect } from "vitest"
-import { validateWebhookUrl } from "./validate-webhook-url"
+import { validateWebhookUrl, type DnsResolver } from "./validate-webhook-url"
+
+const publicResolver: DnsResolver = {
+ resolve4: async () => ["93.184.216.34"],
+ resolve6: async () => [],
+}
+
+const privateResolver: DnsResolver = {
+ resolve4: async () => ["192.168.1.1"],
+ resolve6: async () => [],
+}
+
+const failingResolver: DnsResolver = {
+ resolve4: async () => { throw new Error("ENOTFOUND") },
+ resolve6: async () => { throw new Error("ENOTFOUND") },
+}
describe("validateWebhookUrl", () => {
it("rejects empty urls", async () => {
@@ -44,13 +59,26 @@ describe("validateWebhookUrl", () => {
}
})
+ it("rejects hostnames that resolve to private addresses", async () => {
+ const result = await validateWebhookUrl("https://internal.example.com/webhook", privateResolver)
+ expect(result.valid).toBe(false)
+ })
+
+ it("rejects hostnames that fail to resolve", async () => {
+ const result = await validateWebhookUrl("https://nonexistent.example.com/webhook", failingResolver)
+ expect(result.valid).toBe(false)
+ if (!result.valid) {
+ expect(result.error).toContain("could not be resolved")
+ }
+ })
+
it("accepts valid public https urls", async () => {
- const result = await validateWebhookUrl("https://example.com/webhook")
+ const result = await validateWebhookUrl("https://example.com/webhook", publicResolver)
expect(result.valid).toBe(true)
})
it("trims whitespace from urls", async () => {
- const result = await validateWebhookUrl(" https://example.com/webhook ")
+ const result = await validateWebhookUrl(" https://example.com/webhook ", publicResolver)
expect(result.valid).toBe(true)
if (result.valid) {
expect(result.url).toBe("https://example.com/webhook")
diff --git a/apps/web/lib/validate-webhook-url.ts b/apps/web/lib/validate-webhook-url.ts
index 75ec76e..c980770 100644
--- a/apps/web/lib/validate-webhook-url.ts
+++ b/apps/web/lib/validate-webhook-url.ts
@@ -1,4 +1,14 @@
-import { resolve4, resolve6 } from "dns/promises"
+import { resolve4 as defaultResolve4, resolve6 as defaultResolve6 } from "dns/promises"
+
+export interface DnsResolver {
+ resolve4: (hostname: string) => Promise<string[]>
+ resolve6: (hostname: string) => Promise<string[]>
+}
+
+const defaultResolver: DnsResolver = {
+ resolve4: defaultResolve4,
+ resolve6: defaultResolve6,
+}
const PRIVATE_IPV4_RANGES: Array<[number, number, number]> = [
[0, 0, 8],
@@ -51,7 +61,10 @@ function isPrivateIPv6(address: string): boolean {
return false
}
-export async function validateWebhookUrl(rawUrl: string): Promise<{
+export async function validateWebhookUrl(
+ rawUrl: string,
+ resolver: DnsResolver = defaultResolver
+): Promise<{
valid: true
url: string
} | {
@@ -91,13 +104,13 @@ export async function validateWebhookUrl(rawUrl: string): Promise<{
let resolvedAddresses: string[] = []
try {
- const ipv4Addresses = await resolve4(hostname)
+ const ipv4Addresses = await resolver.resolve4(hostname)
resolvedAddresses = resolvedAddresses.concat(ipv4Addresses)
} catch {
}
try {
- const ipv6Addresses = await resolve6(hostname)
+ const ipv6Addresses = await resolver.resolve6(hostname)
resolvedAddresses = resolvedAddresses.concat(ipv6Addresses)
} catch {
}
diff --git a/package.json b/package.json
index 38a8f47..20c9701 100644
--- a/package.json
+++ b/package.json
@@ -5,7 +5,8 @@
"build": "turbo build",
"dev": "turbo dev",
"lint": "turbo lint",
- "test": "turbo test"
+ "test": "turbo test",
+ "check:tier-parity": "pnpm dlx tsx scripts/check-tier-parity.ts"
},
"devDependencies": {
"turbo": "^2"
diff --git a/scripts/check-tier-parity.ts b/scripts/check-tier-parity.ts
index 0681af0..d7ee8a4 100644
--- a/scripts/check-tier-parity.ts
+++ b/scripts/check-tier-parity.ts
@@ -1,26 +1,6 @@
import { readFileSync } from "fs"
import { resolve } from "path"
-
-const TIER_LIMITS = {
- free: {
- maximumFeeds: 10,
- maximumFolders: 3,
- maximumMutedKeywords: 5,
- maximumCustomFeeds: 1,
- },
- pro: {
- maximumFeeds: 200,
- maximumFolders: 10000,
- maximumMutedKeywords: 10000,
- maximumCustomFeeds: 1000,
- },
- developer: {
- maximumFeeds: 500,
- maximumFolders: 10000,
- maximumMutedKeywords: 10000,
- maximumCustomFeeds: 1000,
- },
-} as const
+import { TIER_LIMITS } from "../packages/shared/source/index.ts"
const TRIGGER_MAP: Record<string, keyof (typeof TIER_LIMITS)["free"]> = {
check_subscription_limit: "maximumFeeds",
@@ -29,8 +9,7 @@ const TRIGGER_MAP: Record<string, keyof (typeof TIER_LIMITS)["free"]> = {
check_custom_feed_limit: "maximumCustomFeeds",
}
-const CASE_PATTERN =
- /when\s+'(\w+)'\s+then\s+(\d+)/g
+const CASE_PATTERN = /when\s+'(\w+)'\s+then\s+(\d+)/g
function extractSqlLimits(
schemaContent: string,
@@ -45,7 +24,8 @@ function extractSqlLimits(
throw new Error(`function ${functionName} not found in schema`)
}
- const caseLinePattern = /maximum_allowed\s*:=\s*case\s+current_tier\s+(.*?)\s+end/is
+ const caseLinePattern =
+ /maximum_allowed\s*:=\s*case\s+current_tier\s+(.*?)\s+end/is
const caseMatch = functionMatch[0].match(caseLinePattern)
if (!caseMatch) {
throw new Error(`case expression not found in ${functionName}`)
@@ -87,7 +67,6 @@ for (const [functionName, tsKey] of Object.entries(TRIGGER_MAP)) {
if (hasErrors) {
console.error("\nTier limit parity check FAILED")
- // eslint-disable-next-line no-process-exit
process.exit(1)
} else {
console.log("Tier limit parity check PASSED — all limits match")
diff --git a/services/worker/go.mod b/services/worker/go.mod
index 2588959..332a44d 100644
--- a/services/worker/go.mod
+++ b/services/worker/go.mod
@@ -8,6 +8,7 @@ require (
github.com/craigpastro/pgmq-go v0.6.0
github.com/jackc/pgx/v5 v5.7.2
github.com/mmcdole/gofeed v1.3.0
+ github.com/stretchr/testify v1.11.1
)
require (
@@ -58,7 +59,6 @@ require (
github.com/shirou/gopsutil/v3 v3.24.5 // indirect
github.com/shoenig/go-m1cpu v0.1.6 // indirect
github.com/sirupsen/logrus v1.9.3 // indirect
- github.com/stretchr/testify v1.11.1 // indirect
github.com/testcontainers/testcontainers-go v0.35.0 // indirect
github.com/tklauser/go-sysconf v0.3.14 // indirect
github.com/tklauser/numcpus v0.9.0 // indirect
diff --git a/services/worker/internal/fetcher/authentication_test.go b/services/worker/internal/fetcher/authentication_test.go
new file mode 100644
index 0000000..bc840b9
--- /dev/null
+++ b/services/worker/internal/fetcher/authentication_test.go
@@ -0,0 +1,90 @@
+package fetcher
+
+import (
+ "encoding/base64"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "net/http"
+ "testing"
+)
+
+func TestApplyBearerAuthentication(test *testing.T) {
+ request, _ := http.NewRequest(http.MethodGet, "https://example.com/feed", nil)
+ authenticationError := ApplyAuthentication(request, AuthenticationConfiguration{
+ AuthenticationType: "bearer",
+ AuthenticationValue: "my-secret-token",
+ })
+
+ require.NoError(test, authenticationError)
+ assert.Equal(test, "Bearer my-secret-token", request.Header.Get("Authorization"))
+}
+
+func TestApplyBasicAuthentication(test *testing.T) {
+ request, _ := http.NewRequest(http.MethodGet, "https://example.com/feed", nil)
+ authenticationError := ApplyAuthentication(request, AuthenticationConfiguration{
+ AuthenticationType: "basic",
+ AuthenticationValue: "user:pass",
+ })
+
+ require.NoError(test, authenticationError)
+
+ expectedEncoded := base64.StdEncoding.EncodeToString([]byte("user:pass"))
+
+ assert.Equal(test, "Basic "+expectedEncoded, request.Header.Get("Authorization"))
+}
+
+func TestApplyQueryParamAuthentication(test *testing.T) {
+ request, _ := http.NewRequest(http.MethodGet, "https://example.com/feed?existing=value", nil)
+ authenticationError := ApplyAuthentication(request, AuthenticationConfiguration{
+ AuthenticationType: "query_param",
+ AuthenticationValue: "api_key=abc123",
+ })
+
+ require.NoError(test, authenticationError)
+ assert.Equal(test, "abc123", request.URL.Query().Get("api_key"))
+ assert.Equal(test, "value", request.URL.Query().Get("existing"))
+}
+
+func TestApplyEmptyAuthenticationType(test *testing.T) {
+ request, _ := http.NewRequest(http.MethodGet, "https://example.com/feed", nil)
+ authenticationError := ApplyAuthentication(request, AuthenticationConfiguration{
+ AuthenticationType: "",
+ AuthenticationValue: "",
+ })
+
+ require.NoError(test, authenticationError)
+ assert.Empty(test, request.Header.Get("Authorization"))
+}
+
+func TestApplyUnsupportedAuthenticationType(test *testing.T) {
+ request, _ := http.NewRequest(http.MethodGet, "https://example.com/feed", nil)
+ authenticationError := ApplyAuthentication(request, AuthenticationConfiguration{
+ AuthenticationType: "oauth2",
+ AuthenticationValue: "token",
+ })
+
+ assert.Error(test, authenticationError)
+ assert.Contains(test, authenticationError.Error(), "unsupported")
+}
+
+func TestApplyQueryParamInvalidFormat(test *testing.T) {
+ request, _ := http.NewRequest(http.MethodGet, "https://example.com/feed", nil)
+ authenticationError := ApplyAuthentication(request, AuthenticationConfiguration{
+ AuthenticationType: "query_param",
+ AuthenticationValue: "no-equals-sign",
+ })
+
+ assert.Error(test, authenticationError)
+ assert.Contains(test, authenticationError.Error(), "param_name=value")
+}
+
+func TestApplyQueryParamWithEqualsInValue(test *testing.T) {
+ request, _ := http.NewRequest(http.MethodGet, "https://example.com/feed", nil)
+ authenticationError := ApplyAuthentication(request, AuthenticationConfiguration{
+ AuthenticationType: "query_param",
+ AuthenticationValue: "token=abc=def",
+ })
+
+ require.NoError(test, authenticationError)
+ assert.Equal(test, "abc=def", request.URL.Query().Get("token"))
+}
diff --git a/services/worker/internal/fetcher/errors_test.go b/services/worker/internal/fetcher/errors_test.go
new file mode 100644
index 0000000..e81251b
--- /dev/null
+++ b/services/worker/internal/fetcher/errors_test.go
@@ -0,0 +1,169 @@
+package fetcher
+
+import (
+ "fmt"
+ "github.com/stretchr/testify/assert"
+ "net"
+ "net/url"
+ "testing"
+)
+
+func TestClassifyHTTPStatus401(test *testing.T) {
+ fetchError := ClassifyError(nil, 401)
+
+ assert.Equal(test, 401, fetchError.StatusCode)
+ assert.False(test, fetchError.Retryable)
+ assert.Contains(test, fetchError.UserMessage, "authentication")
+}
+
+func TestClassifyHTTPStatus403(test *testing.T) {
+ fetchError := ClassifyError(nil, 403)
+
+ assert.Equal(test, 403, fetchError.StatusCode)
+ assert.False(test, fetchError.Retryable)
+ assert.Contains(test, fetchError.UserMessage, "forbidden")
+}
+
+func TestClassifyHTTPStatus404(test *testing.T) {
+ fetchError := ClassifyError(nil, 404)
+
+ assert.Equal(test, 404, fetchError.StatusCode)
+ assert.False(test, fetchError.Retryable)
+ assert.Contains(test, fetchError.UserMessage, "not found")
+}
+
+func TestClassifyHTTPStatus410(test *testing.T) {
+ fetchError := ClassifyError(nil, 410)
+
+ assert.Equal(test, 410, fetchError.StatusCode)
+ assert.False(test, fetchError.Retryable)
+ assert.Contains(test, fetchError.UserMessage, "permanently removed")
+}
+
+func TestClassifyHTTPStatus429(test *testing.T) {
+ fetchError := ClassifyError(nil, 429)
+
+ assert.Equal(test, 429, fetchError.StatusCode)
+ assert.True(test, fetchError.Retryable)
+ assert.Contains(test, fetchError.UserMessage, "rate limited")
+}
+
+func TestClassifyHTTPServerErrors(test *testing.T) {
+ serverErrorCodes := []int{500, 502, 503, 504}
+
+ for _, statusCode := range serverErrorCodes {
+ test.Run(fmt.Sprintf("status_%d", statusCode), func(test *testing.T) {
+ fetchError := ClassifyError(nil, statusCode)
+
+ assert.Equal(test, statusCode, fetchError.StatusCode)
+ assert.True(test, fetchError.Retryable)
+ assert.Contains(test, fetchError.UserMessage, "server error")
+ })
+ }
+}
+
+func TestClassifyHTTPUnknownClientError(test *testing.T) {
+ fetchError := ClassifyError(nil, 418)
+
+ assert.Equal(test, 418, fetchError.StatusCode)
+ assert.False(test, fetchError.Retryable)
+ assert.Contains(test, fetchError.UserMessage, "418")
+}
+
+func TestClassifyDNSNotFoundError(test *testing.T) {
+ dnsError := &net.DNSError{
+ Name: "nonexistent.example.com",
+ IsNotFound: true,
+ }
+ fetchError := ClassifyError(dnsError, 0)
+
+ assert.False(test, fetchError.Retryable)
+ assert.Contains(test, fetchError.UserMessage, "DNS")
+}
+
+func TestClassifyDNSTemporaryError(test *testing.T) {
+ dnsError := &net.DNSError{
+ Name: "flaky.example.com",
+ IsNotFound: false,
+ }
+ fetchError := ClassifyError(dnsError, 0)
+
+ assert.True(test, fetchError.Retryable)
+ assert.Contains(test, fetchError.UserMessage, "DNS")
+}
+
+func TestClassifyTimeoutError(test *testing.T) {
+ timeoutError := &url.Error{
+ Op: "Get",
+ URL: "https://slow.example.com",
+ Err: &timeoutErr{},
+ }
+ fetchError := ClassifyError(timeoutError, 0)
+
+ assert.True(test, fetchError.Retryable)
+ assert.Contains(test, fetchError.UserMessage, "timed out")
+}
+
+type timeoutErr struct{}
+
+func (timeoutError *timeoutErr) Error() string { return "connection timed out" }
+func (timeoutError *timeoutErr) Timeout() bool { return true }
+func (timeoutError *timeoutErr) Temporary() bool { return true }
+
+func TestClassifyNetworkOpError(test *testing.T) {
+ opError := &net.OpError{
+ Op: "dial",
+ Net: "tcp",
+ Err: fmt.Errorf("connection refused"),
+ }
+ fetchError := ClassifyError(opError, 0)
+
+ assert.True(test, fetchError.Retryable)
+ assert.Contains(test, fetchError.UserMessage, "network")
+}
+
+func TestClassifyTLSError(test *testing.T) {
+ tlsError := fmt.Errorf("tls: handshake failure")
+ fetchError := ClassifyError(tlsError, 0)
+
+ assert.False(test, fetchError.Retryable)
+ assert.Contains(test, fetchError.UserMessage, "TLS")
+}
+
+func TestClassifyCertificateError(test *testing.T) {
+ certError := fmt.Errorf("x509: certificate has expired")
+ fetchError := ClassifyError(certError, 0)
+
+ assert.False(test, fetchError.Retryable)
+ assert.Contains(test, fetchError.UserMessage, "certificate")
+}
+
+func TestClassifyNilError(test *testing.T) {
+ fetchError := ClassifyError(nil, 0)
+
+ assert.True(test, fetchError.Retryable)
+ assert.Contains(test, fetchError.UserMessage, "unknown")
+}
+
+func TestFetchErrorInterface(test *testing.T) {
+ underlyingError := fmt.Errorf("some network issue")
+ fetchError := &FetchError{
+ StatusCode: 0,
+ UserMessage: "test error",
+ Retryable: true,
+ UnderlyingError: underlyingError,
+ }
+
+ assert.Contains(test, fetchError.Error(), "test error")
+ assert.Contains(test, fetchError.Error(), "some network issue")
+ assert.Equal(test, underlyingError, fetchError.Unwrap())
+}
+
+func TestFetchErrorWithoutUnderlying(test *testing.T) {
+ fetchError := &FetchError{
+ UserMessage: "standalone error",
+ }
+
+ assert.Equal(test, "standalone error", fetchError.Error())
+ assert.Nil(test, fetchError.Unwrap())
+}
diff --git a/services/worker/internal/fetcher/ssrf_protection_test.go b/services/worker/internal/fetcher/ssrf_protection_test.go
new file mode 100644
index 0000000..3e78380
--- /dev/null
+++ b/services/worker/internal/fetcher/ssrf_protection_test.go
@@ -0,0 +1,114 @@
+package fetcher
+
+import (
+ "github.com/stretchr/testify/assert"
+ "net"
+ "testing"
+)
+
+func TestIsReservedAddress(test *testing.T) {
+ reservedAddresses := []struct {
+ name string
+ address string
+ }{
+ {"loopback ipv4", "127.0.0.1"},
+ {"loopback ipv4 alternate", "127.0.0.2"},
+ {"private 10.x", "10.0.0.1"},
+ {"private 10.x deep", "10.255.255.255"},
+ {"private 172.16.x", "172.16.0.1"},
+ {"private 172.31.x", "172.31.255.255"},
+ {"private 192.168.x", "192.168.1.1"},
+ {"link-local", "169.254.1.1"},
+ {"null address", "0.0.0.1"},
+ {"ipv6 loopback", "::1"},
+ {"ipv6 unique local fc", "fc00::1"},
+ {"ipv6 unique local fd", "fd00::1"},
+ {"ipv6 link-local", "fe80::1"},
+ }
+
+ for _, testCase := range reservedAddresses {
+ test.Run(testCase.name, func(test *testing.T) {
+ parsedIP := net.ParseIP(testCase.address)
+
+ assert.True(test, isReservedAddress(parsedIP), "expected %s to be reserved", testCase.address)
+ })
+ }
+}
+
+func TestIsNotReservedAddress(test *testing.T) {
+ publicAddresses := []struct {
+ name string
+ address string
+ }{
+ {"google dns", "8.8.8.8"},
+ {"cloudflare dns", "1.1.1.1"},
+ {"random public", "93.184.216.34"},
+ {"public 172.32", "172.32.0.1"},
+ {"public ipv6", "2001:db8::1"},
+ }
+
+ for _, testCase := range publicAddresses {
+ test.Run(testCase.name, func(test *testing.T) {
+ parsedIP := net.ParseIP(testCase.address)
+
+ assert.False(test, isReservedAddress(parsedIP), "expected %s to not be reserved", testCase.address)
+ })
+ }
+}
+
+func TestValidateFeedURLRejectsUnsupportedSchemes(test *testing.T) {
+ unsupportedURLs := []string{
+ "ftp://example.com/feed",
+ "file:///etc/passwd",
+ "javascript:alert(1)",
+ "data:text/html,hello",
+ }
+
+ for _, feedURL := range unsupportedURLs {
+ test.Run(feedURL, func(test *testing.T) {
+ assert.Error(test, ValidateFeedURL(feedURL))
+ })
+ }
+}
+
+func TestValidateFeedURLRejectsEmptyHostname(test *testing.T) {
+ assert.Error(test, ValidateFeedURL("http:///path"))
+}
+
+func TestValidateFeedURLRejectsPrivateIPs(test *testing.T) {
+ privateURLs := []string{
+ "http://127.0.0.1/feed",
+ "https://10.0.0.1/feed",
+ "http://192.168.1.1/feed",
+ "http://172.16.0.1/feed",
+ "http://[::1]/feed",
+ }
+
+ for _, feedURL := range privateURLs {
+ test.Run(feedURL, func(test *testing.T) {
+ assert.Error(test, ValidateFeedURL(feedURL))
+ })
+ }
+}
+
+func TestValidateFeedURLAcceptsPublicIPs(test *testing.T) {
+ publicURLs := []string{
+ "http://93.184.216.34/feed",
+ "https://8.8.8.8/feed",
+ }
+
+ for _, feedURL := range publicURLs {
+ test.Run(feedURL, func(test *testing.T) {
+ assert.NoError(test, ValidateFeedURL(feedURL))
+ })
+ }
+}
+
+func TestValidateFeedURLAcceptsHTTPAndHTTPS(test *testing.T) {
+ assert.NoError(test, ValidateFeedURL("http://93.184.216.34/feed"))
+ assert.NoError(test, ValidateFeedURL("https://93.184.216.34/feed"))
+}
+
+func TestValidateFeedURLRejectsInvalidURL(test *testing.T) {
+ assert.Error(test, ValidateFeedURL("://broken"))
+}
diff --git a/services/worker/internal/parser/parser_test.go b/services/worker/internal/parser/parser_test.go
new file mode 100644
index 0000000..7a28132
--- /dev/null
+++ b/services/worker/internal/parser/parser_test.go
@@ -0,0 +1,287 @@
+package parser
+
+import (
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "testing"
+ "time"
+)
+
+const validRSS = `<?xml version="1.0" encoding="UTF-8"?>
+<rss version="2.0">
+ <channel>
+ <title>Test Feed</title>
+ <link>https://example.com</link>
+ <item>
+ <guid>entry-1</guid>
+ <title>First Entry</title>
+ <link>https://example.com/1</link>
+ <description>Summary of the first entry.</description>
+ <content:encoded xmlns:content="http://purl.org/rss/1.0/modules/content/">&lt;p&gt;Full content here.&lt;/p&gt;</content:encoded>
+ <author>Alice</author>
+ <pubDate>Mon, 01 Jan 2024 12:00:00 GMT</pubDate>
+ </item>
+ <item>
+ <guid>entry-2</guid>
+ <title>Second Entry</title>
+ <link>https://example.com/2</link>
+ <description>Summary of the second entry.</description>
+ <pubDate>Tue, 02 Jan 2024 12:00:00 GMT</pubDate>
+ </item>
+ </channel>
+</rss>`
+const validAtom = `<?xml version="1.0" encoding="UTF-8"?>
+<feed xmlns="http://www.w3.org/2005/Atom">
+ <title>Atom Feed</title>
+ <link href="https://atom.example.com"/>
+ <entry>
+ <id>atom-1</id>
+ <title>Atom Entry</title>
+ <link href="https://atom.example.com/1"/>
+ <summary>An atom summary.</summary>
+ <updated>2024-01-01T12:00:00Z</updated>
+ <author><name>Bob</name></author>
+ </entry>
+</feed>`
+const podcastRSS = `<?xml version="1.0" encoding="UTF-8"?>
+<rss version="2.0">
+ <channel>
+ <title>My Podcast</title>
+ <link>https://podcast.example.com</link>
+ <item>
+ <guid>ep-1</guid>
+ <title>Episode 1</title>
+ <enclosure url="https://cdn.example.com/ep1.mp3" type="audio/mpeg" length="12345678"/>
+ <pubDate>Mon, 01 Jan 2024 12:00:00 GMT</pubDate>
+ </item>
+ <item>
+ <guid>ep-2</guid>
+ <title>Episode 2</title>
+ <enclosure url="https://cdn.example.com/ep2.mp3" type="audio/mpeg" length="87654321"/>
+ <pubDate>Tue, 02 Jan 2024 12:00:00 GMT</pubDate>
+ </item>
+ </channel>
+</rss>`
+const mixedEnclosureRSS = `<?xml version="1.0" encoding="UTF-8"?>
+<rss version="2.0">
+ <channel>
+ <title>Mixed Feed</title>
+ <link>https://mixed.example.com</link>
+ <item>
+ <guid>item-audio</guid>
+ <title>Audio Item</title>
+ <enclosure url="https://cdn.example.com/audio.mp3" type="audio/mpeg" length="1000"/>
+ </item>
+ <item>
+ <guid>item-text</guid>
+ <title>Text Item</title>
+ <description>Just a text item.</description>
+ </item>
+ </channel>
+</rss>`
+const noGUIDRSS = `<?xml version="1.0" encoding="UTF-8"?>
+<rss version="2.0">
+ <channel>
+ <title>No GUID Feed</title>
+ <link>https://example.com</link>
+ <item>
+ <title>No GUID but has link</title>
+ <link>https://example.com/no-guid</link>
+ </item>
+ <item>
+ <title>No GUID no link</title>
+ <description>Only title and description.</description>
+ </item>
+ </channel>
+</rss>`
+
+func TestParseValidRSS(test *testing.T) {
+ feedParser := NewParser()
+ result, parseError := feedParser.Parse("feed-123", nil, []byte(validRSS))
+
+ require.NoError(test, parseError)
+ assert.Equal(test, "Test Feed", result.FeedTitle)
+ assert.Equal(test, "https://example.com", result.SiteURL)
+ assert.Len(test, result.Entries, 2)
+ assert.Equal(test, 0.0, result.AudioEnclosureRatio)
+
+ first := result.Entries[0]
+
+ assert.Equal(test, "entry-1", first.GUID)
+ assert.Equal(test, "feed-123", first.FeedIdentifier)
+ require.NotNil(test, first.Title)
+ assert.Equal(test, "First Entry", *first.Title)
+ require.NotNil(test, first.URL)
+ assert.Equal(test, "https://example.com/1", *first.URL)
+ require.NotNil(test, first.Summary)
+ assert.Equal(test, "Summary of the first entry.", *first.Summary)
+ require.NotNil(test, first.ContentHTML)
+ assert.Contains(test, *first.ContentHTML, "Full content here.")
+ require.NotNil(test, first.PublishedAt)
+ assert.Nil(test, first.OwnerIdentifier)
+}
+
+func TestParseValidAtom(test *testing.T) {
+ feedParser := NewParser()
+ result, parseError := feedParser.Parse("atom-feed", nil, []byte(validAtom))
+
+ require.NoError(test, parseError)
+ assert.Equal(test, "Atom Feed", result.FeedTitle)
+ assert.Len(test, result.Entries, 1)
+
+ entry := result.Entries[0]
+
+ assert.Equal(test, "atom-1", entry.GUID)
+ require.NotNil(test, entry.Title)
+ assert.Equal(test, "Atom Entry", *entry.Title)
+ require.NotNil(test, entry.Author)
+ assert.Equal(test, "Bob", *entry.Author)
+}
+
+func TestParsePodcastFeed(test *testing.T) {
+ feedParser := NewParser()
+ result, parseError := feedParser.Parse("podcast-1", nil, []byte(podcastRSS))
+
+ require.NoError(test, parseError)
+ assert.Equal(test, "My Podcast", result.FeedTitle)
+ assert.Len(test, result.Entries, 2)
+ assert.InDelta(test, 1.0, result.AudioEnclosureRatio, 0.01)
+
+ firstEpisode := result.Entries[0]
+
+ require.NotNil(test, firstEpisode.EnclosureURL)
+ assert.Equal(test, "https://cdn.example.com/ep1.mp3", *firstEpisode.EnclosureURL)
+ require.NotNil(test, firstEpisode.EnclosureType)
+ assert.Equal(test, "audio/mpeg", *firstEpisode.EnclosureType)
+ require.NotNil(test, firstEpisode.EnclosureLength)
+ assert.Equal(test, int64(12345678), *firstEpisode.EnclosureLength)
+}
+
+func TestParseMixedEnclosureFeed(test *testing.T) {
+ feedParser := NewParser()
+ result, parseError := feedParser.Parse("mixed-1", nil, []byte(mixedEnclosureRSS))
+
+ require.NoError(test, parseError)
+ assert.Len(test, result.Entries, 2)
+ assert.InDelta(test, 0.5, result.AudioEnclosureRatio, 0.01)
+
+ audioItem := result.Entries[0]
+
+ require.NotNil(test, audioItem.EnclosureURL)
+
+ textItem := result.Entries[1]
+
+ assert.Nil(test, textItem.EnclosureURL)
+ assert.Nil(test, textItem.EnclosureType)
+ assert.Nil(test, textItem.EnclosureLength)
+}
+
+func TestParseGUIDFallback(test *testing.T) {
+ feedParser := NewParser()
+ result, parseError := feedParser.Parse("no-guid-feed", nil, []byte(noGUIDRSS))
+
+ require.NoError(test, parseError)
+ assert.Len(test, result.Entries, 2)
+
+ entryWithLink := result.Entries[0]
+
+ assert.Equal(test, "https://example.com/no-guid", entryWithLink.GUID)
+
+ entryWithHash := result.Entries[1]
+
+ assert.True(test, len(entryWithHash.GUID) > 0)
+ assert.Contains(test, entryWithHash.GUID, "sha256:")
+}
+
+func TestParseOwnerIdentifier(test *testing.T) {
+ feedParser := NewParser()
+ ownerIdentifier := "user-abc"
+ result, parseError := feedParser.Parse("owned-feed", &ownerIdentifier, []byte(validRSS))
+
+ require.NoError(test, parseError)
+
+ for _, entry := range result.Entries {
+ require.NotNil(test, entry.OwnerIdentifier)
+ assert.Equal(test, "user-abc", *entry.OwnerIdentifier)
+ }
+}
+
+func TestParseInvalidXML(test *testing.T) {
+ feedParser := NewParser()
+ _, parseError := feedParser.Parse("bad-feed", nil, []byte("this is not xml at all"))
+
+ assert.Error(test, parseError)
+}
+
+func TestParsePublishedDateFallsBackToUpdated(test *testing.T) {
+ feedParser := NewParser()
+ result, parseError := feedParser.Parse("atom-feed", nil, []byte(validAtom))
+
+ require.NoError(test, parseError)
+
+ entry := result.Entries[0]
+
+ require.NotNil(test, entry.PublishedAt)
+
+ expectedTime := time.Date(2024, 1, 1, 12, 0, 0, 0, time.UTC)
+
+ assert.True(test, entry.PublishedAt.Equal(expectedTime))
+}
+
+func TestParseWordCount(test *testing.T) {
+ feedParser := NewParser()
+ result, parseError := feedParser.Parse("feed-123", nil, []byte(validRSS))
+
+ require.NoError(test, parseError)
+
+ first := result.Entries[0]
+
+ require.NotNil(test, first.WordCount)
+ assert.Greater(test, *first.WordCount, 0)
+}
+
+func TestStripHTMLTags(test *testing.T) {
+ testCases := []struct {
+ name string
+ input string
+ expected string
+ }{
+ {"plain text passthrough", "hello world", "hello world"},
+ {"strips simple tags", "<p>hello</p>", "hello"},
+ {"strips nested tags", "<div><p>hello <strong>world</strong></p></div>", "hello world"},
+ {"handles unicode", "<p>café résumé naïve</p>", "café résumé naïve"},
+ {"empty string", "", ""},
+ {"only tags", "<br/><hr/>", ""},
+ }
+
+ for _, testCase := range testCases {
+ test.Run(testCase.name, func(test *testing.T) {
+ assert.Equal(test, testCase.expected, stripHTMLTags(testCase.input))
+ })
+ }
+}
+
+func TestCountWords(test *testing.T) {
+ emptyResult := countWords("")
+
+ assert.Nil(test, emptyResult)
+
+ twoWords := countWords("hello world")
+
+ require.NotNil(test, twoWords)
+ assert.Equal(test, 2, *twoWords)
+
+ withExtraSpaces := countWords(" hello world ")
+
+ require.NotNil(test, withExtraSpaces)
+ assert.Equal(test, 2, *withExtraSpaces)
+}
+
+func TestStringPointerOrNil(test *testing.T) {
+ assert.Nil(test, stringPointerOrNil(""))
+
+ result := stringPointerOrNil("hello")
+
+ require.NotNil(test, result)
+ assert.Equal(test, "hello", *result)
+}
diff --git a/services/worker/internal/pool/pool_test.go b/services/worker/internal/pool/pool_test.go
new file mode 100644
index 0000000..f347faf
--- /dev/null
+++ b/services/worker/internal/pool/pool_test.go
@@ -0,0 +1,125 @@
+package pool
+
+import (
+ "context"
+ "github.com/stretchr/testify/assert"
+ "log/slog"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+)
+
+func newTestLogger() *slog.Logger {
+ return slog.New(slog.DiscardHandler)
+}
+
+func TestWorkerPoolConcurrencyLimit(test *testing.T) {
+ workerPool := NewWorkerPool(2, newTestLogger())
+ workerContext := context.Background()
+
+ var peakConcurrency atomic.Int32
+ var currentConcurrency atomic.Int32
+ var completedJobs atomic.Int32
+
+ for jobIndex := 0; jobIndex < 10; jobIndex++ {
+ workerPool.Submit(workerContext, func(workContext context.Context) {
+ current := currentConcurrency.Add(1)
+
+ for {
+ peak := peakConcurrency.Load()
+
+ if current <= peak || peakConcurrency.CompareAndSwap(peak, current) {
+ break
+ }
+ }
+
+ time.Sleep(10 * time.Millisecond)
+ currentConcurrency.Add(-1)
+ completedJobs.Add(1)
+ })
+ }
+
+ workerPool.Wait()
+ assert.Equal(test, int32(10), completedJobs.Load())
+ assert.LessOrEqual(test, peakConcurrency.Load(), int32(2))
+}
+
+func TestWorkerPoolPanicRecovery(test *testing.T) {
+ workerPool := NewWorkerPool(2, newTestLogger())
+ workerContext := context.Background()
+
+ var completedAfterPanic atomic.Bool
+
+ workerPool.Submit(workerContext, func(workContext context.Context) {
+ panic("intentional test panic")
+ })
+ time.Sleep(20 * time.Millisecond)
+ workerPool.Submit(workerContext, func(workContext context.Context) {
+ completedAfterPanic.Store(true)
+ })
+ workerPool.Wait()
+ assert.True(test, completedAfterPanic.Load())
+}
+
+func TestWorkerPoolCancelledContext(test *testing.T) {
+ workerPool := NewWorkerPool(1, newTestLogger())
+ cancelledContext, cancelFunction := context.WithCancel(context.Background())
+
+ var blockingMutex sync.Mutex
+
+ blockingMutex.Lock()
+ workerPool.Submit(context.Background(), func(workContext context.Context) {
+ blockingMutex.Lock()
+
+ defer blockingMutex.Unlock()
+ })
+ cancelFunction()
+
+ submitted := workerPool.Submit(cancelledContext, func(workContext context.Context) {})
+
+ assert.False(test, submitted)
+ blockingMutex.Unlock()
+ workerPool.Wait()
+}
+
+func TestWorkerPoolWaitBlocksUntilDone(test *testing.T) {
+ workerPool := NewWorkerPool(4, newTestLogger())
+ workerContext := context.Background()
+
+ var counter atomic.Int32
+
+ for jobIndex := 0; jobIndex < 20; jobIndex++ {
+ workerPool.Submit(workerContext, func(workContext context.Context) {
+ time.Sleep(5 * time.Millisecond)
+ counter.Add(1)
+ })
+ }
+
+ workerPool.Wait()
+ assert.Equal(test, int32(20), counter.Load())
+}
+
+func TestWorkerPoolActiveWorkerCount(test *testing.T) {
+ workerPool := NewWorkerPool(3, newTestLogger())
+
+ assert.Equal(test, 0, workerPool.ActiveWorkerCount())
+
+ var releaseMutex sync.Mutex
+
+ releaseMutex.Lock()
+ workerPool.Submit(context.Background(), func(workContext context.Context) {
+ releaseMutex.Lock()
+
+ defer releaseMutex.Unlock()
+ })
+ workerPool.Submit(context.Background(), func(workContext context.Context) {
+ releaseMutex.Lock()
+
+ defer releaseMutex.Unlock()
+ })
+ time.Sleep(20 * time.Millisecond)
+ assert.Equal(test, 2, workerPool.ActiveWorkerCount())
+ releaseMutex.Unlock()
+ workerPool.Wait()
+}
diff --git a/services/worker/internal/writer/writer.go b/services/worker/internal/writer/writer.go
index f681a4e..543b6e6 100644
--- a/services/worker/internal/writer/writer.go
+++ b/services/worker/internal/writer/writer.go
@@ -3,11 +3,11 @@ package writer
import (
"context"
"fmt"
+ "github.com/Fuwn/asa-news/internal/model"
+ "github.com/jackc/pgx/v5/pgxpool"
"net/url"
"strings"
"time"
- "github.com/Fuwn/asa-news/internal/model"
- "github.com/jackc/pgx/v5/pgxpool"
)
type Writer struct {
diff --git a/supabase/schema.sql b/supabase/schema.sql
index 4f61f0b..f5009f7 100644
--- a/supabase/schema.sql
+++ b/supabase/schema.sql
@@ -2,7 +2,7 @@
-- PostgreSQL database dump
--
--- \restrict dG3SzK2uS18gU4m6KVz2cn807h79RHaJThBHm9wHtsku5jl48o1kEUpbawNImmT
+-- \restrict EIsfb99KJ1xngT5GOQDV7f1tIMMDJiOJsvB1YZBvCSU5Qj8tsxmafuQpd9VaeJM
-- Dumped from database version 17.6
-- Dumped by pg_dump version 17.6
@@ -474,6 +474,29 @@ $$;
ALTER FUNCTION "public"."cleanup_stale_entries"() OWNER TO "postgres";
--
+-- Name: cleanup_stale_rate_limits(); Type: FUNCTION; Schema: public; Owner: postgres
+--
+
+CREATE OR REPLACE FUNCTION "public"."cleanup_stale_rate_limits"() RETURNS integer
+ LANGUAGE "plpgsql" SECURITY DEFINER
+ SET "search_path" TO 'public'
+ AS $$
+DECLARE
+ deleted_count integer;
+BEGIN
+ DELETE FROM public.rate_limits
+ WHERE window_start < now() - interval '1 hour'
+ RETURNING 1;
+
+ GET DIAGNOSTICS deleted_count = ROW_COUNT;
+ RETURN deleted_count;
+END;
+$$;
+
+
+ALTER FUNCTION "public"."cleanup_stale_rate_limits"() OWNER TO "postgres";
+
+--
-- Name: decrement_custom_feed_count(); Type: FUNCTION; Schema: public; Owner: postgres
--
@@ -3206,6 +3229,16 @@ GRANT ALL ON FUNCTION "public"."cleanup_stale_entries"() TO "service_role";
--
+-- Name: FUNCTION "cleanup_stale_rate_limits"(); Type: ACL; Schema: public; Owner: postgres
+--
+
+REVOKE ALL ON FUNCTION "public"."cleanup_stale_rate_limits"() FROM PUBLIC;
+GRANT ALL ON FUNCTION "public"."cleanup_stale_rate_limits"() TO "anon";
+GRANT ALL ON FUNCTION "public"."cleanup_stale_rate_limits"() TO "authenticated";
+GRANT ALL ON FUNCTION "public"."cleanup_stale_rate_limits"() TO "service_role";
+
+
+--
-- Name: FUNCTION "decrement_custom_feed_count"(); Type: ACL; Schema: public; Owner: postgres
--
@@ -3686,5 +3719,5 @@ ALTER DEFAULT PRIVILEGES FOR ROLE "postgres" IN SCHEMA "public" GRANT ALL ON TAB
-- PostgreSQL database dump complete
--
--- \unrestrict dG3SzK2uS18gU4m6KVz2cn807h79RHaJThBHm9wHtsku5jl48o1kEUpbawNImmT
+-- \unrestrict EIsfb99KJ1xngT5GOQDV7f1tIMMDJiOJsvB1YZBvCSU5Qj8tsxmafuQpd9VaeJM