diff options
| author | Fuwn <[email protected]> | 2026-02-26 15:41:45 -0800 |
|---|---|---|
| committer | Fuwn <[email protected]> | 2026-02-26 15:41:45 -0800 |
| commit | fec9114caaa7d274e524793d5eb0cf2ef2c5af11 (patch) | |
| tree | 0897394ccdfaf6633e1a4ca8eb02bff49bb93c00 | |
| parent | feat: add read-only PLC API compatibility endpoints (diff) | |
| download | plutia-test-fec9114caaa7d274e524793d5eb0cf2ef2c5af11.tar.xz plutia-test-fec9114caaa7d274e524793d5eb0cf2ef2c5af11.zip | |
feat: Apply Iku formatting
30 files changed, 1529 insertions, 44 deletions
diff --git a/cmd/plutia/main.go b/cmd/plutia/main.go index 2685a26..8939f4f 100644 --- a/cmd/plutia/main.go +++ b/cmd/plutia/main.go @@ -17,7 +17,6 @@ import ( "strings" "syscall" "time" - "github.com/Fuwn/plutia/internal/api" "github.com/Fuwn/plutia/internal/checkpoint" "github.com/Fuwn/plutia/internal/config" @@ -45,7 +44,9 @@ func main() { usage() os.Exit(2) } + cmd := os.Args[1] + switch cmd { case "serve": if err := runServe(os.Args[2:]); err != nil { @@ -85,18 +86,24 @@ func runServe(args []string) error { fs := flag.NewFlagSet("serve", flag.ExitOnError) configPath := fs.String("config", "config.yaml", "config path") maxOps := fs.Uint64("max-ops", 0, "max operations to ingest in this process (0 = unlimited)") + if err := fs.Parse(args); err != nil { return err } + app, err := bootstrap(*configPath) + if err != nil { return err } + defer app.store.Close() defer app.service.Close() + app.service.SetMaxOps(*maxOps) ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + defer stop() if err := app.service.Replay(ctx); err != nil { @@ -111,18 +118,21 @@ func runServe(args []string) error { Addr: app.cfg.ListenAddr, Handler: app.apiServer.Handler(), } - errCh := make(chan error, 2) pollDone := make(chan struct{}) + go func() { log.Printf("HTTP server listening on %s", app.cfg.ListenAddr) + if err := httpSrv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { errCh <- err } }() + if !app.service.IsCorrupted() { go func() { defer close(pollDone) + if err := app.service.Poll(ctx); err != nil && !errors.Is(err, context.Canceled) { errCh <- err } @@ -132,21 +142,28 @@ func runServe(args []string) error { } var runErr error + select { case <-ctx.Done(): case err := <-errCh: runErr = err + stop() } + <-pollDone shutdownCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + flushErr := app.service.Flush(shutdownCtx) httpErr := httpSrv.Shutdown(shutdownCtx) + if runErr != nil || flushErr != nil || httpErr != nil { return errors.Join(runErr, flushErr, httpErr) } + return nil } @@ -154,26 +171,36 @@ func runReplay(args []string) error { fs := flag.NewFlagSet("replay", flag.ExitOnError) configPath := fs.String("config", "config.yaml", "config path") maxOps := fs.Uint64("max-ops", 0, "max operations to ingest in this process (0 = unlimited)") + if err := fs.Parse(args); err != nil { return err } + app, err := bootstrap(*configPath) + if err != nil { return err } + defer app.store.Close() defer app.service.Close() + app.service.SetMaxOps(*maxOps) ctx := context.Background() + if err := app.service.Replay(ctx); err != nil { return err } + if err := app.service.Flush(ctx); err != nil { return err } + seq, _ := app.store.GetGlobalSeq() + fmt.Printf("replay complete at sequence %d\n", seq) + return nil } @@ -181,43 +208,58 @@ func runVerify(args []string) error { fs := flag.NewFlagSet("verify", flag.ExitOnError) configPath := fs.String("config", "config.yaml", "config path") did := fs.String("did", "", "did to verify") + if err := fs.Parse(args); err != nil { return err } + if *did == "" { return fmt.Errorf("--did is required") } + app, err := bootstrap(*configPath) + if err != nil { return err } + defer app.store.Close() defer app.service.Close() if err := app.service.VerifyDID(context.Background(), *did); err != nil { return err } + fmt.Printf("verification succeeded for %s\n", *did) + return nil } func runSnapshot(args []string) error { fs := flag.NewFlagSet("snapshot", flag.ExitOnError) configPath := fs.String("config", "config.yaml", "config path") + if err := fs.Parse(args); err != nil { return err } + app, err := bootstrap(*configPath) + if err != nil { return err } + defer app.store.Close() defer app.service.Close() + cp, err := app.service.Snapshot(context.Background()) + if err != nil { return err } + fmt.Printf("checkpoint sequence=%d hash=%s\n", cp.Sequence, cp.CheckpointHash) + return nil } @@ -226,42 +268,54 @@ func runBench(args []string) error { configPath := fs.String("config", "config.yaml", "config path") maxOps := fs.Uint64("max-ops", 200000, "max operations to ingest for benchmark") interval := fs.Duration("interval", 10*time.Second, "rolling report interval") + if err := fs.Parse(args); err != nil { return err } + app, err := bootstrap(*configPath) + if err != nil { return err } + defer app.store.Close() defer app.service.Close() + app.service.SetMaxOps(*maxOps) startSeq, err := app.store.GetGlobalSeq() + if err != nil { return err } + start := time.Now() lastSeq := startSeq pid := os.Getpid() - done := make(chan struct{}) + go func() { ticker := time.NewTicker(*interval) + defer ticker.Stop() + for { select { case <-done: return case <-ticker.C: seq, err := app.store.GetGlobalSeq() + if err != nil { continue } + delta := seq - lastSeq lastSeq = seq opsSec := float64(delta) / interval.Seconds() cpuPct, rssKB, _ := processMetrics(pid) + fmt.Printf("rolling seq=%d ops_sec=%.2f cpu_pct=%.2f rss_kb=%d\n", seq, opsSec, cpuPct, rssKB) } } @@ -269,23 +323,29 @@ func runBench(args []string) error { ctx := context.Background() replayErr := app.service.Replay(ctx) + close(done) + if replayErr != nil { return replayErr } + if err := app.service.Flush(ctx); err != nil { return err } endSeq, err := app.store.GetGlobalSeq() + if err != nil { return err } + elapsed := time.Since(start) totalOps := endSeq - startSeq totalOpsSec := float64(totalOps) / elapsed.Seconds() cpuPct, rssKB, _ := processMetrics(pid) totalBytes, opsBytes, indexBytes, cpBytes, err := diskUsageBreakdown(app.cfg.DataDir) + if err != nil { return err } @@ -293,6 +353,7 @@ func runBench(args []string) error { fmt.Printf("bench_total elapsed=%s ops=%d ops_sec=%.2f cpu_pct=%.2f rss_kb=%d\n", elapsed.Round(time.Millisecond), totalOps, totalOpsSec, cpuPct, rssKB) fmt.Printf("bench_disk total_bytes=%d ops_bytes=%d index_bytes=%d checkpoints_bytes=%d bytes_per_op=%.2f\n", totalBytes, opsBytes, indexBytes, cpBytes, float64(totalBytes)/max(1, float64(totalOps))) + return nil } @@ -300,43 +361,56 @@ func runCompare(args []string) error { fs := flag.NewFlagSet("compare", flag.ExitOnError) configPath := fs.String("config", "config.yaml", "config path") remote := fs.String("remote", "", "remote mirror base URL") + if err := fs.Parse(args); err != nil { return err } + if strings.TrimSpace(*remote) == "" { return fmt.Errorf("--remote is required") } app, err := bootstrap(*configPath) + if err != nil { return err } + defer app.store.Close() defer app.service.Close() local, ok, err := app.store.GetLatestCheckpoint() + if err != nil { return fmt.Errorf("load local checkpoint: %w", err) } + if !ok { return fmt.Errorf("local mirror has no checkpoints") } url := strings.TrimRight(*remote, "/") + "/checkpoints/latest" req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, url, nil) + if err != nil { return err } + httpClient := &http.Client{Timeout: app.cfg.RequestTimeout} resp, err := httpClient.Do(req) + if err != nil { return fmt.Errorf("fetch remote checkpoint: %w", err) } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { return fmt.Errorf("remote checkpoint response status=%d", resp.StatusCode) } + var remoteCP types.CheckpointV1 + if err := json.NewDecoder(resp.Body).Decode(&remoteCP); err != nil { return fmt.Errorf("decode remote checkpoint: %w", err) } @@ -345,30 +419,41 @@ func runCompare(args []string) error { fmt.Printf("remote sequence=%d root=%s signature_present=%t\n", remoteCP.Sequence, remoteCP.DIDMerkleRoot, remoteCP.Signature != "") mismatch := false + if local.Sequence == remoteCP.Sequence && local.DIDMerkleRoot != remoteCP.DIDMerkleRoot { fmt.Printf("divergence: matching sequence %d with different merkle roots\n", local.Sequence) + mismatch = true } + if local.Sequence > remoteCP.Sequence { fmt.Printf("local mirror is ahead by %d checkpoints\n", local.Sequence-remoteCP.Sequence) + mismatch = true } else if remoteCP.Sequence > local.Sequence { fmt.Printf("remote mirror is ahead by %d checkpoints\n", remoteCP.Sequence-local.Sequence) + mismatch = true } + if (local.Signature == "") != (remoteCP.Signature == "") { fmt.Println("signature presence mismatch between local and remote checkpoints") + mismatch = true } + if mismatch { return fmt.Errorf("mirror comparison mismatch") } + fmt.Println("mirrors match at latest checkpoint") + return nil } func runVersion() error { fmt.Print(formatVersion()) + return nil } @@ -379,22 +464,29 @@ func formatVersion() string { func bootstrap(path string) (*app, error) { cfg, err := config.Load(path) + if err != nil { return nil, err } + for _, p := range []string{cfg.DataDir, filepath.Join(cfg.DataDir, "ops"), filepath.Join(cfg.DataDir, "index"), filepath.Join(cfg.DataDir, "checkpoints")} { if err := os.MkdirAll(p, 0o755); err != nil { return nil, fmt.Errorf("mkdir %s: %w", p, err) } } + store, err := storage.OpenPebble(cfg.DataDir) + if err != nil { return nil, err } + mode, err := store.GetMode() + if err != nil { return nil, err } + if mode == "" { if err := store.SetMode(cfg.Mode); err != nil { return nil, err @@ -404,12 +496,15 @@ func bootstrap(path string) (*app, error) { } var blockLog *storage.BlockLog + if cfg.Mode == config.ModeMirror { blockLog, err = storage.OpenBlockLog(cfg.DataDir, cfg.ZstdLevel, cfg.BlockSizeMB) + if err != nil { return nil, err } } + client := ingest.NewClient(cfg.PLCSource, ingest.ClientOptions{ MaxAttempts: cfg.HTTPRetryMaxAttempts, BaseDelay: cfg.HTTPRetryBaseDelay, @@ -423,6 +518,7 @@ func bootstrap(path string) (*app, error) { BuildDate: buildDate, GoVersion: runtime.Version(), })) + return &app{cfg: cfg, store: store, service: service, apiServer: apiServer, checkpointM: checkpointMgr}, nil } @@ -442,54 +538,74 @@ Commands: func processMetrics(pid int) (cpuPct float64, rssKB int64, err error) { out, err := exec.Command("ps", "-p", strconv.Itoa(pid), "-o", "pcpu=,rss=").Output() + if err != nil { return 0, 0, err } + fields := strings.Fields(string(out)) + if len(fields) < 2 { return 0, 0, fmt.Errorf("unexpected ps output: %q", strings.TrimSpace(string(out))) } + cpuPct, _ = strconv.ParseFloat(fields[0], 64) rssKB, _ = strconv.ParseInt(fields[1], 10, 64) + return cpuPct, rssKB, nil } func diskUsageBreakdown(dataDir string) (total, ops, index, checkpoints int64, err error) { total, err = dirSize(dataDir) + if err != nil { return 0, 0, 0, 0, err } + ops, err = dirSize(filepath.Join(dataDir, "ops")) + if err != nil { return 0, 0, 0, 0, err } + index, err = dirSize(filepath.Join(dataDir, "index")) + if err != nil { return 0, 0, 0, 0, err } + checkpoints, err = dirSize(filepath.Join(dataDir, "checkpoints")) + if err != nil { return 0, 0, 0, 0, err } + return total, ops, index, checkpoints, nil } func dirSize(path string) (int64, error) { var total int64 + err := filepath.WalkDir(path, func(_ string, d os.DirEntry, err error) error { if err != nil { return err } + if d.IsDir() { return nil } + info, err := d.Info() + if err != nil { return err } + total += info.Size() + return nil }) + return total, err } @@ -497,5 +613,6 @@ func max(a, b float64) float64 { if a > b { return a } + return b } diff --git a/cmd/plutia/version_test.go b/cmd/plutia/version_test.go index 3869de6..0e3b8a4 100644 --- a/cmd/plutia/version_test.go +++ b/cmd/plutia/version_test.go @@ -10,6 +10,7 @@ func TestFormatVersionIncludesBuildMetadata(t *testing.T) { version = "v0.1.0" commit = "abc123" buildDate = "2026-02-26T00:00:00Z" + t.Cleanup(func() { version = oldVersion commit = oldCommit @@ -23,6 +24,7 @@ func TestFormatVersionIncludesBuildMetadata(t *testing.T) { "BuildDate: 2026-02-26T00:00:00Z", "GoVersion: go", } + for _, want := range checks { if !strings.Contains(out, want) { t.Fatalf("version output missing %q: %s", want, out) diff --git a/internal/api/observability.go b/internal/api/observability.go index ebd7711..4654b28 100644 --- a/internal/api/observability.go +++ b/internal/api/observability.go @@ -8,7 +8,6 @@ import ( "strings" "sync" "time" - "github.com/Fuwn/plutia/internal/config" "github.com/Fuwn/plutia/internal/ingest" "github.com/Fuwn/plutia/internal/storage" @@ -32,9 +31,11 @@ type serverMetrics struct { func newServerMetrics(cfg config.Config, store storage.Store, ingestor *ingest.Service) *serverMetrics { reg := prometheus.NewRegistry() + var diskMu sync.Mutex var diskCached int64 var diskCachedAt time.Time + m := &serverMetrics{ registry: reg, checkpointDuration: prometheus.NewHistogram(prometheus.HistogramOpts{ @@ -47,6 +48,7 @@ func newServerMetrics(cfg config.Config, store storage.Store, ingestor *ingest.S Help: "Latest checkpoint sequence generated by this mirror.", }), } + reg.MustRegister(m.checkpointDuration, m.checkpointSequence) reg.MustRegister(prometheus.NewCounterFunc( prometheus.CounterOpts{ @@ -55,9 +57,11 @@ func newServerMetrics(cfg config.Config, store storage.Store, ingestor *ingest.S }, func() float64 { seq, err := store.GetGlobalSeq() + if err != nil { return 0 } + return float64(seq) }, )) @@ -70,6 +74,7 @@ func newServerMetrics(cfg config.Config, store storage.Store, ingestor *ingest.S if ingestor == nil { return 0 } + return ingestor.Stats().IngestOpsPerSec }, )) @@ -82,6 +87,7 @@ func newServerMetrics(cfg config.Config, store storage.Store, ingestor *ingest.S if ingestor == nil { return 0 } + return float64(ingestor.Stats().LagOps) }, )) @@ -94,6 +100,7 @@ func newServerMetrics(cfg config.Config, store storage.Store, ingestor *ingest.S if ingestor == nil { return 0 } + return float64(ingestor.Stats().VerifyFailures) }, )) @@ -104,14 +111,18 @@ func newServerMetrics(cfg config.Config, store storage.Store, ingestor *ingest.S }, func() float64 { diskMu.Lock() + defer diskMu.Unlock() + if diskCachedAt.IsZero() || time.Since(diskCachedAt) > 5*time.Second { size, err := dirSize(cfg.DataDir) + if err == nil { diskCached = size diskCachedAt = time.Now() } } + return float64(diskCached) }, )) @@ -124,14 +135,18 @@ func newServerMetrics(cfg config.Config, store storage.Store, ingestor *ingest.S if ingestor != nil { return float64(ingestor.Stats().DIDCount) } + count := uint64(0) _ = store.ForEachState(func(_ types.StateV1) error { count++ + return nil }) + return float64(count) }, )) + return m } @@ -182,18 +197,23 @@ type ipRateLimiter struct { func newIPRateLimiter(cfg config.RateLimit) *ipRateLimiter { def := config.Default().RateLimit + if cfg.ResolveRPS <= 0 { cfg.ResolveRPS = def.ResolveRPS } + if cfg.ResolveBurst <= 0 { cfg.ResolveBurst = def.ResolveBurst } + if cfg.ProofRPS <= 0 { cfg.ProofRPS = def.ProofRPS } + if cfg.ProofBurst <= 0 { cfg.ProofBurst = def.ProofBurst } + return &ipRateLimiter{ buckets: map[string]*tokenBucket{}, resolve: endpointPolicy{ @@ -210,7 +230,9 @@ func newIPRateLimiter(cfg config.RateLimit) *ipRateLimiter { func (l *ipRateLimiter) Allow(ip string, class limiterClass) bool { now := time.Now() + l.mu.Lock() + defer l.mu.Unlock() if now.Sub(l.lastSweep) > 2*time.Minute { @@ -219,75 +241,99 @@ func (l *ipRateLimiter) Allow(ip string, class limiterClass) bool { delete(l.buckets, key) } } + l.lastSweep = now } policy := l.resolve routeKey := "resolve" + if class == limiterProof { policy = l.proof routeKey = "proof" } + key := routeKey + "|" + ip b, ok := l.buckets[key] + if !ok { l.buckets[key] = &tokenBucket{ tokens: policy.burst - 1, last: now, lastSeen: now, } + return true } + elapsed := now.Sub(b.last).Seconds() + if elapsed > 0 { b.tokens += elapsed * policy.rps + if b.tokens > policy.burst { b.tokens = policy.burst } } + b.last = now b.lastSeen = now + if b.tokens < 1 { return false } + b.tokens-- + return true } func clientIP(r *http.Request) string { if forwarded := strings.TrimSpace(r.Header.Get("X-Forwarded-For")); forwarded != "" { parts := strings.Split(forwarded, ",") + if len(parts) > 0 { if ip := strings.TrimSpace(parts[0]); ip != "" { return ip } } } + if realIP := strings.TrimSpace(r.Header.Get("X-Real-IP")); realIP != "" { return realIP } + host, _, err := net.SplitHostPort(r.RemoteAddr) + if err == nil && host != "" { return host } + return r.RemoteAddr } func dirSize(path string) (int64, error) { var total int64 + err := filepath.WalkDir(path, func(_ string, d os.DirEntry, err error) error { if err != nil { return err } + if d.IsDir() { return nil } + info, err := d.Info() + if err != nil { return err } + total += info.Size() + return nil }) + return total, err } diff --git a/internal/api/plc_compatibility_test.go b/internal/api/plc_compatibility_test.go index 67cbafa..b914278 100644 --- a/internal/api/plc_compatibility_test.go +++ b/internal/api/plc_compatibility_test.go @@ -14,7 +14,6 @@ import ( "strings" "testing" "time" - "github.com/Fuwn/plutia/internal/checkpoint" "github.com/Fuwn/plutia/internal/config" "github.com/Fuwn/plutia/internal/ingest" @@ -24,28 +23,37 @@ import ( func TestPLCCompatibilityGetDIDMatchesStoredDocument(t *testing.T) { ts, store, _, cleanup := newCompatibilityServer(t) + defer cleanup() state, ok, err := store.GetState("did:plc:alice") + if err != nil { t.Fatalf("get state: %v", err) } + if !ok { t.Fatalf("state not found") } resp, err := http.Get(ts.URL + "/did:plc:alice") + if err != nil { t.Fatalf("get did: %v", err) } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { t.Fatalf("status: got %d want 200", resp.StatusCode) } + if got := resp.Header.Get("Content-Type"); !strings.Contains(got, "application/did+ld+json") { t.Fatalf("content-type mismatch: %s", got) } + body, _ := io.ReadAll(resp.Body) + if strings.TrimSpace(string(body)) != strings.TrimSpace(string(state.DIDDocument)) { t.Fatalf("did document mismatch\n got: %s\nwant: %s", string(body), string(state.DIDDocument)) } @@ -53,26 +61,35 @@ func TestPLCCompatibilityGetDIDMatchesStoredDocument(t *testing.T) { func TestPLCCompatibilityGetLogOrdered(t *testing.T) { ts, _, recs, cleanup := newCompatibilityServer(t) + defer cleanup() resp, err := http.Get(ts.URL + "/did:plc:alice/log") + if err != nil { t.Fatalf("get log: %v", err) } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { t.Fatalf("status: got %d want 200", resp.StatusCode) } + var ops []map[string]any + if err := json.NewDecoder(resp.Body).Decode(&ops); err != nil { t.Fatalf("decode log: %v", err) } + if len(ops) != 2 { t.Fatalf("log length mismatch: got %d want 2", len(ops)) } + if _, ok := ops[0]["prev"]; ok { t.Fatalf("first op should be genesis without prev") } + if prev, _ := ops[1]["prev"].(string); prev != recs[0].CID { t.Fatalf("second op prev mismatch: got %q want %q", prev, recs[0].CID) } @@ -80,29 +97,39 @@ func TestPLCCompatibilityGetLogOrdered(t *testing.T) { func TestPLCCompatibilityExportCount(t *testing.T) { ts, _, _, cleanup := newCompatibilityServer(t) + defer cleanup() resp, err := http.Get(ts.URL + "/export?count=2") + if err != nil { t.Fatalf("get export: %v", err) } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { t.Fatalf("status: got %d want 200", resp.StatusCode) } + if got := resp.Header.Get("Content-Type"); !strings.Contains(got, "application/jsonlines") { t.Fatalf("content-type mismatch: %s", got) } + body, _ := io.ReadAll(resp.Body) lines := strings.Split(strings.TrimSpace(string(body)), "\n") + if len(lines) != 2 { t.Fatalf("line count mismatch: got %d want 2", len(lines)) } + for _, line := range lines { var entry map[string]any + if err := json.Unmarshal([]byte(line), &entry); err != nil { t.Fatalf("decode export line: %v", err) } + for _, key := range []string{"did", "operation", "cid", "nullified", "createdAt"} { if _, ok := entry[key]; !ok { t.Fatalf("missing export key %q in %v", key, entry) @@ -113,20 +140,27 @@ func TestPLCCompatibilityExportCount(t *testing.T) { func TestPLCCompatibilityPostIsMethodNotAllowed(t *testing.T) { ts, _, _, cleanup := newCompatibilityServer(t) + defer cleanup() req, err := http.NewRequest(http.MethodPost, ts.URL+"/did:plc:alice", strings.NewReader(`{}`)) + if err != nil { t.Fatalf("new request: %v", err) } + resp, err := http.DefaultClient.Do(req) + if err != nil { t.Fatalf("post did: %v", err) } + defer resp.Body.Close() + if resp.StatusCode != http.StatusMethodNotAllowed { t.Fatalf("status: got %d want 405", resp.StatusCode) } + if allow := resp.Header.Get("Allow"); allow != http.MethodGet { t.Fatalf("allow header mismatch: got %q want %q", allow, http.MethodGet) } @@ -134,14 +168,19 @@ func TestPLCCompatibilityPostIsMethodNotAllowed(t *testing.T) { func TestPLCCompatibilityNoVerificationMetadataLeak(t *testing.T) { ts, _, _, cleanup := newCompatibilityServer(t) + defer cleanup() resp, err := http.Get(ts.URL + "/did:plc:alice") + if err != nil { t.Fatalf("get did: %v", err) } + defer resp.Body.Close() + body, _ := io.ReadAll(resp.Body) + if strings.Contains(string(body), "checkpoint_reference") { t.Fatalf("compatibility endpoint leaked verification metadata: %s", string(body)) } @@ -149,48 +188,63 @@ func TestPLCCompatibilityNoVerificationMetadataLeak(t *testing.T) { func TestPLCCompatibilityProofEndpointStillWorks(t *testing.T) { ts, _, _, cleanup := newCompatibilityServer(t) + defer cleanup() resp, err := http.Get(ts.URL + "/did/did:plc:alice/proof") + if err != nil { t.Fatalf("get proof: %v", err) } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) + t.Fatalf("proof status: got %d want 200 body=%s", resp.StatusCode, string(body)) } } func newCompatibilityServer(t *testing.T) (*httptest.Server, *storage.PebbleStore, []types.ExportRecord, func()) { t.Helper() + tmp := t.TempDir() dataDir := filepath.Join(tmp, "data") + if err := os.MkdirAll(dataDir, 0o755); err != nil { t.Fatalf("mkdir data: %v", err) } seed := make([]byte, ed25519.SeedSize) + if _, err := rand.Read(seed); err != nil { t.Fatalf("seed: %v", err) } + keyPath := filepath.Join(tmp, "mirror.key") + if err := os.WriteFile(keyPath, []byte(base64.RawURLEncoding.EncodeToString(seed)), 0o600); err != nil { t.Fatalf("write key: %v", err) } recs := buildCheckpointScenarioRecords(t) sourcePath := filepath.Join(tmp, "records.ndjson") + writeRecordsFile(t, sourcePath, recs) store, err := storage.OpenPebble(dataDir) + if err != nil { t.Fatalf("open pebble: %v", err) } + if err := store.SetMode(config.ModeMirror); err != nil { t.Fatalf("set mode: %v", err) } + bl, err := storage.OpenBlockLog(dataDir, 3, 4) + if err != nil { t.Fatalf("open block log: %v", err) } @@ -212,12 +266,15 @@ func newCompatibilityServer(t *testing.T) (*httptest.Server, *storage.PebbleStor } cpMgr := checkpoint.NewManager(store, dataDir, keyPath) svc := ingest.NewService(cfg, store, ingest.NewClient(sourcePath), bl, cpMgr) + if err := svc.Replay(context.Background()); err != nil { t.Fatalf("replay: %v", err) } + if err := svc.Flush(context.Background()); err != nil { t.Fatalf("flush: %v", err) } + if _, err := svc.Snapshot(context.Background()); err != nil { t.Fatalf("snapshot: %v", err) } @@ -226,7 +283,9 @@ func newCompatibilityServer(t *testing.T) (*httptest.Server, *storage.PebbleStor cleanup := func() { ts.Close() svc.Close() + _ = store.Close() } + return ts, store, recs, cleanup } diff --git a/internal/api/server.go b/internal/api/server.go index f773145..2a3c589 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -10,7 +10,6 @@ import ( "strconv" "strings" "time" - "github.com/Fuwn/plutia/internal/checkpoint" "github.com/Fuwn/plutia/internal/config" "github.com/Fuwn/plutia/internal/ingest" @@ -44,21 +43,27 @@ func NewServer(cfg config.Config, store storage.Store, ingestor *ingest.Service, }, limiter: newIPRateLimiter(cfg.RateLimit), } + for _, opt := range opts { opt(s) } + s.metrics = newServerMetrics(cfg, store, ingestor) + if cp, ok, err := store.GetLatestCheckpoint(); err == nil && ok { s.metrics.checkpointSequence.Set(float64(cp.Sequence)) } + if ingestor != nil { ingestor.SetMetricsSink(s.metrics) } + return s } func (s *Server) Handler() http.Handler { mux := http.NewServeMux() + mux.Handle("/health", s.withTimeout(http.HandlerFunc(s.handleHealth))) mux.Handle("/metrics", s.metrics.Handler()) mux.Handle("/status", s.withTimeout(http.HandlerFunc(s.handleStatus))) @@ -67,6 +72,7 @@ func (s *Server) Handler() http.Handler { mux.Handle("/did/", s.withTimeout(http.HandlerFunc(s.handleDID))) mux.Handle("/export", s.withTimeout(http.HandlerFunc(s.handleExportCompatibility))) mux.Handle("/", s.withTimeout(http.HandlerFunc(s.handlePLCCompatibility))) + return mux } @@ -76,19 +82,27 @@ func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) { func (s *Server) handleStatus(w http.ResponseWriter, r *http.Request) { seq, err := s.store.GetGlobalSeq() + if err != nil { writeErr(w, http.StatusInternalServerError, err) + return } + cp, ok, err := s.store.GetLatestCheckpoint() + if err != nil { writeErr(w, http.StatusInternalServerError, err) + return } + stats := ingest.Stats{} + if s.ingestor != nil { stats = s.ingestor.Stats() } + payload := map[string]any{ "mode": s.cfg.Mode, "verify_policy": s.cfg.VerifyPolicy, @@ -96,161 +110,225 @@ func (s *Server) handleStatus(w http.ResponseWriter, r *http.Request) { "stats": stats, "build": s.build, } + if s.ingestor != nil { payload["corrupted"] = s.ingestor.IsCorrupted() + if err := s.ingestor.CorruptionError(); err != nil { payload["corruption_error"] = err.Error() } } + if ok { payload["latest_checkpoint"] = cp } + writeJSON(w, http.StatusOK, payload) } func (s *Server) handleLatestCheckpoint(w http.ResponseWriter, r *http.Request) { cp, ok, err := s.store.GetLatestCheckpoint() + if err != nil { writeErr(w, http.StatusInternalServerError, err) + return } + if !ok { writeErr(w, http.StatusNotFound, fmt.Errorf("no checkpoint")) + return } + writeJSON(w, http.StatusOK, cp) } func (s *Server) handleCheckpointBySequence(w http.ResponseWriter, r *http.Request) { path := strings.TrimPrefix(r.URL.Path, "/checkpoints/") + if path == "" { writeErr(w, http.StatusNotFound, fmt.Errorf("missing checkpoint sequence")) + return } + seq, err := strconv.ParseUint(path, 10, 64) + if err != nil { writeErr(w, http.StatusBadRequest, fmt.Errorf("invalid checkpoint sequence")) + return } + cp, ok, err := s.store.GetCheckpoint(seq) + if err != nil { writeErr(w, http.StatusInternalServerError, err) + return } + if !ok { writeErr(w, http.StatusNotFound, fmt.Errorf("checkpoint not found")) + return } + writeJSON(w, http.StatusOK, cp) } func (s *Server) handleDID(w http.ResponseWriter, r *http.Request) { path := strings.TrimPrefix(r.URL.Path, "/did/") + if path == "" { writeErr(w, http.StatusBadRequest, fmt.Errorf("missing did")) + return } + if strings.HasSuffix(path, "/proof") { did := strings.TrimSuffix(path, "/proof") + if !s.allowRequest(r, limiterProof) { writeErr(w, http.StatusTooManyRequests, fmt.Errorf("proof rate limit exceeded")) + return } + s.handleDIDProof(w, r, did) + return } + if !s.allowRequest(r, limiterResolve) { writeErr(w, http.StatusTooManyRequests, fmt.Errorf("resolve rate limit exceeded")) + return } + s.handleDIDResolve(w, r, path) } func (s *Server) handleDIDResolve(w http.ResponseWriter, r *http.Request, did string) { state, ok, err := s.store.GetState(did) + if err != nil { writeErr(w, http.StatusInternalServerError, err) + return } + if !ok { writeErr(w, http.StatusNotFound, fmt.Errorf("did not found")) + return } + cp, cpOK, err := s.store.GetLatestCheckpoint() + if err != nil { writeErr(w, http.StatusInternalServerError, err) + return } + resp := map[string]any{ "did": did, "did_document": json.RawMessage(state.DIDDocument), "chain_tip_hash": state.ChainTipHash, } + if cpOK { resp["checkpoint_reference"] = map[string]any{ "sequence": cp.Sequence, "checkpoint_hash": cp.CheckpointHash, } } + writeJSON(w, http.StatusOK, resp) } func (s *Server) handleDIDProof(w http.ResponseWriter, r *http.Request, did string) { if s.ingestor == nil { writeErr(w, http.StatusServiceUnavailable, fmt.Errorf("ingestor unavailable")) + return } + if err := s.ingestor.CorruptionError(); err != nil { writeErr(w, http.StatusServiceUnavailable, err) + return } cp, verifyCheckpointUnchanged, err := s.selectCheckpointForProof(r) + if err != nil { writeErr(w, http.StatusBadRequest, err) + return } tipHash, seqs, err := s.ingestor.RecomputeTipAtOrBefore(r.Context(), did, cp.Sequence) + if err != nil { if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { writeErr(w, http.StatusGatewayTimeout, err) + return } + writeErr(w, http.StatusNotFound, err) + return } + siblings, leafHash, found, err := s.checkpoints.BuildDIDProofAtCheckpoint(r.Context(), did, tipHash, cp.Sequence) + if err != nil { if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { writeErr(w, http.StatusGatewayTimeout, err) + return } + writeErr(w, http.StatusInternalServerError, err) + return } + if !found { writeErr(w, http.StatusNotFound, fmt.Errorf("did not present in checkpoint state")) + return } leafBytes, err := hex.DecodeString(leafHash) + if err != nil { writeErr(w, http.StatusInternalServerError, fmt.Errorf("invalid leaf hash: %w", err)) + return } + root, err := hex.DecodeString(cp.DIDMerkleRoot) + if err != nil { writeErr(w, http.StatusInternalServerError, fmt.Errorf("invalid checkpoint root")) + return } + if !merkle.VerifyProof(leafBytes, siblings, root) { writeErr(w, http.StatusInternalServerError, fmt.Errorf("inclusion proof failed consistency check")) + return } if err := verifyCheckpointUnchanged(); err != nil { writeErr(w, http.StatusConflict, err) + return } @@ -265,6 +343,7 @@ func (s *Server) handleDIDProof(w http.ResponseWriter, r *http.Request, did stri CheckpointSig: cp.Signature, CheckpointKeyID: cp.KeyID, } + writeJSON(w, http.StatusOK, map[string]any{ "did": did, "checkpoint_sequence": cp.Sequence, @@ -287,28 +366,39 @@ type plcAuditEntry struct { func (s *Server) handlePLCCompatibility(w http.ResponseWriter, r *http.Request) { path := strings.TrimPrefix(r.URL.Path, "/") + if path == "" { writeErr(w, http.StatusNotFound, fmt.Errorf("not found")) + return } + if path == "export" { s.handleExportCompatibility(w, r) + return } + parts := strings.Split(path, "/") did := parts[0] + if !strings.HasPrefix(did, "did:") { writeErr(w, http.StatusNotFound, fmt.Errorf("not found")) + return } + if r.Method == http.MethodPost && len(parts) == 1 { w.Header().Set("Allow", http.MethodGet) writeErr(w, http.StatusMethodNotAllowed, fmt.Errorf("write operations are not supported by this mirror")) + return } + if r.Method != http.MethodGet { w.Header().Set("Allow", http.MethodGet) writeErr(w, http.StatusMethodNotAllowed, fmt.Errorf("method not allowed")) + return } @@ -330,90 +420,129 @@ func (s *Server) handlePLCCompatibility(w http.ResponseWriter, r *http.Request) func (s *Server) handleGetDIDCompatibility(w http.ResponseWriter, did string) { state, ok, err := s.store.GetState(did) + if err != nil { writeErr(w, http.StatusInternalServerError, err) + return } + if !ok { writeErr(w, http.StatusNotFound, fmt.Errorf("did not found")) + return } + status := http.StatusOK + if isTombstonedDIDDocument(state.DIDDocument) { status = http.StatusGone } + w.Header().Set("Content-Type", "application/did+ld+json") w.WriteHeader(status) + _, _ = w.Write(state.DIDDocument) } func (s *Server) handleGetDIDLogCompatibility(w http.ResponseWriter, r *http.Request, did string) { if s.ingestor == nil { writeErr(w, http.StatusServiceUnavailable, fmt.Errorf("ingestor unavailable")) + return } + logEntries, err := s.ingestor.LoadDIDLog(r.Context(), did) + if err != nil { if errors.Is(err, ingest.ErrDIDNotFound) || errors.Is(err, ingest.ErrHistoryNotStored) { writeErr(w, http.StatusNotFound, fmt.Errorf("did not found")) + return } + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { writeErr(w, http.StatusGatewayTimeout, err) + return } + writeErr(w, http.StatusInternalServerError, err) + return } + ops := make([]json.RawMessage, 0, len(logEntries)) + for _, rec := range logEntries { ops = append(ops, rec.Operation) } + writeJSONWithContentType(w, http.StatusOK, "application/json", ops) } func (s *Server) handleGetDIDLogLastCompatibility(w http.ResponseWriter, r *http.Request, did string) { if s.ingestor == nil { writeErr(w, http.StatusServiceUnavailable, fmt.Errorf("ingestor unavailable")) + return } + rec, err := s.ingestor.LoadLatestDIDOperation(r.Context(), did) + if err != nil { if errors.Is(err, ingest.ErrDIDNotFound) || errors.Is(err, ingest.ErrHistoryNotStored) { writeErr(w, http.StatusNotFound, fmt.Errorf("did not found")) + return } + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { writeErr(w, http.StatusGatewayTimeout, err) + return } + writeErr(w, http.StatusInternalServerError, err) + return } + w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) + _, _ = w.Write(rec.Operation) } func (s *Server) handleGetDIDLogAuditCompatibility(w http.ResponseWriter, r *http.Request, did string) { if s.ingestor == nil { writeErr(w, http.StatusServiceUnavailable, fmt.Errorf("ingestor unavailable")) + return } + logEntries, err := s.ingestor.LoadDIDLog(r.Context(), did) + if err != nil { if errors.Is(err, ingest.ErrDIDNotFound) || errors.Is(err, ingest.ErrHistoryNotStored) { writeErr(w, http.StatusNotFound, fmt.Errorf("did not found")) + return } + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { writeErr(w, http.StatusGatewayTimeout, err) + return } + writeErr(w, http.StatusInternalServerError, err) + return } + audit := make([]plcAuditEntry, 0, len(logEntries)) + for _, rec := range logEntries { audit = append(audit, plcAuditEntry{ DID: did, @@ -423,27 +552,37 @@ func (s *Server) handleGetDIDLogAuditCompatibility(w http.ResponseWriter, r *htt CreatedAt: rec.CreatedAt, }) } + writeJSONWithContentType(w, http.StatusOK, "application/json", audit) } func (s *Server) handleGetDIDDataCompatibility(w http.ResponseWriter, r *http.Request, did string) { if s.ingestor == nil { writeErr(w, http.StatusServiceUnavailable, fmt.Errorf("ingestor unavailable")) + return } + data, err := s.ingestor.LoadCurrentPLCData(r.Context(), did) + if err != nil { if errors.Is(err, ingest.ErrDIDNotFound) { writeErr(w, http.StatusNotFound, fmt.Errorf("did not found")) + return } + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { writeErr(w, http.StatusGatewayTimeout, err) + return } + writeErr(w, http.StatusInternalServerError, err) + return } + writeJSONWithContentType(w, http.StatusOK, "application/json", data) } @@ -451,38 +590,51 @@ func (s *Server) handleExportCompatibility(w http.ResponseWriter, r *http.Reques if r.Method != http.MethodGet { w.Header().Set("Allow", http.MethodGet) writeErr(w, http.StatusMethodNotAllowed, fmt.Errorf("method not allowed")) + return } + if s.ingestor == nil { writeErr(w, http.StatusServiceUnavailable, fmt.Errorf("ingestor unavailable")) + return } count := 1000 + if rawCount := strings.TrimSpace(r.URL.Query().Get("count")); rawCount != "" { n, err := strconv.Atoi(rawCount) + if err != nil || n < 1 { writeErr(w, http.StatusBadRequest, fmt.Errorf("invalid count query parameter")) + return } + if n > 1000 { n = 1000 } + count = n } var after time.Time + if rawAfter := strings.TrimSpace(r.URL.Query().Get("after")); rawAfter != "" { parsed, err := time.Parse(time.RFC3339, rawAfter) + if err != nil { writeErr(w, http.StatusBadRequest, fmt.Errorf("invalid after query parameter")) + return } + after = parsed } w.Header().Set("Content-Type", "application/jsonlines") w.WriteHeader(http.StatusOK) + flusher, _ := w.(http.Flusher) enc := json.NewEncoder(w) err := s.ingestor.StreamExport(r.Context(), after, count, func(rec types.ExportRecord) error { @@ -493,14 +645,18 @@ func (s *Server) handleExportCompatibility(w http.ResponseWriter, r *http.Reques Nullified: rec.Nullified, CreatedAt: rec.CreatedAt, } + if err := enc.Encode(entry); err != nil { return err } + if flusher != nil { flusher.Flush() } + return nil }) + if err != nil { // Response has already started; best effort termination. return @@ -511,22 +667,30 @@ func isTombstonedDIDDocument(raw []byte) bool { if len(raw) == 0 { return false } + var doc map[string]any + if err := json.Unmarshal(raw, &doc); err != nil { return false } + deactivated, _ := doc["deactivated"].(bool) + return deactivated } func (s *Server) withTimeout(next http.Handler) http.Handler { timeout := s.cfg.RequestTimeout + if timeout <= 0 { timeout = 10 * time.Second } + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx, cancel := context.WithTimeout(r.Context(), timeout) + defer cancel() + next.ServeHTTP(w, r.WithContext(ctx)) }) } @@ -535,53 +699,70 @@ func (s *Server) allowRequest(r *http.Request, class limiterClass) bool { if s.limiter == nil { return true } + return s.limiter.Allow(clientIP(r), class) } func (s *Server) selectCheckpointForProof(r *http.Request) (types.CheckpointV1, func() error, error) { checkpointParam := strings.TrimSpace(r.URL.Query().Get("checkpoint")) + if checkpointParam == "" { cp, ok, err := s.store.GetLatestCheckpoint() + if err != nil { return types.CheckpointV1{}, nil, err } + if !ok { return types.CheckpointV1{}, nil, fmt.Errorf("no checkpoint available") } + return cp, func() error { now, ok, err := s.store.GetLatestCheckpoint() + if err != nil { return err } + if !ok { return fmt.Errorf("latest checkpoint disappeared during request") } + if now.CheckpointHash != cp.CheckpointHash { return fmt.Errorf("checkpoint advanced during proof generation") } + return nil }, nil } seq, err := strconv.ParseUint(checkpointParam, 10, 64) + if err != nil { return types.CheckpointV1{}, nil, fmt.Errorf("invalid checkpoint query parameter") } + cp, ok, err := s.store.GetCheckpoint(seq) + if err != nil { return types.CheckpointV1{}, nil, err } + if !ok { return types.CheckpointV1{}, nil, fmt.Errorf("checkpoint %d unavailable", seq) } + return cp, func() error { again, ok, err := s.store.GetCheckpoint(seq) + if err != nil { return err } + if !ok || again.CheckpointHash != cp.CheckpointHash { return fmt.Errorf("checkpoint %d changed during proof generation", seq) } + return nil }, nil } @@ -592,10 +773,13 @@ func writeJSON(w http.ResponseWriter, code int, v any) { func writeJSONWithContentType(w http.ResponseWriter, code int, contentType string, v any) { w.Header().Set("Content-Type", "application/json") + if strings.TrimSpace(contentType) != "" { w.Header().Set("Content-Type", contentType) } + w.WriteHeader(code) + _ = json.NewEncoder(w).Encode(v) } diff --git a/internal/api/server_checkpoint_test.go b/internal/api/server_checkpoint_test.go index 9a12f61..531f24d 100644 --- a/internal/api/server_checkpoint_test.go +++ b/internal/api/server_checkpoint_test.go @@ -5,7 +5,6 @@ import ( "path/filepath" "strings" "testing" - "github.com/Fuwn/plutia/internal/config" "github.com/Fuwn/plutia/internal/storage" "github.com/Fuwn/plutia/internal/types" @@ -14,9 +13,11 @@ import ( func TestSelectCheckpointForProofDetectsLatestAdvance(t *testing.T) { dataDir := t.TempDir() store, err := storage.OpenPebble(dataDir) + if err != nil { t.Fatalf("open pebble: %v", err) } + defer store.Close() if err := store.PutCheckpoint(types.CheckpointV1{ @@ -32,9 +33,11 @@ func TestSelectCheckpointForProofDetectsLatestAdvance(t *testing.T) { srv := NewServer(config.Default(), store, nil, nil) req := httptest.NewRequest("GET", "/did/did:plc:alice/proof", nil) cp, verifyUnchanged, err := srv.selectCheckpointForProof(req) + if err != nil { t.Fatalf("select checkpoint: %v", err) } + if cp.Sequence != 10 { t.Fatalf("selected unexpected checkpoint sequence: got %d want 10", cp.Sequence) } @@ -57,9 +60,11 @@ func TestSelectCheckpointForProofDetectsLatestAdvance(t *testing.T) { func TestSelectCheckpointForProofDetectsHistoricalMutation(t *testing.T) { dataDir := t.TempDir() store, err := storage.OpenPebble(filepath.Clean(dataDir)) + if err != nil { t.Fatalf("open pebble: %v", err) } + defer store.Close() if err := store.PutCheckpoint(types.CheckpointV1{ @@ -75,9 +80,11 @@ func TestSelectCheckpointForProofDetectsHistoricalMutation(t *testing.T) { srv := NewServer(config.Default(), store, nil, nil) req := httptest.NewRequest("GET", "/did/did:plc:alice/proof?checkpoint=20", nil) cp, verifyUnchanged, err := srv.selectCheckpointForProof(req) + if err != nil { t.Fatalf("select historical checkpoint: %v", err) } + if cp.CheckpointHash != "cp-20-a" { t.Fatalf("selected unexpected checkpoint hash: got %s", cp.CheckpointHash) } diff --git a/internal/api/server_hardening_test.go b/internal/api/server_hardening_test.go index bb9f24c..fcdf87b 100644 --- a/internal/api/server_hardening_test.go +++ b/internal/api/server_hardening_test.go @@ -8,7 +8,6 @@ import ( "strings" "testing" "time" - "github.com/Fuwn/plutia/internal/config" "github.com/Fuwn/plutia/internal/storage" "github.com/Fuwn/plutia/internal/types" @@ -16,10 +15,13 @@ import ( func TestResolveRateLimitPerIP(t *testing.T) { store, err := storage.OpenPebble(t.TempDir()) + if err != nil { t.Fatalf("open pebble: %v", err) } + defer store.Close() + if err := store.PutState(types.StateV1{Version: 1, DID: "did:plc:alice", DIDDocument: []byte(`{"id":"did:plc:alice"}`), ChainTipHash: "tip", LatestOpSeq: 1, UpdatedAt: time.Now().UTC()}); err != nil { t.Fatalf("put state: %v", err) } @@ -31,11 +33,12 @@ func TestResolveRateLimitPerIP(t *testing.T) { cfg.RateLimit.ProofRPS = 1 cfg.RateLimit.ProofBurst = 1 h := NewServer(cfg, store, nil, nil).Handler() - req1 := httptest.NewRequest(http.MethodGet, "/did/did:plc:alice", nil) req1.RemoteAddr = "203.0.113.7:12345" rr1 := httptest.NewRecorder() + h.ServeHTTP(rr1, req1) + if rr1.Code != http.StatusOK { t.Fatalf("first request status: got %d want %d", rr1.Code, http.StatusOK) } @@ -43,7 +46,9 @@ func TestResolveRateLimitPerIP(t *testing.T) { req2 := httptest.NewRequest(http.MethodGet, "/did/did:plc:alice", nil) req2.RemoteAddr = "203.0.113.7:12345" rr2 := httptest.NewRecorder() + h.ServeHTTP(rr2, req2) + if rr2.Code != http.StatusTooManyRequests { t.Fatalf("second request status: got %d want %d", rr2.Code, http.StatusTooManyRequests) } @@ -51,9 +56,11 @@ func TestResolveRateLimitPerIP(t *testing.T) { func TestStatusIncludesBuildInfo(t *testing.T) { store, err := storage.OpenPebble(t.TempDir()) + if err != nil { t.Fatalf("open pebble: %v", err) } + defer store.Close() h := NewServer(config.Default(), store, nil, nil, WithBuildInfo(BuildInfo{ @@ -62,25 +69,31 @@ func TestStatusIncludesBuildInfo(t *testing.T) { BuildDate: "2026-02-26T00:00:00Z", GoVersion: "go1.test", })).Handler() - req := httptest.NewRequest(http.MethodGet, "/status", nil) rr := httptest.NewRecorder() + h.ServeHTTP(rr, req) + if rr.Code != http.StatusOK { t.Fatalf("status code: got %d want %d", rr.Code, http.StatusOK) } var payload map[string]any + if err := json.Unmarshal(rr.Body.Bytes(), &payload); err != nil { t.Fatalf("decode status: %v", err) } + build, ok := payload["build"].(map[string]any) + if !ok { t.Fatalf("missing build section: %v", payload) } + if got := build["version"]; got != "v0.1.0" { t.Fatalf("unexpected build version: %v", got) } + if got := build["commit"]; got != "abc123" { t.Fatalf("unexpected build commit: %v", got) } @@ -88,20 +101,26 @@ func TestStatusIncludesBuildInfo(t *testing.T) { func TestMetricsExposeRequiredSeries(t *testing.T) { store, err := storage.OpenPebble(t.TempDir()) + if err != nil { t.Fatalf("open pebble: %v", err) } + defer store.Close() h := NewServer(config.Default(), store, nil, nil).Handler() req := httptest.NewRequest(http.MethodGet, "/metrics", nil) rr := httptest.NewRecorder() + h.ServeHTTP(rr, req) + if rr.Code != http.StatusOK { t.Fatalf("metrics status: got %d want %d", rr.Code, http.StatusOK) } + body, _ := io.ReadAll(rr.Body) text := string(body) + for _, metric := range []string{ "ingest_ops_total", "ingest_ops_per_second", diff --git a/internal/api/server_integration_test.go b/internal/api/server_integration_test.go index 1736750..edb64bd 100644 --- a/internal/api/server_integration_test.go +++ b/internal/api/server_integration_test.go @@ -14,7 +14,6 @@ import ( "path/filepath" "strings" "testing" - "github.com/Fuwn/plutia/internal/checkpoint" "github.com/Fuwn/plutia/internal/config" "github.com/Fuwn/plutia/internal/ingest" @@ -25,32 +24,41 @@ import ( func TestProofAgainstOlderCheckpointAfterFurtherIngest(t *testing.T) { tmp := t.TempDir() dataDir := filepath.Join(tmp, "data") + if err := os.MkdirAll(dataDir, 0o755); err != nil { t.Fatalf("mkdir data: %v", err) } keyPath := filepath.Join(tmp, "mirror.key") seed := make([]byte, ed25519.SeedSize) + if _, err := rand.Read(seed); err != nil { t.Fatalf("seed: %v", err) } + if err := os.WriteFile(keyPath, []byte(base64.RawURLEncoding.EncodeToString(seed)), 0o600); err != nil { t.Fatalf("write key: %v", err) } recs := buildCheckpointScenarioRecords(t) sourcePath := filepath.Join(tmp, "records.ndjson") + writeRecordsFile(t, sourcePath, recs) store, err := storage.OpenPebble(dataDir) + if err != nil { t.Fatalf("open pebble: %v", err) } + defer store.Close() + if err := store.SetMode(config.ModeMirror); err != nil { t.Fatalf("set mode: %v", err) } + bl, err := storage.OpenBlockLog(dataDir, 3, 4) + if err != nil { t.Fatalf("open block log: %v", err) } @@ -69,29 +77,37 @@ func TestProofAgainstOlderCheckpointAfterFurtherIngest(t *testing.T) { } cpMgr := checkpoint.NewManager(store, dataDir, keyPath) svc := ingest.NewService(cfg, store, ingest.NewClient(sourcePath), bl, cpMgr) + if err := svc.Replay(context.Background()); err != nil { t.Fatalf("replay: %v", err) } + if err := svc.Flush(context.Background()); err != nil { t.Fatalf("flush: %v", err) } cp2, ok, err := store.GetCheckpoint(2) + if err != nil || !ok { t.Fatalf("checkpoint 2 missing: ok=%v err=%v", ok, err) } ts := httptest.NewServer(NewServer(cfg, store, svc, cpMgr).Handler()) + defer ts.Close() url := ts.URL + "/did/" + strings.ReplaceAll("did:plc:alice", ":", "%3A") + "/proof?checkpoint=2" resp, err := http.Get(url) + if err != nil { t.Fatalf("get proof: %v", err) } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { bodyBytes, _ := io.ReadAll(resp.Body) + t.Fatalf("proof status: %d body=%s", resp.StatusCode, string(bodyBytes)) } @@ -102,15 +118,19 @@ func TestProofAgainstOlderCheckpointAfterFurtherIngest(t *testing.T) { MerkleRoot string `json:"merkle_root"` } `json:"inclusion_proof"` } + if err := json.NewDecoder(resp.Body).Decode(&body); err != nil { t.Fatalf("decode response: %v", err) } + if body.CheckpointSequence != 2 { t.Fatalf("unexpected checkpoint sequence: got %d want 2", body.CheckpointSequence) } + if body.ChainTipReference != recs[0].CID { t.Fatalf("expected old tip at checkpoint=2: got %s want %s", body.ChainTipReference, recs[0].CID) } + if body.InclusionProof.MerkleRoot != cp2.DIDMerkleRoot { t.Fatalf("merkle root mismatch: got %s want %s", body.InclusionProof.MerkleRoot, cp2.DIDMerkleRoot) } @@ -119,34 +139,44 @@ func TestProofAgainstOlderCheckpointAfterFurtherIngest(t *testing.T) { func TestCorruptedBlockRefusesProof(t *testing.T) { tmp := t.TempDir() dataDir := filepath.Join(tmp, "data") + if err := os.MkdirAll(dataDir, 0o755); err != nil { t.Fatalf("mkdir data: %v", err) } seed := make([]byte, ed25519.SeedSize) + if _, err := rand.Read(seed); err != nil { t.Fatalf("seed: %v", err) } + keyPath := filepath.Join(tmp, "mirror.key") + if err := os.WriteFile(keyPath, []byte(base64.RawURLEncoding.EncodeToString(seed)), 0o600); err != nil { t.Fatalf("write key: %v", err) } recs := buildCheckpointScenarioRecords(t) sourcePath := filepath.Join(tmp, "records.ndjson") + writeRecordsFile(t, sourcePath, recs) store, err := storage.OpenPebble(dataDir) + if err != nil { t.Fatalf("open pebble: %v", err) } + if err := store.SetMode(config.ModeMirror); err != nil { t.Fatalf("set mode: %v", err) } + bl, err := storage.OpenBlockLog(dataDir, 3, 4) + if err != nil { t.Fatalf("open block log: %v", err) } + cfg := config.Config{ Mode: config.ModeMirror, DataDir: dataDir, @@ -161,49 +191,67 @@ func TestCorruptedBlockRefusesProof(t *testing.T) { } cpMgr := checkpoint.NewManager(store, dataDir, keyPath) svc := ingest.NewService(cfg, store, ingest.NewClient(sourcePath), bl, cpMgr) + if err := svc.Replay(context.Background()); err != nil { t.Fatalf("replay: %v", err) } + if _, err := svc.Snapshot(context.Background()); err != nil { t.Fatalf("snapshot: %v", err) } + svc.Close() + if err := store.Close(); err != nil { t.Fatalf("close store: %v", err) } blockPath := filepath.Join(dataDir, "ops", "000001.zst") b, err := os.ReadFile(blockPath) + if err != nil { t.Fatalf("read block: %v", err) } + b[len(b)/2] ^= 0xFF + if err := os.WriteFile(blockPath, b, 0o644); err != nil { t.Fatalf("write corrupted block: %v", err) } store2, err := storage.OpenPebble(dataDir) + if err != nil { t.Fatalf("open store2: %v", err) } + defer store2.Close() + bl2, err := storage.OpenBlockLog(dataDir, 3, 4) + if err != nil { t.Fatalf("open blocklog2: %v", err) } + svc2 := ingest.NewService(cfg, store2, ingest.NewClient("file:///nonexistent"), bl2, checkpoint.NewManager(store2, dataDir, keyPath)) + if !svc2.IsCorrupted() { t.Fatalf("expected service to detect corruption on restart") } ts := httptest.NewServer(NewServer(cfg, store2, svc2, checkpoint.NewManager(store2, dataDir, keyPath)).Handler()) + defer ts.Close() + url := ts.URL + "/did/" + strings.ReplaceAll("did:plc:alice", ":", "%3A") + "/proof" resp, err := http.Get(url) + if err != nil { t.Fatalf("request proof: %v", err) } + defer resp.Body.Close() + if resp.StatusCode != http.StatusServiceUnavailable { t.Fatalf("expected 503 for corrupted proof, got %d", resp.StatusCode) } @@ -211,13 +259,18 @@ func TestCorruptedBlockRefusesProof(t *testing.T) { func writeRecordsFile(t *testing.T, path string, recs []types.ExportRecord) { t.Helper() + f, err := os.Create(path) + if err != nil { t.Fatalf("create records file: %v", err) } + defer f.Close() + for _, rec := range recs { b, _ := json.Marshal(rec) + if _, err := fmt.Fprintln(f, string(b)); err != nil { t.Fatalf("write record: %v", err) } @@ -226,36 +279,44 @@ func writeRecordsFile(t *testing.T, path string, recs []types.ExportRecord) { func buildCheckpointScenarioRecords(t *testing.T) []types.ExportRecord { t.Helper() + pub, priv, err := ed25519.GenerateKey(rand.Reader) + if err != nil { t.Fatalf("generate key: %v", err) } + mk := func(seq uint64, did, prev string) types.ExportRecord { unsigned := map[string]any{ "did": did, "didDoc": map[string]any{"id": did, "seq": seq}, "publicKey": base64.RawURLEncoding.EncodeToString(pub), } + if prev != "" { unsigned["prev"] = prev } + payload, _ := json.Marshal(unsigned) canon, _ := types.CanonicalizeJSON(payload) sig := ed25519.Sign(priv, canon) op := map[string]any{} + for k, v := range unsigned { op[k] = v } + op["sigPayload"] = base64.RawURLEncoding.EncodeToString(canon) op["sig"] = base64.RawURLEncoding.EncodeToString(sig) raw, _ := json.Marshal(op) opCanon, _ := types.CanonicalizeJSON(raw) cid := types.ComputeDigestCID(opCanon) + return types.ExportRecord{Seq: seq, DID: did, CID: cid, Operation: raw} } - r1 := mk(1, "did:plc:alice", "") r2 := mk(2, "did:plc:bob", "") r3 := mk(3, "did:plc:alice", r1.CID) + return []types.ExportRecord{r1, r2, r3} } diff --git a/internal/checkpoint/checkpoint.go b/internal/checkpoint/checkpoint.go index c3ed0da..85c4992 100644 --- a/internal/checkpoint/checkpoint.go +++ b/internal/checkpoint/checkpoint.go @@ -15,7 +15,6 @@ import ( "sort" "strings" "time" - "github.com/Fuwn/plutia/internal/merkle" "github.com/Fuwn/plutia/internal/storage" "github.com/Fuwn/plutia/internal/types" @@ -40,27 +39,34 @@ func NewManager(store storage.Store, dataDir, keyPath string) *Manager { func (m *Manager) BuildAndStore(sequence uint64, states []types.StateV1, blockHashes []string) (types.CheckpointV1, error) { didRoot, leaves := didMerkle(states) + if err := m.writeCheckpointStateSnapshot(sequence, leaves); err != nil { return types.CheckpointV1{}, err } + return m.signAndPersist(sequence, didRoot, blockHashes) } func (m *Manager) BuildAndStoreFromStore(sequence uint64, blockHashes []string) (types.CheckpointV1, error) { cp, _, err := m.BuildAndStoreFromStoreWithMetrics(sequence, blockHashes) + return cp, err } func (m *Manager) BuildAndStoreFromStoreWithMetrics(sequence uint64, blockHashes []string) (types.CheckpointV1, BuildMetrics, error) { start := time.Now() didRoot, didCount, merkleCompute, err := m.writeCheckpointStateSnapshotFromStore(sequence) + if err != nil { return types.CheckpointV1{}, BuildMetrics{}, err } + cp, err := m.signAndPersist(sequence, didRoot, blockHashes) + if err != nil { return types.CheckpointV1{}, BuildMetrics{}, err } + return cp, BuildMetrics{ DIDCount: didCount, MerkleCompute: merkleCompute, @@ -70,12 +76,14 @@ func (m *Manager) BuildAndStoreFromStoreWithMetrics(sequence uint64, blockHashes func (m *Manager) signAndPersist(sequence uint64, didRoot string, blockHashes []string) (types.CheckpointV1, error) { privateKey, keyID, err := m.loadSigningKey() + if err != nil { return types.CheckpointV1{}, err } - blockRoot := blockMerkle(blockHashes) + blockRoot := blockMerkle(blockHashes) prev := "" + if latest, ok, err := m.store.GetLatestCheckpoint(); err == nil && ok { prev = latest.CheckpointHash } else if err != nil { @@ -92,9 +100,11 @@ func (m *Manager) signAndPersist(sequence uint64, didRoot string, blockHashes [] KeyID: keyID, } payload, err := marshalCheckpointPayload(unsigned) + if err != nil { return types.CheckpointV1{}, err } + sum := sha256.Sum256(payload) unsigned.CheckpointHash = hex.EncodeToString(sum[:]) unsigned.Signature = base64.RawURLEncoding.EncodeToString(ed25519.Sign(privateKey, payload)) @@ -102,84 +112,112 @@ func (m *Manager) signAndPersist(sequence uint64, didRoot string, blockHashes [] if err := m.store.PutCheckpoint(unsigned); err != nil { return types.CheckpointV1{}, fmt.Errorf("persist checkpoint: %w", err) } + if err := m.writeCheckpointFile(unsigned); err != nil { return types.CheckpointV1{}, err } + return unsigned, nil } func (m *Manager) BuildDIDProofAtCheckpoint(ctx context.Context, did, chainTipHash string, checkpointSeq uint64) ([]merkle.Sibling, string, bool, error) { snapshot, err := m.LoadStateSnapshot(checkpointSeq) + if err != nil { return nil, "", false, err } + leaves := make([][]byte, len(snapshot.Leaves)) index := -1 leafHashHex := "" + for i, s := range snapshot.Leaves { if err := ctx.Err(); err != nil { return nil, "", false, err } + h := merkle.HashLeaf([]byte(s.DID + s.ChainTipHash)) leaves[i] = h + if s.DID == did && s.ChainTipHash == chainTipHash { index = i leafHashHex = hex.EncodeToString(h) } } + if index < 0 { return nil, "", false, nil } + proof := merkle.BuildProof(leaves, index) + return proof, leafHashHex, true, nil } func (m *Manager) LoadStateSnapshot(sequence uint64) (types.CheckpointStateSnapshotV1, error) { path := filepath.Join(m.dataDir, "checkpoints", fmt.Sprintf("%020d.state.json", sequence)) b, err := os.ReadFile(path) + if err != nil { return types.CheckpointStateSnapshotV1{}, fmt.Errorf("read checkpoint state snapshot: %w", err) } + var snap types.CheckpointStateSnapshotV1 + if err := json.Unmarshal(b, &snap); err != nil { return types.CheckpointStateSnapshotV1{}, fmt.Errorf("unmarshal checkpoint state snapshot: %w", err) } + if snap.Sequence != sequence { return types.CheckpointStateSnapshotV1{}, fmt.Errorf("snapshot sequence mismatch: got %d want %d", snap.Sequence, sequence) } + return snap, nil } func didMerkle(states []types.StateV1) (string, []types.DIDLeaf) { sorted := append([]types.StateV1(nil), states...) + sort.Slice(sorted, func(i, j int) bool { return sorted[i].DID < sorted[j].DID }) + acc := merkle.NewAccumulator() snapshotLeaves := make([]types.DIDLeaf, 0, len(sorted)) + for _, s := range sorted { snapshotLeaves = append(snapshotLeaves, types.DIDLeaf{ DID: s.DID, ChainTipHash: s.ChainTipHash, }) + acc.AddLeafHash(merkle.HashLeaf([]byte(s.DID + s.ChainTipHash))) } + root := acc.RootDuplicateLast() + return hex.EncodeToString(root), snapshotLeaves } func blockMerkle(hashes []string) string { if len(hashes) == 0 { root := merkle.Root(nil) + return hex.EncodeToString(root) } + leaves := make([][]byte, 0, len(hashes)) + for _, h := range hashes { decoded, err := hex.DecodeString(strings.TrimSpace(h)) + if err != nil { decoded = []byte(h) } + leaves = append(leaves, merkle.HashLeaf(decoded)) } + root := merkle.Root(leaves) + return hex.EncodeToString(root) } @@ -188,33 +226,42 @@ func marshalCheckpointPayload(cp types.CheckpointV1) ([]byte, error) { clone.Signature = "" clone.CheckpointHash = "" b, err := json.Marshal(clone) + if err != nil { return nil, fmt.Errorf("marshal checkpoint payload: %w", err) } + return types.CanonicalizeJSON(b) } func (m *Manager) writeCheckpointFile(cp types.CheckpointV1) error { dir := filepath.Join(m.dataDir, "checkpoints") + if err := os.MkdirAll(dir, 0o755); err != nil { return fmt.Errorf("mkdir checkpoints: %w", err) } + path := filepath.Join(dir, fmt.Sprintf("%020d.json", cp.Sequence)) b, err := json.MarshalIndent(cp, "", " ") + if err != nil { return fmt.Errorf("marshal checkpoint: %w", err) } + if err := os.WriteFile(path, b, 0o644); err != nil { return fmt.Errorf("write checkpoint file: %w", err) } + return nil } func (m *Manager) writeCheckpointStateSnapshot(sequence uint64, leaves []types.DIDLeaf) error { dir := filepath.Join(m.dataDir, "checkpoints") + if err := os.MkdirAll(dir, 0o755); err != nil { return fmt.Errorf("mkdir checkpoints: %w", err) } + snap := types.CheckpointStateSnapshotV1{ Version: 1, Sequence: sequence, @@ -222,32 +269,41 @@ func (m *Manager) writeCheckpointStateSnapshot(sequence uint64, leaves []types.D Leaves: leaves, } b, err := json.Marshal(snap) + if err != nil { return fmt.Errorf("marshal checkpoint state snapshot: %w", err) } + path := filepath.Join(dir, fmt.Sprintf("%020d.state.json", sequence)) + if err := os.WriteFile(path, b, 0o644); err != nil { return fmt.Errorf("write checkpoint state snapshot: %w", err) } + return nil } func (m *Manager) writeCheckpointStateSnapshotFromStore(sequence uint64) (string, int, time.Duration, error) { dir := filepath.Join(m.dataDir, "checkpoints") + if err := os.MkdirAll(dir, 0o755); err != nil { return "", 0, 0, fmt.Errorf("mkdir checkpoints: %w", err) } + finalPath := filepath.Join(dir, fmt.Sprintf("%020d.state.json", sequence)) tmpPath := finalPath + ".tmp" - f, err := os.Create(tmpPath) + if err != nil { return "", 0, 0, fmt.Errorf("create checkpoint state snapshot temp file: %w", err) } + createdAt := time.Now().UTC().Format(time.RFC3339) w := bufio.NewWriterSize(f, 1<<20) + if _, err := fmt.Fprintf(w, `{"v":1,"sequence":%d,"created_at":"%s","leaves":[`, sequence, createdAt); err != nil { _ = f.Close() + return "", 0, 0, fmt.Errorf("write snapshot header: %w", err) } @@ -255,61 +311,82 @@ func (m *Manager) writeCheckpointStateSnapshotFromStore(sequence uint64) (string didCount := 0 merkleCompute := time.Duration(0) acc := merkle.NewAccumulator() + if err := m.store.ForEachState(func(s types.StateV1) error { leaf := types.DIDLeaf{ DID: s.DID, ChainTipHash: s.ChainTipHash, } b, err := json.Marshal(leaf) + if err != nil { return fmt.Errorf("marshal checkpoint leaf: %w", err) } + if !first { if _, err := w.WriteString(","); err != nil { return err } } + if _, err := w.Write(b); err != nil { return err } + first = false hashStart := time.Now() + acc.AddLeafHash(merkle.HashLeaf([]byte(leaf.DID + leaf.ChainTipHash))) + merkleCompute += time.Since(hashStart) + didCount++ + return nil }); err != nil { _ = f.Close() + return "", 0, 0, err } if _, err := w.WriteString(`]}`); err != nil { _ = f.Close() + return "", 0, 0, fmt.Errorf("write snapshot trailer: %w", err) } + if err := w.Flush(); err != nil { _ = f.Close() + return "", 0, 0, fmt.Errorf("flush snapshot writer: %w", err) } + if err := f.Sync(); err != nil { _ = f.Close() + return "", 0, 0, fmt.Errorf("sync snapshot file: %w", err) } + if err := f.Close(); err != nil { return "", 0, 0, fmt.Errorf("close snapshot file: %w", err) } + if err := os.Rename(tmpPath, finalPath); err != nil { return "", 0, 0, fmt.Errorf("rename snapshot file: %w", err) } + return hex.EncodeToString(acc.RootDuplicateLast()), didCount, merkleCompute, nil } func (m *Manager) loadSigningKey() (ed25519.PrivateKey, string, error) { data, err := os.ReadFile(m.keyPath) + if err != nil { return nil, "", fmt.Errorf("read mirror private key: %w", err) } + text := strings.TrimSpace(string(data)) + if text == "" { return nil, "", errors.New("empty mirror private key") } @@ -321,11 +398,14 @@ func (m *Manager) loadSigningKey() (ed25519.PrivateKey, string, error) { var k struct { PrivateKey string `json:"private_key"` } + if err := json.Unmarshal(data, &k); err == nil && strings.TrimSpace(k.PrivateKey) != "" { raw, err := decodeKeyString(k.PrivateKey) + if err != nil { return nil, "", err } + return keyFromRaw(raw) } @@ -335,32 +415,42 @@ func (m *Manager) loadSigningKey() (ed25519.PrivateKey, string, error) { func decodeKeyString(v string) ([]byte, error) { if strings.HasPrefix(v, "did:key:") { mb := strings.TrimPrefix(v, "did:key:") + if mb == "" || mb[0] != 'z' { return nil, errors.New("unsupported did:key format") } + decoded, err := base58.Decode(mb[1:]) + if err != nil { return nil, err } + if len(decoded) < 34 { return nil, errors.New("invalid did:key length") } + return decoded[len(decoded)-32:], nil } + if isHexString(v) { if b, err := hex.DecodeString(v); err == nil { return b, nil } } + if b, err := base64.RawURLEncoding.DecodeString(v); err == nil { return b, nil } + if b, err := base64.StdEncoding.DecodeString(v); err == nil { return b, nil } + if b, err := hex.DecodeString(v); err == nil { return b, nil } + return nil, errors.New("unknown key encoding") } @@ -368,12 +458,15 @@ func isHexString(v string) bool { if len(v) == 0 || len(v)%2 != 0 { return false } + for _, r := range v { if (r >= '0' && r <= '9') || (r >= 'a' && r <= 'f') || (r >= 'A' && r <= 'F') { continue } + return false } + return true } @@ -382,10 +475,12 @@ func keyFromRaw(raw []byte) (ed25519.PrivateKey, string, error) { case ed25519.SeedSize: pk := ed25519.NewKeyFromSeed(raw) kid := keyID(pk.Public().(ed25519.PublicKey)) + return pk, kid, nil case ed25519.PrivateKeySize: pk := ed25519.PrivateKey(raw) kid := keyID(pk.Public().(ed25519.PublicKey)) + return pk, kid, nil default: return nil, "", fmt.Errorf("invalid private key length %d", len(raw)) @@ -394,5 +489,6 @@ func keyFromRaw(raw []byte) (ed25519.PrivateKey, string, error) { func keyID(pub ed25519.PublicKey) string { sum := sha256.Sum256(pub) + return "ed25519:" + hex.EncodeToString(sum[:8]) } diff --git a/internal/checkpoint/checkpoint_test.go b/internal/checkpoint/checkpoint_test.go index 8149129..156aa9e 100644 --- a/internal/checkpoint/checkpoint_test.go +++ b/internal/checkpoint/checkpoint_test.go @@ -9,7 +9,6 @@ import ( "os" "path/filepath" "testing" - "github.com/Fuwn/plutia/internal/storage" "github.com/Fuwn/plutia/internal/types" ) @@ -17,16 +16,21 @@ import ( func TestBuildAndStoreCheckpoint(t *testing.T) { tmp := t.TempDir() store, err := storage.OpenPebble(tmp) + if err != nil { t.Fatalf("open pebble: %v", err) } + defer store.Close() pub, priv, err := ed25519.GenerateKey(rand.Reader) + if err != nil { t.Fatalf("generate key: %v", err) } + keyPath := filepath.Join(tmp, "mirror.key") + if err := os.WriteFile(keyPath, []byte(base64.RawURLEncoding.EncodeToString(priv)), 0o600); err != nil { t.Fatalf("write key: %v", err) } @@ -37,28 +41,37 @@ func TestBuildAndStoreCheckpoint(t *testing.T) { {Version: 1, DID: "did:plc:b", ChainTipHash: "tip-b"}, } cp, err := mgr.BuildAndStore(42, states, []string{"abc", "def"}) + if err != nil { t.Fatalf("build checkpoint: %v", err) } + if cp.Signature == "" || cp.CheckpointHash == "" { t.Fatalf("expected signed checkpoint") } + payload, err := marshalCheckpointPayload(cp) + if err != nil { t.Fatalf("marshal payload: %v", err) } + rawSig, err := base64.RawURLEncoding.DecodeString(cp.Signature) + if err != nil { t.Fatalf("decode signature: %v", err) } + if !ed25519.Verify(pub, payload, rawSig) { t.Fatalf("signature verification failed") } stored, ok, err := store.GetCheckpoint(cp.Sequence) + if err != nil || !ok { t.Fatalf("stored checkpoint missing: ok=%v err=%v", ok, err) } + if stored.CheckpointHash != cp.CheckpointHash { t.Fatalf("checkpoint hash mismatch") } @@ -67,33 +80,42 @@ func TestBuildAndStoreCheckpoint(t *testing.T) { func TestCheckpointRootStability(t *testing.T) { tmp := t.TempDir() store, err := storage.OpenPebble(tmp) + if err != nil { t.Fatalf("open pebble: %v", err) } + defer store.Close() if err := store.PutState(types.StateV1{Version: 1, DID: "did:plc:b", ChainTipHash: "tip-b"}); err != nil { t.Fatalf("put state b: %v", err) } + if err := store.PutState(types.StateV1{Version: 1, DID: "did:plc:a", ChainTipHash: "tip-a"}); err != nil { t.Fatalf("put state a: %v", err) } _, priv, err := ed25519.GenerateKey(rand.Reader) + if err != nil { t.Fatalf("generate key: %v", err) } + keyPath := filepath.Join(tmp, "mirror.key") + if err := os.WriteFile(keyPath, []byte(base64.RawURLEncoding.EncodeToString(priv)), 0o600); err != nil { t.Fatalf("write key: %v", err) } mgr := NewManager(store, tmp, keyPath) cp1, err := mgr.BuildAndStoreFromStore(100, []string{"abc"}) + if err != nil { t.Fatalf("build checkpoint 1: %v", err) } + cp2, err := mgr.BuildAndStoreFromStore(200, []string{"abc"}) + if err != nil { t.Fatalf("build checkpoint 2: %v", err) } @@ -101,6 +123,7 @@ func TestCheckpointRootStability(t *testing.T) { if cp1.DIDMerkleRoot != cp2.DIDMerkleRoot { t.Fatalf("did root changed for identical state: %s vs %s", cp1.DIDMerkleRoot, cp2.DIDMerkleRoot) } + if cp1.BlockMerkleRoot != cp2.BlockMerkleRoot { t.Fatalf("block root changed for identical block set: %s vs %s", cp1.BlockMerkleRoot, cp2.BlockMerkleRoot) } @@ -109,30 +132,40 @@ func TestCheckpointRootStability(t *testing.T) { func TestBuildDIDProofAtCheckpointHonorsContextCancellation(t *testing.T) { tmp := t.TempDir() store, err := storage.OpenPebble(tmp) + if err != nil { t.Fatalf("open pebble: %v", err) } + defer store.Close() _, priv, err := ed25519.GenerateKey(rand.Reader) + if err != nil { t.Fatalf("generate key: %v", err) } + keyPath := filepath.Join(tmp, "mirror.key") + if err := os.WriteFile(keyPath, []byte(base64.RawURLEncoding.EncodeToString(priv)), 0o600); err != nil { t.Fatalf("write key: %v", err) } + checkpointsDir := filepath.Join(tmp, "checkpoints") + if err := os.MkdirAll(checkpointsDir, 0o755); err != nil { t.Fatalf("mkdir checkpoints: %v", err) } + leaves := make([]types.DIDLeaf, 500) + for i := range leaves { leaves[i] = types.DIDLeaf{ DID: "did:plc:test-" + string(rune('a'+(i%26))), ChainTipHash: "tip", } } + snapshot := types.CheckpointStateSnapshotV1{ Version: 1, Sequence: 10, @@ -140,16 +173,20 @@ func TestBuildDIDProofAtCheckpointHonorsContextCancellation(t *testing.T) { Leaves: leaves, } b, err := json.Marshal(snapshot) + if err != nil { t.Fatalf("marshal snapshot: %v", err) } + if err := os.WriteFile(filepath.Join(checkpointsDir, "00000000000000000010.state.json"), b, 0o644); err != nil { t.Fatalf("write snapshot: %v", err) } mgr := NewManager(store, tmp, keyPath) ctx, cancel := context.WithCancel(context.Background()) + cancel() + if _, _, _, err := mgr.BuildDIDProofAtCheckpoint(ctx, "did:plc:test-a", "tip", 10); err == nil { t.Fatalf("expected context cancellation error") } diff --git a/internal/config/config.go b/internal/config/config.go index 609bde5..abb7444 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -8,14 +8,12 @@ import ( "strconv" "strings" "time" - "gopkg.in/yaml.v3" ) const ( ModeResolver = "resolver" ModeMirror = "mirror" - VerifyFull = "full" VerifyLazy = "lazy" VerifyStateOnly = "state-only" @@ -77,19 +75,25 @@ func Default() Config { func Load(path string) (Config, error) { cfg := Default() + if path != "" { b, err := os.ReadFile(path) + if err != nil { return Config{}, fmt.Errorf("read config: %w", err) } + if err := yaml.Unmarshal(b, &cfg); err != nil { return Config{}, fmt.Errorf("parse config: %w", err) } } + applyEnv(&cfg) + if err := cfg.Validate(); err != nil { return Config{}, err } + return cfg, nil } @@ -154,64 +158,84 @@ func (c Config) Validate() error { if c.Mode != ModeResolver && c.Mode != ModeMirror { return fmt.Errorf("invalid mode %q", c.Mode) } + switch c.VerifyPolicy { case VerifyFull, VerifyLazy, VerifyStateOnly: default: return fmt.Errorf("invalid verify policy %q", c.VerifyPolicy) } + if c.DataDir == "" { return errors.New("data_dir is required") } + if c.PLCSource == "" { return errors.New("plc_source is required") } + if c.ZstdLevel < 1 || c.ZstdLevel > 22 { return fmt.Errorf("zstd_level must be between 1 and 22, got %d", c.ZstdLevel) } + if c.BlockSizeMB < 4 || c.BlockSizeMB > 16 { return fmt.Errorf("block_size_mb must be between 4 and 16, got %d", c.BlockSizeMB) } + if c.CheckpointInterval == 0 { return errors.New("checkpoint_interval must be > 0") } + if c.CommitBatchSize <= 0 || c.CommitBatchSize > 4096 { return fmt.Errorf("commit_batch_size must be between 1 and 4096, got %d", c.CommitBatchSize) } + if c.VerifyWorkers <= 0 || c.VerifyWorkers > 1024 { return fmt.Errorf("verify_workers must be between 1 and 1024, got %d", c.VerifyWorkers) } + if c.ListenAddr == "" { return errors.New("listen_addr is required") } + if c.PollInterval <= 0 { return errors.New("poll_interval must be > 0") } + if c.RequestTimeout <= 0 { return errors.New("request_timeout must be > 0") } + if c.RateLimit.ResolveRPS <= 0 { return errors.New("rate_limit.resolve_rps must be > 0") } + if c.RateLimit.ResolveBurst <= 0 { return errors.New("rate_limit.resolve_burst must be > 0") } + if c.RateLimit.ProofRPS <= 0 { return errors.New("rate_limit.proof_rps must be > 0") } + if c.RateLimit.ProofBurst <= 0 { return errors.New("rate_limit.proof_burst must be > 0") } + if c.HTTPRetryMaxAttempts < 1 || c.HTTPRetryMaxAttempts > 32 { return fmt.Errorf("http_retry_max_attempts must be between 1 and 32, got %d", c.HTTPRetryMaxAttempts) } + if c.HTTPRetryBaseDelay <= 0 { return errors.New("http_retry_base_delay must be > 0") } + if c.HTTPRetryMaxDelay <= 0 { return errors.New("http_retry_max_delay must be > 0") } + if c.HTTPRetryBaseDelay > c.HTTPRetryMaxDelay { return errors.New("http_retry_base_delay must be <= http_retry_max_delay") } + return nil } diff --git a/internal/ingest/client.go b/internal/ingest/client.go index d25b73f..12d8bd6 100644 --- a/internal/ingest/client.go +++ b/internal/ingest/client.go @@ -17,7 +17,6 @@ import ( "strconv" "strings" "time" - "github.com/Fuwn/plutia/internal/types" ) @@ -39,17 +38,21 @@ func NewClient(source string, opts ...ClientOptions) *Client { BaseDelay: 250 * time.Millisecond, MaxDelay: 10 * time.Second, } + if len(opts) > 0 { if opts[0].MaxAttempts > 0 { cfg.MaxAttempts = opts[0].MaxAttempts } + if opts[0].BaseDelay > 0 { cfg.BaseDelay = opts[0].BaseDelay } + if opts[0].MaxDelay > 0 { cfg.MaxDelay = opts[0].MaxDelay } } + transport := &http.Transport{ Proxy: http.ProxyFromEnvironment, DialContext: (&net.Dialer{Timeout: 10 * time.Second, KeepAlive: 30 * time.Second}).DialContext, @@ -60,6 +63,7 @@ func NewClient(source string, opts ...ClientOptions) *Client { TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, } + return &Client{ source: strings.TrimRight(source, "/"), opts: cfg, @@ -80,45 +84,63 @@ func (c *Client) FetchExportLimited(ctx context.Context, after uint64, limit uin } u, err := url.Parse(c.source) + if err != nil { return nil, fmt.Errorf("parse plc source: %w", err) } + u.Path = strings.TrimRight(u.Path, "/") + "/export" q := u.Query() + q.Set("after", fmt.Sprintf("%d", after)) - u.RawQuery = q.Encode() + u.RawQuery = q.Encode() req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil) + if err != nil { return nil, fmt.Errorf("new request: %w", err) } + maxAttempts := c.opts.MaxAttempts + if maxAttempts < 1 { maxAttempts = 1 } + var lastErr error + for attempt := 1; attempt <= maxAttempts; attempt++ { records, retryAfter, retryable, err := c.fetchExportOnce(req, limit) + if err == nil { return records, nil } + lastErr = err + if !retryable || attempt == maxAttempts || ctx.Err() != nil { break } + delay := retryAfter + if delay <= 0 { delay = c.backoffDelay(attempt) } + log.Printf("retrying plc export fetch attempt=%d after_seq=%d delay=%s reason=%v", attempt, after, delay, err) + timer := time.NewTimer(delay) + select { case <-ctx.Done(): timer.Stop() + return nil, ctx.Err() case <-timer.C: } } + return nil, lastErr } @@ -133,11 +155,15 @@ func (e httpStatusError) Error() string { func (c *Client) fetchExportOnce(req *http.Request, limit uint64) ([]types.ExportRecord, time.Duration, bool, error) { reqClone := req.Clone(req.Context()) + reqClone.Header.Set("Accept-Encoding", "gzip") + resp, err := c.http.Do(reqClone) + if err != nil { return nil, 0, isTransientNetworkErr(err), fmt.Errorf("fetch export: %w", err) } + defer resp.Body.Close() if resp.StatusCode != http.StatusOK { @@ -145,12 +171,16 @@ func (c *Client) fetchExportOnce(req *http.Request, limit uint64) ([]types.Expor body := strings.TrimSpace(string(b)) retryDelay := parseRetryAfter(resp.Header.Get("Retry-After")) err := httpStatusError{StatusCode: resp.StatusCode, Body: body} + return nil, retryDelay, shouldRetryStatus(resp.StatusCode), err } + records, err := decodeExportBody(resp.Body, limit) + if err != nil { return nil, 0, false, err } + return records, 0, false, nil } @@ -158,23 +188,31 @@ func (c *Client) backoffDelay(attempt int) time.Duration { if attempt < 1 { attempt = 1 } + delay := c.opts.BaseDelay + if delay <= 0 { delay = 250 * time.Millisecond } + maxDelay := c.opts.MaxDelay + if maxDelay <= 0 { maxDelay = 10 * time.Second } + for i := 1; i < attempt; i++ { delay *= 2 + if delay >= maxDelay { return maxDelay } } + if delay > maxDelay { return maxDelay } + return delay } @@ -189,18 +227,23 @@ func shouldRetryStatus(status int) bool { func parseRetryAfter(v string) time.Duration { v = strings.TrimSpace(v) + if v == "" { return 0 } + if secs, err := strconv.Atoi(v); err == nil && secs > 0 { return time.Duration(secs) * time.Second } + if ts, err := http.ParseTime(v); err == nil { delay := time.Until(ts) + if delay > 0 { return delay } } + return 0 } @@ -208,63 +251,85 @@ func isTransientNetworkErr(err error) bool { if err == nil { return false } + var netErr net.Error + if errors.As(err, &netErr) { return true } + return errors.Is(err, io.ErrUnexpectedEOF) || errors.Is(err, io.EOF) } func decodeExportBody(r io.Reader, limit uint64) ([]types.ExportRecord, error) { br := bufio.NewReader(r) first, err := peekFirstNonSpace(br) + if err != nil { if err == io.EOF { return nil, nil } + return nil, err } if first == '[' { b, err := io.ReadAll(br) + if err != nil { return nil, fmt.Errorf("read export body: %w", err) } + trimmed := bytes.TrimSpace(b) + var records []types.ExportRecord + if err := json.Unmarshal(trimmed, &records); err != nil { return nil, fmt.Errorf("decode export json array: %w", err) } + if limit > 0 && uint64(len(records)) > limit { records = records[:limit] } + return records, nil } + out := make([]types.ExportRecord, 0, 1024) + for { line, err := br.ReadBytes('\n') isEOF := errors.Is(err, io.EOF) + if err != nil && !isEOF { return nil, fmt.Errorf("read ndjson line: %w", err) } + if limit > 0 && uint64(len(out)) >= limit { return out, nil } + trimmed := bytes.TrimSpace(line) + if len(trimmed) > 0 { var rec types.ExportRecord + if err := json.Unmarshal(trimmed, &rec); err != nil { if isEOF && isTrailingNDJSONPartial(err) { return out, nil } + return nil, fmt.Errorf("decode ndjson line: %w", err) } + out = append(out, rec) } + if isEOF { break } } + return out, nil } @@ -272,19 +337,23 @@ func isTrailingNDJSONPartial(err error) bool { if !errors.Is(err, io.ErrUnexpectedEOF) && !strings.Contains(err.Error(), "unexpected end of JSON input") { return false } + return true } func peekFirstNonSpace(br *bufio.Reader) (byte, error) { for { b, err := br.ReadByte() + if err != nil { return 0, err } + if !isSpace(b) { if err := br.UnreadByte(); err != nil { return 0, fmt.Errorf("unread byte: %w", err) } + return b, nil } } @@ -301,27 +370,37 @@ func isSpace(b byte) bool { func (c *Client) fetchFromFile(after uint64, limit uint64) ([]types.ExportRecord, error) { path := c.source + if strings.HasPrefix(path, "file://") { path = strings.TrimPrefix(path, "file://") } + path = filepath.Clean(path) b, err := os.ReadFile(path) + if err != nil { return nil, fmt.Errorf("read source file: %w", err) } + recs, err := decodeExportBody(bytes.NewReader(b), 0) + if err != nil { return nil, err } + out := make([]types.ExportRecord, 0, len(recs)) + for _, r := range recs { if r.Seq <= after { continue } + out = append(out, r) + if limit > 0 && uint64(len(out)) >= limit { break } } + return out, nil } diff --git a/internal/ingest/client_retry_test.go b/internal/ingest/client_retry_test.go index a23fceb..b019ac6 100644 --- a/internal/ingest/client_retry_test.go +++ b/internal/ingest/client_retry_test.go @@ -12,15 +12,20 @@ import ( func TestFetchExportLimitedRetries429ThenSucceeds(t *testing.T) { var attempts atomic.Int32 + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { n := attempts.Add(1) + if n <= 2 { w.Header().Set("Retry-After", "0") http.Error(w, "rate limited", http.StatusTooManyRequests) + return } + _, _ = fmt.Fprintln(w, `{"seq":1,"did":"did:plc:alice","cid":"cid1","operation":{"x":1}}`) })) + defer ts.Close() client := NewClient(ts.URL, ClientOptions{ @@ -29,12 +34,15 @@ func TestFetchExportLimitedRetries429ThenSucceeds(t *testing.T) { MaxDelay: 2 * time.Millisecond, }) records, err := client.FetchExportLimited(context.Background(), 0, 0) + if err != nil { t.Fatalf("fetch export: %v", err) } + if len(records) != 1 { t.Fatalf("record count mismatch: got %d want 1", len(records)) } + if got := attempts.Load(); got != 3 { t.Fatalf("attempt count mismatch: got %d want 3", got) } @@ -42,10 +50,12 @@ func TestFetchExportLimitedRetries429ThenSucceeds(t *testing.T) { func TestFetchExportLimitedDoesNotRetry400(t *testing.T) { var attempts atomic.Int32 + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { attempts.Add(1) http.Error(w, "bad request", http.StatusBadRequest) })) + defer ts.Close() client := NewClient(ts.URL, ClientOptions{ @@ -54,9 +64,11 @@ func TestFetchExportLimitedDoesNotRetry400(t *testing.T) { MaxDelay: 2 * time.Millisecond, }) _, err := client.FetchExportLimited(context.Background(), 0, 0) + if err == nil { t.Fatalf("expected error for 400 response") } + if got := attempts.Load(); got != 1 { t.Fatalf("unexpected retries on 400: got attempts=%d want 1", got) } diff --git a/internal/ingest/client_test.go b/internal/ingest/client_test.go index ff80225..50595ec 100644 --- a/internal/ingest/client_test.go +++ b/internal/ingest/client_test.go @@ -11,14 +11,16 @@ func TestDecodeExportBody_IgnoresTrailingPartialNDJSONLine(t *testing.T) { `{"seq":2,"did":"did:plc:bob","cid":"cid2","operation":{"x":2}}`, `{"seq":3,"did":"did:plc:carol","cid":"cid3","operation":{"x":3}`, }, "\n") - records, err := decodeExportBody(strings.NewReader(body), 0) + if err != nil { t.Fatalf("decode export body: %v", err) } + if len(records) != 2 { t.Fatalf("record count mismatch: got %d want 2", len(records)) } + if records[0].Seq != 1 || records[1].Seq != 2 { t.Fatalf("unexpected sequences: got [%d %d], want [1 2]", records[0].Seq, records[1].Seq) } @@ -30,8 +32,8 @@ func TestDecodeExportBody_FailsOnMalformedNonTrailingNDJSONLine(t *testing.T) { `{"seq":2,"did":"did:plc:bob","cid":"cid2","operation":{"x":2}`, `{"seq":3,"did":"did:plc:carol","cid":"cid3","operation":{"x":3}}`, }, "\n") - _, err := decodeExportBody(strings.NewReader(body), 0) + if err == nil { t.Fatalf("expected malformed middle line to fail") } diff --git a/internal/ingest/service.go b/internal/ingest/service.go index f6b2433..1779db9 100644 --- a/internal/ingest/service.go +++ b/internal/ingest/service.go @@ -15,7 +15,6 @@ import ( "sync" "sync/atomic" "time" - "github.com/Fuwn/plutia/internal/checkpoint" "github.com/Fuwn/plutia/internal/config" "github.com/Fuwn/plutia/internal/state" @@ -73,18 +72,23 @@ func NewService(cfg config.Config, store storage.Store, client *Client, blockLog blockLog: blockLog, startedAt: time.Now(), } + if didCount, err := countStates(store); err == nil { atomic.StoreUint64(&s.stats.DIDCount, didCount) } + if cp, ok, err := store.GetLatestCheckpoint(); err == nil && ok { atomic.StoreUint64(&s.stats.CheckpointSeq, cp.Sequence) } + if cfg.Mode == config.ModeMirror && blockLog != nil { s.appender = newBlockAppender(blockLog, 1024) + if _, err := blockLog.ValidateAndRecover(store); err != nil { s.MarkCorrupted(err) } } + return s } @@ -95,9 +99,11 @@ func (s *Service) SetMaxOps(max uint64) { func (s *Service) Replay(ctx context.Context) error { for { changed, err := s.RunOnce(ctx) + if err != nil { return err } + if !changed { return nil } @@ -106,11 +112,14 @@ func (s *Service) Replay(ctx context.Context) error { func (s *Service) Poll(ctx context.Context) error { ticker := time.NewTicker(s.cfg.PollInterval) + defer ticker.Stop() + for { if _, err := s.RunOnce(ctx); err != nil { atomic.AddUint64(&s.stats.Errors, 1) } + select { case <-ctx.Done(): return ctx.Err() @@ -123,42 +132,59 @@ func (s *Service) RunOnce(ctx context.Context) (bool, error) { if err := s.CorruptionError(); err != nil { return false, err } + if s.maxOps > 0 && s.runOps >= s.maxOps { return false, nil } + lastSeq, err := s.store.GetGlobalSeq() + if err != nil { return false, fmt.Errorf("load global sequence: %w", err) } + limit := uint64(0) + if s.maxOps > 0 { remaining := s.maxOps - s.runOps + if remaining == 0 { return false, nil } + limit = remaining } + records, err := s.client.FetchExportLimited(ctx, lastSeq, limit) + if err != nil { return false, err } + if len(records) == 0 { atomic.StoreUint64(&s.stats.LagOps, 0) + return false, nil } + latestSourceSeq := records[len(records)-1].Seq committed, err := s.processRecords(ctx, records) s.runOps += committed + if err != nil { atomic.AddUint64(&s.stats.Errors, 1) + return committed > 0, err } + lastCommitted := atomic.LoadUint64(&s.stats.LastSeq) + if latestSourceSeq > lastCommitted { atomic.StoreUint64(&s.stats.LagOps, latestSourceSeq-lastCommitted) } else { atomic.StoreUint64(&s.stats.LagOps, 0) } + return true, nil } @@ -177,50 +203,72 @@ type verifyResult struct { func (s *Service) processRecords(ctx context.Context, records []types.ExportRecord) (uint64, error) { verified, err := s.verifyRecords(ctx, records) + if err != nil { return 0, err } + return s.commitVerified(ctx, verified) } func (s *Service) verifyRecords(ctx context.Context, records []types.ExportRecord) ([]verifyResult, error) { _ = ctx workers := s.cfg.VerifyWorkers + if workers < 1 { workers = 1 } + queues := make([]chan verifyTask, workers) results := make(chan verifyResult, len(records)) + var wg sync.WaitGroup for i := 0; i < workers; i++ { queues[i] = make(chan verifyTask, 64) + wg.Add(1) + go func(queue <-chan verifyTask) { defer wg.Done() + cache := map[string]*types.StateV1{} + for task := range queue { op, err := types.ParseOperation(task.rec) + if err != nil { results <- verifyResult{index: task.index, err: err} + continue } + existing, err := s.loadExistingState(cache, op.DID) + if err != nil { results <- verifyResult{index: task.index, err: err} + continue } + if err := s.verifier.VerifyOperation(op, existing); err != nil { atomic.AddUint64(&s.stats.VerifyFailures, 1) + results <- verifyResult{index: task.index, err: err} + continue } + next, err := state.ComputeNextState(op, existing) + if err != nil { results <- verifyResult{index: task.index, err: err} + continue } + cache[op.DID] = cloneState(&next) + results <- verifyResult{index: task.index, op: op, state: next, newDID: existing == nil} } }(queues[i]) @@ -228,13 +276,17 @@ func (s *Service) verifyRecords(ctx context.Context, records []types.ExportRecor for idx, rec := range records { worker := didWorker(rec.DID, workers) + queues[worker] <- verifyTask{index: idx, rec: rec} } + for _, q := range queues { close(q) } + wg.Wait() close(results) + return collectVerifiedInOrder(len(records), results) } @@ -242,15 +294,21 @@ func (s *Service) loadExistingState(cache map[string]*types.StateV1, did string) if existing, ok := cache[did]; ok { return cloneState(existing), nil } + stateVal, ok, err := s.store.GetState(did) + if err != nil { return nil, err } + if !ok { cache[did] = nil + return nil, nil } + cache[did] = cloneState(&stateVal) + return cloneState(&stateVal), nil } @@ -259,37 +317,48 @@ func collectVerifiedInOrder(total int, results <-chan verifyResult) ([]verifyRes seen := make([]bool, total) firstErr := error(nil) received := 0 + for r := range results { received++ + if r.index < 0 || r.index >= total { if firstErr == nil { firstErr = fmt.Errorf("verify worker returned out-of-range index %d", r.index) } + continue } + if seen[r.index] { if firstErr == nil { firstErr = fmt.Errorf("duplicate verify result for index %d", r.index) } + continue } + seen[r.index] = true ordered[r.index] = r + if r.err != nil && firstErr == nil { firstErr = r.err } } + if received != total && firstErr == nil { firstErr = fmt.Errorf("verify result count mismatch: got %d want %d", received, total) } + if firstErr != nil { return nil, firstErr } + for i := range seen { if !seen[i] { return nil, fmt.Errorf("missing verify result index %d", i) } } + return ordered, nil } @@ -297,8 +366,10 @@ func didWorker(did string, workers int) int { if workers <= 1 { return 0 } + hasher := fnv.New32a() _, _ = hasher.Write([]byte(did)) + return int(hasher.Sum32() % uint32(workers)) } @@ -306,50 +377,64 @@ func cloneState(in *types.StateV1) *types.StateV1 { if in == nil { return nil } + out := *in out.DIDDocument = append([]byte(nil), in.DIDDocument...) out.RotationKeys = append([]string(nil), in.RotationKeys...) + return &out } func (s *Service) commitVerified(ctx context.Context, verified []verifyResult) (uint64, error) { _ = ctx batchSize := s.cfg.CommitBatchSize + if batchSize < 1 { batchSize = 1 } + pendingOps := make([]storage.OperationMutation, 0, batchSize) pendingSeqs := make([]uint64, 0, batchSize) pendingBlockHashes := map[uint64]string{} pendingNewDIDs := map[string]struct{}{} + var committed uint64 commit := func() error { if len(pendingOps) == 0 { return nil } + if s.cfg.Mode == config.ModeMirror { flush, err := s.appender.Flush() + if err != nil { return err } + if flush != nil { pendingBlockHashes[flush.BlockID] = flush.Hash } } + blockEntries := make([]storage.BlockHashEntry, 0, len(pendingBlockHashes)) + for id, hash := range pendingBlockHashes { blockEntries = append(blockEntries, storage.BlockHashEntry{BlockID: id, Hash: hash}) } + sort.Slice(blockEntries, func(i, j int) bool { return blockEntries[i].BlockID < blockEntries[j].BlockID }) if err := s.store.ApplyOperationsBatch(pendingOps, blockEntries); err != nil { return err } + lastSeq := pendingSeqs[len(pendingSeqs)-1] + atomic.StoreUint64(&s.stats.LastSeq, lastSeq) atomic.AddUint64(&s.stats.IngestedOps, uint64(len(pendingOps))) atomic.AddUint64(&s.stats.DIDCount, uint64(len(pendingNewDIDs))) + committed += uint64(len(pendingOps)) if s.checkpoints != nil { @@ -364,59 +449,78 @@ func (s *Service) commitVerified(ctx context.Context, verified []verifyResult) ( pendingOps = pendingOps[:0] pendingSeqs = pendingSeqs[:0] + clear(pendingBlockHashes) clear(pendingNewDIDs) + return nil } for _, v := range verified { var ref *types.BlockRefV1 + if s.cfg.Mode == config.ModeMirror { if s.appender == nil { return committed, fmt.Errorf("mirror mode requires block appender") } + result, err := s.appender.Append(v.op) + if err != nil { return committed, err } + if result.Flush != nil { pendingBlockHashes[result.Flush.BlockID] = result.Flush.Hash } + result.Ref.Received = time.Now().UTC().Format(time.RFC3339) ref = &result.Ref } + pendingOps = append(pendingOps, storage.OperationMutation{State: v.state, Ref: ref}) pendingSeqs = append(pendingSeqs, v.op.Sequence) + if v.newDID { pendingNewDIDs[v.op.DID] = struct{}{} } + if len(pendingOps) >= batchSize { if err := commit(); err != nil { return committed, err } } } + if err := commit(); err != nil { return committed, err } + return committed, nil } func (s *Service) createCheckpoint(sequence uint64) error { blocks, err := s.store.ListBlockHashes() + if err != nil { return fmt.Errorf("list blocks for checkpoint: %w", err) } + done := make(chan struct{}) usageCh := make(chan checkpointProcessUsage, 1) + go sampleCheckpointProcessUsage(done, usageCh) cp, metrics, err := s.checkpoints.BuildAndStoreFromStoreWithMetrics(sequence, blocks) + close(done) + usage := <-usageCh + if err != nil { return fmt.Errorf("build checkpoint: %w", err) } + log.Printf( "checkpoint_metrics seq=%d did_count=%d merkle_compute_ms=%d total_checkpoint_ms=%d cpu_pct=%.2f rss_peak_mb=%.2f completed_unix_ms=%d", cp.Sequence, @@ -428,41 +532,54 @@ func (s *Service) createCheckpoint(sequence uint64) error { time.Now().UnixMilli(), ) atomic.StoreUint64(&s.stats.CheckpointSeq, cp.Sequence) + if s.metricsSink != nil { s.metricsSink.ObserveCheckpoint(metrics.Total, cp.Sequence) } + return nil } func (s *Service) Snapshot(ctx context.Context) (types.CheckpointV1, error) { if s.appender != nil { flush, err := s.appender.Flush() + if err != nil { return types.CheckpointV1{}, err } + if flush != nil { if err := s.store.ApplyOperationsBatch(nil, []storage.BlockHashEntry{{BlockID: flush.BlockID, Hash: flush.Hash}}); err != nil { return types.CheckpointV1{}, err } } } + _ = ctx seq, err := s.store.GetGlobalSeq() + if err != nil { return types.CheckpointV1{}, err } + blocks, err := s.store.ListBlockHashes() + if err != nil { return types.CheckpointV1{}, err } + cp, metrics, err := s.checkpoints.BuildAndStoreFromStoreWithMetrics(seq, blocks) + if err != nil { return types.CheckpointV1{}, err } + atomic.StoreUint64(&s.stats.CheckpointSeq, cp.Sequence) + if s.metricsSink != nil { s.metricsSink.ObserveCheckpoint(metrics.Total, cp.Sequence) } + return cp, nil } @@ -474,31 +591,44 @@ type checkpointProcessUsage struct { func sampleCheckpointProcessUsage(done <-chan struct{}, out chan<- checkpointProcessUsage) { pid := os.Getpid() ticker := time.NewTicker(25 * time.Millisecond) + defer ticker.Stop() + var cpuSum float64 var samples int var rssPeak int64 + sample := func() { cpu, rss, err := processMetrics(pid) + if err != nil { return } + cpuSum += cpu + samples++ + if rss > rssPeak { rssPeak = rss } } + sample() + for { select { case <-done: sample() + avg := 0.0 + if samples > 0 { avg = cpuSum / float64(samples) } + out <- checkpointProcessUsage{CPUPercentAvg: avg, RSSPeakKB: rssPeak} + return case <-ticker.C: sample() @@ -508,15 +638,20 @@ func sampleCheckpointProcessUsage(done <-chan struct{}, out chan<- checkpointPro func processMetrics(pid int) (cpuPct float64, rssKB int64, err error) { out, err := exec.Command("ps", "-p", strconv.Itoa(pid), "-o", "pcpu=,rss=").Output() + if err != nil { return 0, 0, err } + fields := strings.Fields(string(out)) + if len(fields) < 2 { return 0, 0, fmt.Errorf("unexpected ps output: %q", strings.TrimSpace(string(out))) } + cpuPct, _ = strconv.ParseFloat(fields[0], 64) rssKB, _ = strconv.ParseInt(fields[1], 10, 64) + return cpuPct, rssKB, nil } @@ -524,51 +659,69 @@ func (s *Service) VerifyDID(ctx context.Context, did string) error { if err := s.CorruptionError(); err != nil { return err } + if s.cfg.Mode != config.ModeMirror { _, ok, err := s.store.GetState(did) + if err != nil { return err } + if !ok { return fmt.Errorf("did not found") } + return nil } + seqs, err := s.store.ListDIDSequences(did) + if err != nil { return err } + if len(seqs) == 0 { return fmt.Errorf("no operations for did %s", did) } var previous *types.StateV1 + for _, seq := range seqs { if err := ctx.Err(); err != nil { return err } + ref, ok, err := s.store.GetOpSeqRef(seq) + if err != nil { return err } + if !ok { return fmt.Errorf("missing op reference for seq %d", seq) } + payload, err := s.blockLog.ReadRecord(ref) + if err != nil { return err } + rec := types.ExportRecord{Seq: seq, DID: did, CID: ref.CID, Operation: payload} op, err := types.ParseOperation(rec) + if err != nil { return err } + if err := s.verifier.VerifyOperation(op, previous); err != nil { return fmt.Errorf("verify seq %d failed: %w", seq, err) } + s := types.StateV1{DID: did, ChainTipHash: op.CID, LatestOpSeq: seq} previous = &s } + return nil } @@ -576,51 +729,69 @@ func (s *Service) RecomputeTipAtOrBefore(ctx context.Context, did string, sequen if err := s.CorruptionError(); err != nil { return "", nil, err } + if s.cfg.Mode != config.ModeMirror { return "", nil, errors.New("historical proof requires mirror mode") } + seqs, err := s.store.ListDIDSequences(did) + if err != nil { return "", nil, err } + if len(seqs) == 0 { return "", nil, fmt.Errorf("no operations for did %s", did) } + filtered := make([]uint64, 0, len(seqs)) + for _, seq := range seqs { if seq <= sequence { filtered = append(filtered, seq) } } + if len(filtered) == 0 { return "", nil, fmt.Errorf("no operations for did %s at checkpoint %d", did, sequence) } + sort.Slice(filtered, func(i, j int) bool { return filtered[i] < filtered[j] }) var previous *types.StateV1 var tip string + for _, seq := range filtered { if err := ctx.Err(); err != nil { return "", nil, err } + ref, ok, err := s.store.GetOpSeqRef(seq) + if err != nil { return "", nil, err } + if !ok { return "", nil, fmt.Errorf("missing op reference for seq %d", seq) } + payload, err := s.blockLog.ReadRecord(ref) + if err != nil { return "", nil, err } + op, err := types.ParseOperation(types.ExportRecord{Seq: seq, DID: did, CID: ref.CID, Operation: payload}) + if err != nil { return "", nil, err } + if err := s.verifier.VerifyOperation(op, previous); err != nil { return "", nil, fmt.Errorf("verify seq %d failed: %w", seq, err) } + tip = op.CID next := types.StateV1{ DID: did, @@ -628,12 +799,15 @@ func (s *Service) RecomputeTipAtOrBefore(ctx context.Context, did string, sequen LatestOpSeq: seq, } prior := []string(nil) + if previous != nil { prior = append(prior, previous.RotationKeys...) } + next.RotationKeys = extractRotationKeysFromPayload(op.Payload, prior) previous = &next } + return tip, filtered, nil } @@ -641,32 +815,44 @@ func (s *Service) LoadDIDLog(ctx context.Context, did string) ([]types.ExportRec if err := s.CorruptionError(); err != nil { return nil, err } + if s.cfg.Mode != config.ModeMirror || s.blockLog == nil { return nil, ErrHistoryNotStored } + seqs, err := s.store.ListDIDSequences(did) + if err != nil { return nil, err } + if len(seqs) == 0 { return nil, ErrDIDNotFound } + out := make([]types.ExportRecord, 0, len(seqs)) + for _, seq := range seqs { if err := ctx.Err(); err != nil { return nil, err } + ref, ok, err := s.store.GetOpSeqRef(seq) + if err != nil { return nil, err } + if !ok { return nil, fmt.Errorf("missing op reference for seq %d", seq) } + payload, err := s.blockLog.ReadRecord(ref) + if err != nil { return nil, err } + out = append(out, types.ExportRecord{ Seq: seq, DID: did, @@ -676,6 +862,7 @@ func (s *Service) LoadDIDLog(ctx context.Context, did string) ([]types.ExportRec Operation: json.RawMessage(payload), }) } + return out, nil } @@ -683,31 +870,43 @@ func (s *Service) LoadLatestDIDOperation(ctx context.Context, did string) (types if err := s.CorruptionError(); err != nil { return types.ExportRecord{}, err } + if s.cfg.Mode != config.ModeMirror || s.blockLog == nil { return types.ExportRecord{}, ErrHistoryNotStored } + seqs, err := s.store.ListDIDSequences(did) + if err != nil { return types.ExportRecord{}, err } + if len(seqs) == 0 { return types.ExportRecord{}, ErrDIDNotFound } + lastSeq := seqs[len(seqs)-1] + if err := ctx.Err(); err != nil { return types.ExportRecord{}, err } + ref, ok, err := s.store.GetOpSeqRef(lastSeq) + if err != nil { return types.ExportRecord{}, err } + if !ok { return types.ExportRecord{}, fmt.Errorf("missing op reference for seq %d", lastSeq) } + payload, err := s.blockLog.ReadRecord(ref) + if err != nil { return types.ExportRecord{}, err } + return types.ExportRecord{ Seq: lastSeq, DID: did, @@ -722,32 +921,44 @@ func (s *Service) LoadCurrentPLCData(ctx context.Context, did string) (map[strin if err := s.CorruptionError(); err != nil { return nil, err } + if s.cfg.Mode != config.ModeMirror || s.blockLog == nil { state, ok, err := s.store.GetState(did) + if err != nil { return nil, err } + if !ok { return nil, ErrDIDNotFound } + var doc map[string]any + if err := json.Unmarshal(state.DIDDocument, &doc); err != nil { return nil, err } + return doc, nil } + last, err := s.LoadLatestDIDOperation(ctx, did) + if err != nil { return nil, err } + var op map[string]any + if err := json.Unmarshal(last.Operation, &op); err != nil { return nil, fmt.Errorf("decode latest operation: %w", err) } + delete(op, "sig") delete(op, "signature") delete(op, "sigPayload") delete(op, "signaturePayload") + return op, nil } @@ -755,16 +966,21 @@ func (s *Service) StreamExport(ctx context.Context, after time.Time, limit int, if err := s.CorruptionError(); err != nil { return err } + if s.cfg.Mode != config.ModeMirror || s.blockLog == nil { return nil } + if limit <= 0 { limit = 1000 } + const maxCount = 1000 + if limit > maxCount { limit = maxCount } + afterSet := !after.IsZero() emitted := 0 stop := errors.New("stop export iteration") @@ -772,25 +988,33 @@ func (s *Service) StreamExport(ctx context.Context, after time.Time, limit int, if err := ctx.Err(); err != nil { return err } + if emitted >= limit { return stop } + if afterSet { if strings.TrimSpace(ref.Received) == "" { return nil } + ts, err := time.Parse(time.RFC3339, ref.Received) + if err != nil { return nil } + if !ts.After(after) { return nil } } + payload, err := s.blockLog.ReadRecord(ref) + if err != nil { return err } + rec := types.ExportRecord{ Seq: seq, DID: ref.DID, @@ -799,41 +1023,54 @@ func (s *Service) StreamExport(ctx context.Context, after time.Time, limit int, Nullified: detectNullified(payload), Operation: json.RawMessage(payload), } + emitted++ + return emit(rec) }) + if err == nil || errors.Is(err, stop) { return nil } + return err } func detectNullified(operation []byte) bool { var payload map[string]any + if err := json.Unmarshal(operation, &payload); err != nil { return false } + v, _ := payload["nullified"].(bool) + return v } func (s *Service) Flush(ctx context.Context) error { _ = ctx + if err := s.CorruptionError(); err != nil { return err } + if s.appender == nil { return nil } + flush, err := s.appender.Flush() + if err != nil { return err } + if flush != nil { if err := s.store.ApplyOperationsBatch(nil, []storage.BlockHashEntry{{BlockID: flush.BlockID, Hash: flush.Hash}}); err != nil { return err } } + return nil } @@ -841,6 +1078,7 @@ func (s *Service) MarkCorrupted(err error) { if err == nil { return } + s.corrupted.Store(true) s.corruptErr.Store(err.Error()) } @@ -853,10 +1091,13 @@ func (s *Service) CorruptionError() error { if !s.IsCorrupted() { return nil } + msg, _ := s.corruptErr.Load().(string) + if msg == "" { msg = "data corruption detected" } + return fmt.Errorf("data corruption detected: %s", msg) } @@ -865,15 +1106,19 @@ func extractRotationKeysFromPayload(payload map[string]any, prior []string) []st seen := map[string]struct{}{} add := func(v string) { v = strings.TrimSpace(v) + if v == "" { return } + if _, ok := seen[v]; ok { return } + seen[v] = struct{}{} out = append(out, v) } + if arr, ok := payload["rotationKeys"].([]any); ok { for _, v := range arr { if s, ok := v.(string); ok { @@ -881,17 +1126,21 @@ func extractRotationKeysFromPayload(payload map[string]any, prior []string) []st } } } + if recovery, ok := payload["recoveryKey"].(string); ok { add(recovery) } + if signing, ok := payload["signingKey"].(string); ok { add(signing) } + if len(out) == 0 { for _, v := range prior { add(v) } } + return out } @@ -908,9 +1157,11 @@ func (s *Service) SetMetricsSink(sink MetricsSink) { func (s *Service) Stats() Stats { ingested := atomic.LoadUint64(&s.stats.IngestedOps) opsPerSec := 0.0 + if elapsed := time.Since(s.startedAt).Seconds(); elapsed > 0 { opsPerSec = float64(ingested) / elapsed } + return Stats{ IngestedOps: ingested, Errors: atomic.LoadUint64(&s.stats.Errors), @@ -925,12 +1176,15 @@ func (s *Service) Stats() Stats { func countStates(store storage.Store) (uint64, error) { var count uint64 + if err := store.ForEachState(func(types.StateV1) error { count++ + return nil }); err != nil { return 0, err } + return count, nil } @@ -959,34 +1213,47 @@ func newBlockAppender(log *storage.BlockLog, buffer int) *blockAppender { queue: make(chan appendRequest, buffer), closed: make(chan struct{}), } + go ba.run() + return ba } func (b *blockAppender) run() { defer close(b.closed) + for req := range b.queue { if req.Flush { flush, err := b.log.Flush() + req.Reply <- appendResult{Flush: flush, Err: err} + continue } + ref, flush, err := b.log.Append(req.Operation.Sequence, req.Operation.DID, req.Operation.CID, req.Operation.Prev, req.Operation.CanonicalBytes) + req.Reply <- appendResult{Ref: ref, Flush: flush, Err: err} } } func (b *blockAppender) Append(op types.ParsedOperation) (appendResult, error) { reply := make(chan appendResult, 1) + b.queue <- appendRequest{Operation: op, Reply: reply} + res := <-reply + return res, res.Err } func (b *blockAppender) Flush() (*storage.FlushInfo, error) { reply := make(chan appendResult, 1) + b.queue <- appendRequest{Flush: true, Reply: reply} + res := <-reply + return res.Flush, res.Err } diff --git a/internal/ingest/service_integration_test.go b/internal/ingest/service_integration_test.go index e01dd4e..9985367 100644 --- a/internal/ingest/service_integration_test.go +++ b/internal/ingest/service_integration_test.go @@ -10,7 +10,6 @@ import ( "os" "path/filepath" "testing" - "github.com/Fuwn/plutia/internal/checkpoint" "github.com/Fuwn/plutia/internal/config" "github.com/Fuwn/plutia/internal/storage" @@ -20,15 +19,19 @@ import ( func TestReplayIntegration(t *testing.T) { tmp := t.TempDir() dataDir := filepath.Join(tmp, "data") + if err := os.MkdirAll(dataDir, 0o755); err != nil { t.Fatalf("mkdir data dir: %v", err) } keySeed := make([]byte, ed25519.SeedSize) + if _, err := rand.Read(keySeed); err != nil { t.Fatalf("rand seed: %v", err) } + keyPath := filepath.Join(tmp, "mirror.key") + if err := os.WriteFile(keyPath, []byte(base64.RawURLEncoding.EncodeToString(keySeed)), 0o600); err != nil { t.Fatalf("write mirror key: %v", err) } @@ -36,27 +39,35 @@ func TestReplayIntegration(t *testing.T) { records := buildSignedRecords(t) sourcePath := filepath.Join(tmp, "sample.ndjson") file, err := os.Create(sourcePath) + if err != nil { t.Fatalf("create source: %v", err) } + for _, rec := range records { b, _ := json.Marshal(rec) + if _, err := fmt.Fprintln(file, string(b)); err != nil { t.Fatalf("write source: %v", err) } } + file.Close() store, err := storage.OpenPebble(dataDir) + if err != nil { t.Fatalf("open pebble: %v", err) } + defer store.Close() + if err := store.SetMode(config.ModeMirror); err != nil { t.Fatalf("set mode: %v", err) } bl, err := storage.OpenBlockLog(dataDir, 3, 4) + if err != nil { t.Fatalf("open block log: %v", err) } @@ -73,31 +84,39 @@ func TestReplayIntegration(t *testing.T) { MirrorPrivateKeyPath: keyPath, } service := NewService(cfg, store, NewClient(sourcePath), bl, checkpoint.NewManager(store, dataDir, keyPath)) + if err := service.Replay(context.Background()); err != nil { t.Fatalf("replay: %v", err) } + if err := service.Flush(context.Background()); err != nil { t.Fatalf("flush: %v", err) } seq, err := store.GetGlobalSeq() + if err != nil { t.Fatalf("get global seq: %v", err) } + if seq != 3 { t.Fatalf("global seq mismatch: got %d want 3", seq) } s, ok, err := store.GetState("did:plc:alice") + if err != nil { t.Fatalf("get state: %v", err) } + if !ok { t.Fatalf("missing alice state") } + if s.LatestOpSeq != 2 { t.Fatalf("latest op seq mismatch for alice: got %d want 2", s.LatestOpSeq) } + if err := service.VerifyDID(context.Background(), "did:plc:alice"); err != nil { t.Fatalf("verify alice did: %v", err) } @@ -105,6 +124,7 @@ func TestReplayIntegration(t *testing.T) { if _, err := service.Snapshot(context.Background()); err != nil { t.Fatalf("snapshot: %v", err) } + if _, ok, err := store.GetLatestCheckpoint(); err != nil || !ok { t.Fatalf("expected latest checkpoint, err=%v ok=%v", err, ok) } @@ -112,14 +132,19 @@ func TestReplayIntegration(t *testing.T) { func buildSignedRecords(t *testing.T) []types.ExportRecord { t.Helper() + pub1, priv1, err := ed25519.GenerateKey(rand.Reader) + if err != nil { t.Fatalf("generate key1: %v", err) } + pub2, priv2, err := ed25519.GenerateKey(rand.Reader) + if err != nil { t.Fatalf("generate key2: %v", err) } + var out []types.ExportRecord mk := func(seq uint64, did, prev string, pub ed25519.PublicKey, priv ed25519.PrivateKey) types.ExportRecord { @@ -127,9 +152,11 @@ func buildSignedRecords(t *testing.T) []types.ExportRecord { "did": did, "didDoc": map[string]any{"id": did, "seq": seq}, } + if prev != "" { payloadDoc["prev"] = prev } + payloadBytes, _ := json.Marshal(payloadDoc) canon, _ := types.CanonicalizeJSON(payloadBytes) sig := ed25519.Sign(priv, canon) @@ -140,34 +167,41 @@ func buildSignedRecords(t *testing.T) []types.ExportRecord { "sigPayload": base64.RawURLEncoding.EncodeToString(canon), "sig": base64.RawURLEncoding.EncodeToString(sig), } + if prev != "" { op["prev"] = prev } + opRaw, _ := json.Marshal(op) opCanon, _ := types.CanonicalizeJSON(opRaw) cid := types.ComputeDigestCID(opCanon) + return types.ExportRecord{Seq: seq, DID: did, CID: cid, Operation: opRaw} } - rec1 := mk(1, "did:plc:alice", "", pub1, priv1) rec2 := mk(2, "did:plc:alice", rec1.CID, pub1, priv1) rec3 := mk(3, "did:plc:bob", "", pub2, priv2) out = append(out, rec1, rec2, rec3) + return out } func TestRecomputeTipAtOrBeforeHonorsContextCancellation(t *testing.T) { tmp := t.TempDir() dataDir := filepath.Join(tmp, "data") + if err := os.MkdirAll(dataDir, 0o755); err != nil { t.Fatalf("mkdir data dir: %v", err) } keySeed := make([]byte, ed25519.SeedSize) + if _, err := rand.Read(keySeed); err != nil { t.Fatalf("rand seed: %v", err) } + keyPath := filepath.Join(tmp, "mirror.key") + if err := os.WriteFile(keyPath, []byte(base64.RawURLEncoding.EncodeToString(keySeed)), 0o600); err != nil { t.Fatalf("write mirror key: %v", err) } @@ -175,29 +209,39 @@ func TestRecomputeTipAtOrBeforeHonorsContextCancellation(t *testing.T) { records := buildSignedRecords(t) sourcePath := filepath.Join(tmp, "sample.ndjson") file, err := os.Create(sourcePath) + if err != nil { t.Fatalf("create source: %v", err) } + for _, rec := range records { b, _ := json.Marshal(rec) + if _, err := fmt.Fprintln(file, string(b)); err != nil { t.Fatalf("write source: %v", err) } } + file.Close() store, err := storage.OpenPebble(dataDir) + if err != nil { t.Fatalf("open pebble: %v", err) } + defer store.Close() + if err := store.SetMode(config.ModeMirror); err != nil { t.Fatalf("set mode: %v", err) } + bl, err := storage.OpenBlockLog(dataDir, 3, 4) + if err != nil { t.Fatalf("open block log: %v", err) } + cfg := config.Config{ Mode: config.ModeMirror, DataDir: dataDir, @@ -211,15 +255,19 @@ func TestRecomputeTipAtOrBeforeHonorsContextCancellation(t *testing.T) { PollInterval: 5, } service := NewService(cfg, store, NewClient(sourcePath), bl, checkpoint.NewManager(store, dataDir, keyPath)) + if err := service.Replay(context.Background()); err != nil { t.Fatalf("replay: %v", err) } + if err := service.Flush(context.Background()); err != nil { t.Fatalf("flush: %v", err) } ctx, cancel := context.WithCancel(context.Background()) + cancel() + if _, _, err := service.RecomputeTipAtOrBefore(ctx, "did:plc:alice", 2); err == nil { t.Fatalf("expected cancellation error") } diff --git a/internal/ingest/service_order_test.go b/internal/ingest/service_order_test.go index 0e2299c..1d7c766 100644 --- a/internal/ingest/service_order_test.go +++ b/internal/ingest/service_order_test.go @@ -2,24 +2,28 @@ package ingest import ( "testing" - "github.com/Fuwn/plutia/internal/types" ) func TestCollectVerifiedInOrder_OutOfOrderInput(t *testing.T) { results := make(chan verifyResult, 3) + results <- verifyResult{index: 2, op: types.ParsedOperation{Sequence: 3}} results <- verifyResult{index: 0, op: types.ParsedOperation{Sequence: 1}} results <- verifyResult{index: 1, op: types.ParsedOperation{Sequence: 2}} + close(results) ordered, err := collectVerifiedInOrder(3, results) + if err != nil { t.Fatalf("collect verified: %v", err) } + if len(ordered) != 3 { t.Fatalf("ordered length mismatch: got %d want 3", len(ordered)) } + for i := 0; i < 3; i++ { if ordered[i].op.Sequence != uint64(i+1) { t.Fatalf("unexpected sequence at index %d: got %d want %d", i, ordered[i].op.Sequence, i+1) @@ -29,8 +33,10 @@ func TestCollectVerifiedInOrder_OutOfOrderInput(t *testing.T) { func TestCollectVerifiedInOrder_MissingResult(t *testing.T) { results := make(chan verifyResult, 2) + results <- verifyResult{index: 0, op: types.ParsedOperation{Sequence: 1}} results <- verifyResult{index: 2, op: types.ParsedOperation{Sequence: 3}} + close(results) if _, err := collectVerifiedInOrder(3, results); err == nil { diff --git a/internal/merkle/tree.go b/internal/merkle/tree.go index ff97217..023b4ec 100644 --- a/internal/merkle/tree.go +++ b/internal/merkle/tree.go @@ -12,27 +12,36 @@ type Sibling struct { func HashLeaf(leaf []byte) []byte { s := sha256.Sum256(leaf) + return s[:] } func Root(leaves [][]byte) []byte { if len(leaves) == 0 { empty := sha256.Sum256(nil) + return empty[:] } + level := cloneLevel(leaves) + for len(level) > 1 { next := make([][]byte, 0, (len(level)+1)/2) + for i := 0; i < len(level); i += 2 { left := level[i] right := left + if i+1 < len(level) { right = level[i+1] } + next = append(next, hashPair(left, right)) } + level = next } + return level[0] } @@ -40,15 +49,19 @@ func BuildProof(leaves [][]byte, index int) []Sibling { if len(leaves) == 0 || index < 0 || index >= len(leaves) { return nil } + proof := make([]Sibling, 0) level := cloneLevel(leaves) idx := index + for len(level) > 1 { if idx%2 == 0 { sib := idx + 1 + if sib >= len(level) { sib = idx } + proof = append(proof, Sibling{Hash: hex.EncodeToString(level[sib]), Left: false}) } else { sib := idx - 1 @@ -56,44 +69,57 @@ func BuildProof(leaves [][]byte, index int) []Sibling { } next := make([][]byte, 0, (len(level)+1)/2) + for i := 0; i < len(level); i += 2 { left := level[i] right := left + if i+1 < len(level) { right = level[i+1] } + next = append(next, hashPair(left, right)) } + idx /= 2 level = next } + return proof } func VerifyProof(leafHash []byte, proof []Sibling, root []byte) bool { cur := append([]byte(nil), leafHash...) + for _, s := range proof { sib, err := hex.DecodeString(s.Hash) + if err != nil { return false } + var combined []byte + if s.Left { combined = append(append([]byte(nil), sib...), cur...) } else { combined = append(append([]byte(nil), cur...), sib...) } + h := sha256.Sum256(combined) cur = h[:] } + return hex.EncodeToString(cur) == hex.EncodeToString(root) } func cloneLevel(in [][]byte) [][]byte { out := make([][]byte, len(in)) + for i := range in { out[i] = append([]byte(nil), in[i]...) } + return out } @@ -110,19 +136,26 @@ func (a *Accumulator) AddLeafHash(leafHash []byte) { if len(leafHash) == 0 { return } + cur := append([]byte(nil), leafHash...) level := 0 + for { if level >= len(a.levels) { a.levels = append(a.levels, nil) } + if a.levels[level] == nil { a.levels[level] = cur + a.pendingCount++ + break } + cur = hashPair(a.levels[level], cur) a.levels[level] = nil + a.pendingCount-- level++ } @@ -131,40 +164,53 @@ func (a *Accumulator) AddLeafHash(leafHash []byte) { func (a *Accumulator) RootDuplicateLast() []byte { if a.pendingCount == 0 { empty := sha256.Sum256(nil) + return empty[:] } + clone := &Accumulator{ levels: cloneLevel(a.levels), pendingCount: a.pendingCount, } + for clone.pendingCount > 1 { level := clone.lowestPendingLevel() cur := clone.levels[level] clone.levels[level] = nil + clone.pendingCount-- cur = hashPair(cur, cur) + level++ + for { if level >= len(clone.levels) { clone.levels = append(clone.levels, nil) } + if clone.levels[level] == nil { clone.levels[level] = cur + clone.pendingCount++ + break } + cur = hashPair(clone.levels[level], cur) clone.levels[level] = nil + clone.pendingCount-- level++ } } + return clone.highestPendingHash() } func hashPair(left, right []byte) []byte { h := sha256.Sum256(append(append([]byte(nil), left...), right...)) + return h[:] } @@ -174,6 +220,7 @@ func (a *Accumulator) lowestPendingLevel() int { return i } } + return 0 } @@ -183,6 +230,8 @@ func (a *Accumulator) highestPendingHash() []byte { return a.levels[i] } } + empty := sha256.Sum256(nil) + return empty[:] } diff --git a/internal/merkle/tree_test.go b/internal/merkle/tree_test.go index 9fa791d..b2546a8 100644 --- a/internal/merkle/tree_test.go +++ b/internal/merkle/tree_test.go @@ -9,12 +9,14 @@ import ( func TestRootAndProof(t *testing.T) { inputs := [][]byte{HashLeaf([]byte("a")), HashLeaf([]byte("b")), HashLeaf([]byte("c"))} root := Root(inputs) + if len(root) == 0 { t.Fatalf("expected root") } for i := range inputs { proof := BuildProof(inputs, i) + if !VerifyProof(inputs[i], proof, root) { t.Fatalf("proof verification failed for index %d", i) } @@ -23,6 +25,7 @@ func TestRootAndProof(t *testing.T) { func TestRootEmpty(t *testing.T) { r := Root(nil) + if got := hex.EncodeToString(r); got != "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" { t.Fatalf("unexpected empty root: %s", got) } @@ -32,13 +35,17 @@ func TestAccumulatorRootMatchesRoot(t *testing.T) { for n := 1; n <= 128; n++ { leaves := make([][]byte, 0, n) acc := NewAccumulator() + for i := 0; i < n; i++ { leaf := HashLeaf([]byte(fmt.Sprintf("leaf-%d", i))) leaves = append(leaves, leaf) + acc.AddLeafHash(leaf) } + want := hex.EncodeToString(Root(leaves)) got := hex.EncodeToString(acc.RootDuplicateLast()) + if got != want { t.Fatalf("root mismatch n=%d got=%s want=%s", n, got, want) } diff --git a/internal/state/engine.go b/internal/state/engine.go index 4ecdaca..0f5caad 100644 --- a/internal/state/engine.go +++ b/internal/state/engine.go @@ -4,7 +4,6 @@ import ( "encoding/json" "fmt" "time" - "github.com/Fuwn/plutia/internal/config" "github.com/Fuwn/plutia/internal/storage" "github.com/Fuwn/plutia/internal/types" @@ -21,41 +20,55 @@ func New(store storage.Store, mode string) *Engine { func (e *Engine) Apply(op types.ParsedOperation, blockRef *types.BlockRefV1) error { existing, ok, err := e.store.GetState(op.DID) + if err != nil { return fmt.Errorf("load state: %w", err) } + var prior *types.StateV1 + if ok { prior = &existing } + next, err := ComputeNextState(op, prior) + if err != nil { return err } + if pebbleStore, ok := e.store.(*storage.PebbleStore); ok { includeOpRef := e.mode == config.ModeMirror && blockRef != nil + if err := pebbleStore.ApplyOperationBatch(next, blockRef, includeOpRef); err != nil { return fmt.Errorf("apply operation batch: %w", err) } + return nil } + if err := e.store.PutState(next); err != nil { return fmt.Errorf("put state: %w", err) } + if err := e.store.SetChainHead(op.DID, op.Sequence); err != nil { return fmt.Errorf("set chain head: %w", err) } + if err := e.store.AddDIDSequence(op.DID, op.Sequence); err != nil { return fmt.Errorf("append did sequence: %w", err) } + if e.mode == config.ModeMirror && blockRef != nil { if err := e.store.PutOpSeqRef(op.Sequence, *blockRef); err != nil { return fmt.Errorf("put opseq ref: %w", err) } } + if err := e.store.SetGlobalSeq(op.Sequence); err != nil { return fmt.Errorf("set global seq: %w", err) } + return nil } @@ -64,12 +77,16 @@ func ComputeNextState(op types.ParsedOperation, existing *types.StateV1) (types. if existing.ChainTipHash == op.CID { return *existing, nil } + return types.StateV1{}, fmt.Errorf("non-monotonic sequence for %s: got %d <= %d", op.DID, op.Sequence, existing.LatestOpSeq) } + doc, err := materializeDIDDocument(op.DID, op.Payload) + if err != nil { return types.StateV1{}, err } + next := types.StateV1{ Version: 1, DID: op.DID, @@ -79,9 +96,11 @@ func ComputeNextState(op types.ParsedOperation, existing *types.StateV1) (types. UpdatedAt: time.Now().UTC(), } next.RotationKeys = extractRotationKeys(op.Payload) + if len(next.RotationKeys) == 0 && existing != nil { next.RotationKeys = append([]string(nil), existing.RotationKeys...) } + return next, nil } @@ -89,31 +108,39 @@ func materializeDIDDocument(did string, payload map[string]any) ([]byte, error) for _, k := range []string{"did_document", "didDoc", "document", "didDocument"} { if v, ok := payload[k]; ok { b, err := json.Marshal(v) + if err != nil { return nil, fmt.Errorf("marshal did document field %s: %w", k, err) } + return types.CanonicalizeJSON(b) } } + if v, ok := payload["alsoKnownAs"]; ok { doc := map[string]any{ "id": did, "alsoKnownAs": v, } b, err := json.Marshal(doc) + if err != nil { return nil, fmt.Errorf("marshal derived did document: %w", err) } + return types.CanonicalizeJSON(b) } doc := map[string]any{"id": did} + if typ, _ := payload["type"].(string); typ == "plc_tombstone" || typ == "tombstone" { doc["deactivated"] = true } + if handle, ok := payload["handle"].(string); ok && handle != "" { doc["alsoKnownAs"] = []string{"at://" + handle} } + if serviceEndpoint, ok := payload["service"].(string); ok && serviceEndpoint != "" { doc["service"] = []map[string]any{ { @@ -123,6 +150,7 @@ func materializeDIDDocument(did string, payload map[string]any) ([]byte, error) }, } } + if signingKey, ok := payload["signingKey"].(string); ok && signingKey != "" { doc["verificationMethod"] = []map[string]any{ { @@ -134,10 +162,13 @@ func materializeDIDDocument(did string, payload map[string]any) ([]byte, error) } doc["authentication"] = []string{"#atproto"} } + b, err := json.Marshal(doc) + if err != nil { return nil, fmt.Errorf("marshal synthesized did document: %w", err) } + return types.CanonicalizeJSON(b) } @@ -148,9 +179,11 @@ func extractRotationKeys(payload map[string]any) []string { if v == "" { return } + if _, ok := seen[v]; ok { return } + seen[v] = struct{}{} keys = append(keys, v) } @@ -162,11 +195,14 @@ func extractRotationKeys(payload map[string]any) []string { } } } + if recovery, ok := payload["recoveryKey"].(string); ok { add(recovery) } + if signing, ok := payload["signingKey"].(string); ok { add(signing) } + return keys } diff --git a/internal/storage/blocklog.go b/internal/storage/blocklog.go index 879d3bd..e343b45 100644 --- a/internal/storage/blocklog.go +++ b/internal/storage/blocklog.go @@ -13,7 +13,6 @@ import ( "strings" "sync" "time" - "github.com/Fuwn/plutia/internal/types" "github.com/klauspost/compress/zstd" ) @@ -33,7 +32,6 @@ type BlockLog struct { dir string zstdLevel int targetSize int - mu sync.Mutex buf bytes.Buffer blockID uint64 @@ -41,13 +39,17 @@ type BlockLog struct { func OpenBlockLog(dataDir string, zstdLevel int, targetMB int) (*BlockLog, error) { dir := filepath.Join(dataDir, "ops") + if err := os.MkdirAll(dir, 0o755); err != nil { return nil, fmt.Errorf("mkdir ops: %w", err) } + nextID, err := detectNextBlockID(dir) + if err != nil { return nil, err } + return &BlockLog{ dir: dir, zstdLevel: zstdLevel, @@ -58,10 +60,12 @@ func OpenBlockLog(dataDir string, zstdLevel int, targetMB int) (*BlockLog, error func (l *BlockLog) Append(seq uint64, did, cid, prev string, canonical []byte) (types.BlockRefV1, *FlushInfo, error) { l.mu.Lock() + defer l.mu.Unlock() record := encodeRecord(canonical) offset := uint64(l.buf.Len()) + if _, err := l.buf.Write(record); err != nil { return types.BlockRefV1{}, nil, fmt.Errorf("buffer write: %w", err) } @@ -82,16 +86,21 @@ func (l *BlockLog) Append(seq uint64, did, cid, prev string, canonical []byte) ( if l.buf.Len() < l.targetSize { return ref, nil, nil } + flush, err := l.flushLocked() + if err != nil { return types.BlockRefV1{}, nil, err } + return ref, flush, nil } func (l *BlockLog) Flush() (*FlushInfo, error) { l.mu.Lock() + defer l.mu.Unlock() + return l.flushLocked() } @@ -99,12 +108,16 @@ func (l *BlockLog) flushLocked() (*FlushInfo, error) { if l.buf.Len() == 0 { return nil, nil } + encLevel := zstd.EncoderLevelFromZstd(l.zstdLevel) enc, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(encLevel)) + if err != nil { return nil, fmt.Errorf("zstd encoder: %w", err) } + compressed := enc.EncodeAll(l.buf.Bytes(), nil) + if err := enc.Close(); err != nil { return nil, fmt.Errorf("close zstd encoder: %w", err) } @@ -112,36 +125,50 @@ func (l *BlockLog) flushLocked() (*FlushInfo, error) { path := filepath.Join(l.dir, fmt.Sprintf("%06d.zst", l.blockID)) tmpPath := path + ".tmp" f, err := os.Create(tmpPath) + if err != nil { return nil, fmt.Errorf("create temp block file: %w", err) } + half := len(compressed) / 2 + if _, err := f.Write(compressed[:half]); err != nil { _ = f.Close() + return nil, fmt.Errorf("write temp block first half: %w", err) } + if sleepMs, err := strconv.Atoi(strings.TrimSpace(os.Getenv("PLUTIA_TEST_SLOW_FLUSH_MS"))); err == nil && sleepMs > 0 { time.Sleep(time.Duration(sleepMs) * time.Millisecond) } + if _, err := f.Write(compressed[half:]); err != nil { _ = f.Close() + return nil, fmt.Errorf("write temp block second half: %w", err) } + if err := f.Sync(); err != nil { _ = f.Close() + return nil, fmt.Errorf("sync temp block file: %w", err) } + if err := f.Close(); err != nil { return nil, fmt.Errorf("close temp block file: %w", err) } + if err := os.Rename(tmpPath, path); err != nil { return nil, fmt.Errorf("rename temp block file: %w", err) } + sum := sha256.Sum256(compressed) hash := hex.EncodeToString(sum[:]) l.blockID++ + l.buf.Reset() + return &FlushInfo{BlockID: l.blockID - 1, Hash: hash}, nil } @@ -151,56 +178,77 @@ func (l *BlockLog) ValidateAndRecover(store Store) (*IntegrityReport, error) { RemovedOrphans: make([]uint64, 0), } entries, err := os.ReadDir(l.dir) + if err != nil { return nil, fmt.Errorf("read ops dir: %w", err) } + storedEntries, err := store.ListBlockHashEntries() + if err != nil { return nil, fmt.Errorf("list stored block hashes: %w", err) } + stored := make(map[uint64]string, len(storedEntries)) + for _, e := range storedEntries { stored[e.BlockID] = e.Hash } files := make(map[uint64]string) + for _, e := range entries { name := e.Name() + if strings.HasSuffix(name, ".tmp") { path := filepath.Join(l.dir, name) + if err := os.Remove(path); err != nil { return nil, fmt.Errorf("remove temp block file %s: %w", name, err) } + report.RemovedTempFiles = append(report.RemovedTempFiles, name) + continue } + if !strings.HasSuffix(name, ".zst") { continue } + base := strings.TrimSuffix(name, ".zst") id, err := strconv.Atoi(base) + if err != nil { continue } + files[uint64(id)] = filepath.Join(l.dir, name) } for id, path := range files { expected, ok := stored[id] + if !ok { if err := os.Remove(path); err != nil { return nil, fmt.Errorf("remove orphan block %d: %w", id, err) } + report.RemovedOrphans = append(report.RemovedOrphans, id) + continue } + actual, err := fileHash(path) + if err != nil { return nil, fmt.Errorf("hash block %d: %w", id, err) } + if actual != expected { return nil, fmt.Errorf("block hash mismatch block=%d expected=%s got=%s", id, expected, actual) } + report.VerifiedBlocks++ } @@ -216,115 +264,157 @@ func (l *BlockLog) ValidateAndRecover(store Store) (*IntegrityReport, error) { func (l *BlockLog) ReadRecord(ref types.BlockRefV1) ([]byte, error) { path := filepath.Join(l.dir, fmt.Sprintf("%06d.zst", ref.BlockID)) compressed, err := os.ReadFile(path) + if err != nil { return nil, fmt.Errorf("read block file: %w", err) } + dec, err := zstd.NewReader(nil) + if err != nil { return nil, fmt.Errorf("zstd reader: %w", err) } + decompressed, err := dec.DecodeAll(compressed, nil) + if err != nil { return nil, fmt.Errorf("decode block: %w", err) } + dec.Close() + if ref.Offset+ref.Length > uint64(len(decompressed)) { return nil, fmt.Errorf("record bounds out of range") } + record := decompressed[ref.Offset : ref.Offset+ref.Length] _, payload, err := decodeRecord(record) + if err != nil { return nil, err } + return payload, nil } func (l *BlockLog) IterateBlockRecords(blockID uint64, fn func(offset uint64, payload []byte) error) error { path := filepath.Join(l.dir, fmt.Sprintf("%06d.zst", blockID)) compressed, err := os.ReadFile(path) + if err != nil { return fmt.Errorf("read block file: %w", err) } + dec, err := zstd.NewReader(nil) + if err != nil { return fmt.Errorf("zstd reader: %w", err) } + decompressed, err := dec.DecodeAll(compressed, nil) + if err != nil { return fmt.Errorf("decode block: %w", err) } + dec.Close() var offset uint64 + for offset < uint64(len(decompressed)) { start := offset length, n := binary.Uvarint(decompressed[offset:]) + if n <= 0 { return fmt.Errorf("invalid varint at offset %d", offset) } + offset += uint64(n) + if offset+length > uint64(len(decompressed)) { return fmt.Errorf("record out of bounds at offset %d", start) } + payload := decompressed[offset : offset+length] offset += length + if err := fn(start, payload); err != nil { return err } } + return nil } func encodeRecord(payload []byte) []byte { var hdr [binary.MaxVarintLen64]byte + n := binary.PutUvarint(hdr[:], uint64(len(payload))) out := make([]byte, n+len(payload)) + copy(out, hdr[:n]) copy(out[n:], payload) + return out } func decodeRecord(b []byte) (uint64, []byte, error) { length, n := binary.Uvarint(b) + if n <= 0 { return 0, nil, fmt.Errorf("invalid varint record") } + if uint64(n)+length != uint64(len(b)) { return 0, nil, fmt.Errorf("record length mismatch") } + return length, b[n:], nil } func detectNextBlockID(dir string) (uint64, error) { entries, err := os.ReadDir(dir) + if err != nil { return 0, fmt.Errorf("read ops dir: %w", err) } + ids := make([]int, 0) + for _, e := range entries { name := e.Name() + if !strings.HasSuffix(name, ".zst") { continue } + base := strings.TrimSuffix(name, ".zst") n, err := strconv.Atoi(base) + if err != nil { continue } + ids = append(ids, n) } + if len(ids) == 0 { return 1, nil } + sort.Ints(ids) + return uint64(ids[len(ids)-1] + 1), nil } func fileHash(path string) (string, error) { b, err := os.ReadFile(path) + if err != nil { return "", err } + sum := sha256.Sum256(b) + return hex.EncodeToString(sum[:]), nil } diff --git a/internal/storage/blocklog_test.go b/internal/storage/blocklog_test.go index 23045cc..a71febf 100644 --- a/internal/storage/blocklog_test.go +++ b/internal/storage/blocklog_test.go @@ -9,6 +9,7 @@ import ( func TestBlockLogAppendReadFlush(t *testing.T) { tmp := t.TempDir() log, err := OpenBlockLog(tmp, 3, 4) + if err != nil { t.Fatalf("OpenBlockLog: %v", err) } @@ -16,40 +17,53 @@ func TestBlockLogAppendReadFlush(t *testing.T) { payload1 := []byte(`{"op":1}`) payload2 := []byte(`{"op":2}`) ref1, flush, err := log.Append(1, "did:plc:a", "cid1", "", payload1) + if err != nil { t.Fatalf("append 1: %v", err) } + if flush != nil { t.Fatalf("unexpected automatic flush") } + ref2, _, err := log.Append(2, "did:plc:a", "cid2", "cid1", payload2) + if err != nil { t.Fatalf("append 2: %v", err) } flushed, err := log.Flush() + if err != nil { t.Fatalf("flush: %v", err) } + if flushed == nil { t.Fatalf("expected flushed block") } + blockFile := filepath.Join(tmp, "ops", "000001.zst") + if _, err := os.Stat(blockFile); err != nil { t.Fatalf("expected block file: %v", err) } got1, err := log.ReadRecord(ref1) + if err != nil { t.Fatalf("read record1: %v", err) } + if string(got1) != string(payload1) { t.Fatalf("payload1 mismatch: got %s want %s", got1, payload1) } + got2, err := log.ReadRecord(ref2) + if err != nil { t.Fatalf("read record2: %v", err) } + if string(got2) != string(payload2) { t.Fatalf("payload2 mismatch: got %s want %s", got2, payload2) } diff --git a/internal/storage/pebble_store.go b/internal/storage/pebble_store.go index 109310f..55c04fb 100644 --- a/internal/storage/pebble_store.go +++ b/internal/storage/pebble_store.go @@ -6,7 +6,6 @@ import ( "fmt" "path/filepath" "sort" - "github.com/Fuwn/plutia/internal/types" "github.com/cockroachdb/pebble" ) @@ -20,9 +19,11 @@ type PebbleStore struct { func OpenPebble(dataDir string) (*PebbleStore, error) { indexDir := filepath.Join(dataDir, "index") db, err := pebble.Open(indexDir, &pebble.Options{}) + if err != nil { return nil, fmt.Errorf("open pebble: %w", err) } + return &PebbleStore{db: db}, nil } @@ -30,6 +31,7 @@ func (p *PebbleStore) Close() error { return p.db.Close() } func (p *PebbleStore) GetMode() (string, error) { v, ok, err := p.getString(metaKey("mode")) + return v, okOrErr(ok, err) } @@ -39,80 +41,103 @@ func (p *PebbleStore) SetMode(mode string) error { func (p *PebbleStore) GetGlobalSeq() (uint64, error) { v, ok, err := p.getUint64(metaKey("global_seq")) + if err != nil { return 0, err } + if !ok { return 0, nil } + return v, nil } func (p *PebbleStore) SetGlobalSeq(seq uint64) error { b := make([]byte, 8) + binary.BigEndian.PutUint64(b, seq) + return p.db.Set(metaKey("global_seq"), b, pebble.Sync) } func (p *PebbleStore) PutState(state types.StateV1) error { b, err := json.Marshal(state) + if err != nil { return fmt.Errorf("marshal state: %w", err) } + return p.db.Set(didKey(state.DID), b, pebble.Sync) } func (p *PebbleStore) ApplyOperationBatch(state types.StateV1, ref *types.BlockRefV1, includeOpRef bool) error { var opRef *types.BlockRefV1 + if includeOpRef { opRef = ref } + return p.ApplyOperationsBatch([]OperationMutation{{State: state, Ref: opRef}}, nil) } func (p *PebbleStore) GetState(did string) (types.StateV1, bool, error) { b, ok, err := p.getBytes(didKey(did)) + if err != nil || !ok { return types.StateV1{}, ok, err } + var s types.StateV1 + if err := json.Unmarshal(b, &s); err != nil { return types.StateV1{}, false, fmt.Errorf("unmarshal state: %w", err) } + return s, true, nil } func (p *PebbleStore) ListStates() ([]types.StateV1, error) { states := make([]types.StateV1, 0) + if err := p.ForEachState(func(s types.StateV1) error { states = append(states, s) + return nil }); err != nil { return nil, err } + sort.Slice(states, func(i, j int) bool { return states[i].DID < states[j].DID }) + return states, nil } func (p *PebbleStore) ForEachState(fn func(types.StateV1) error) error { iter, err := p.db.NewIter(&pebble.IterOptions{LowerBound: []byte("did:"), UpperBound: []byte("did;")}) + if err != nil { return fmt.Errorf("new iterator: %w", err) } + defer iter.Close() for iter.First(); iter.Valid(); iter.Next() { var s types.StateV1 + if err := json.Unmarshal(iter.Value(), &s); err != nil { return fmt.Errorf("unmarshal state: %w", err) } + if err := fn(s); err != nil { return err } } + if err := iter.Error(); err != nil { return fmt.Errorf("iterate states: %w", err) } + return nil } @@ -120,31 +145,41 @@ func (p *PebbleStore) ApplyOperationsBatch(ops []OperationMutation, blockHashes if len(ops) == 0 && len(blockHashes) == 0 { return nil } + batch := p.db.NewBatch() + defer batch.Close() for _, op := range ops { stateBytes, err := json.Marshal(op.State) + if err != nil { return fmt.Errorf("marshal state: %w", err) } + seqBytes := make([]byte, 8) + binary.BigEndian.PutUint64(seqBytes, op.State.LatestOpSeq) if err := batch.Set(didKey(op.State.DID), stateBytes, nil); err != nil { return err } + if err := batch.Set(chainKey(op.State.DID), seqBytes, nil); err != nil { return err } + if err := batch.Set(didOpKey(op.State.DID, op.State.LatestOpSeq), []byte{1}, nil); err != nil { return err } + if op.Ref != nil { refBytes, err := json.Marshal(op.Ref) + if err != nil { return fmt.Errorf("marshal opseq ref: %w", err) } + if err := batch.Set(opSeqKey(op.State.LatestOpSeq), refBytes, nil); err != nil { return err } @@ -160,7 +195,9 @@ func (p *PebbleStore) ApplyOperationsBatch(ops []OperationMutation, blockHashes if len(ops) > 0 { lastSeq := ops[len(ops)-1].State.LatestOpSeq seqBytes := make([]byte, 8) + binary.BigEndian.PutUint64(seqBytes, lastSeq) + if err := batch.Set(metaKey("global_seq"), seqBytes, nil); err != nil { return err } @@ -171,7 +208,9 @@ func (p *PebbleStore) ApplyOperationsBatch(ops []OperationMutation, blockHashes func (p *PebbleStore) SetChainHead(did string, seq uint64) error { b := make([]byte, 8) + binary.BigEndian.PutUint64(b, seq) + return p.db.Set(chainKey(did), b, pebble.Sync) } @@ -187,71 +226,94 @@ func (p *PebbleStore) ListDIDSequences(did string) ([]uint64, error) { prefix := []byte("didop:" + did + ":") upper := append(append([]byte(nil), prefix...), 0xFF) iter, err := p.db.NewIter(&pebble.IterOptions{LowerBound: prefix, UpperBound: upper}) + if err != nil { return nil, fmt.Errorf("new iterator: %w", err) } + defer iter.Close() seqs := make([]uint64, 0) + for iter.First(); iter.Valid(); iter.Next() { key := iter.Key() + if len(key) < len(prefix)+8 { continue } + seq := binary.BigEndian.Uint64(key[len(prefix):]) seqs = append(seqs, seq) } + if err := iter.Error(); err != nil { return nil, fmt.Errorf("iterate did op sequences: %w", err) } + sort.Slice(seqs, func(i, j int) bool { return seqs[i] < seqs[j] }) + return seqs, nil } func (p *PebbleStore) PutOpSeqRef(seq uint64, ref types.BlockRefV1) error { b, err := json.Marshal(ref) + if err != nil { return fmt.Errorf("marshal opseq ref: %w", err) } + return p.db.Set(opSeqKey(seq), b, pebble.Sync) } func (p *PebbleStore) GetOpSeqRef(seq uint64) (types.BlockRefV1, bool, error) { b, ok, err := p.getBytes(opSeqKey(seq)) + if err != nil || !ok { return types.BlockRefV1{}, ok, err } + var ref types.BlockRefV1 + if err := json.Unmarshal(b, &ref); err != nil { return types.BlockRefV1{}, false, fmt.Errorf("unmarshal opseq ref: %w", err) } + return ref, true, nil } func (p *PebbleStore) ForEachOpSeqRef(fn func(seq uint64, ref types.BlockRefV1) error) error { iter, err := p.db.NewIter(&pebble.IterOptions{LowerBound: []byte("opseq:"), UpperBound: []byte("opseq;")}) + if err != nil { return fmt.Errorf("new iterator: %w", err) } + defer iter.Close() for iter.First(); iter.Valid(); iter.Next() { key := iter.Key() + if len(key) < len("opseq:")+8 { continue } + seq := binary.BigEndian.Uint64(key[len("opseq:"):]) + var ref types.BlockRefV1 + if err := json.Unmarshal(iter.Value(), &ref); err != nil { return fmt.Errorf("unmarshal opseq ref: %w", err) } + if err := fn(seq, ref); err != nil { return err } } + if err := iter.Error(); err != nil { return fmt.Errorf("iterate opseq refs: %w", err) } + return nil } @@ -265,127 +327,168 @@ func (p *PebbleStore) GetBlockHash(blockID uint64) (string, bool, error) { func (p *PebbleStore) ListBlockHashes() ([]string, error) { iter, err := p.db.NewIter(&pebble.IterOptions{LowerBound: []byte("block:"), UpperBound: []byte("block;")}) + if err != nil { return nil, fmt.Errorf("new iterator: %w", err) } + defer iter.Close() hashes := make([]string, 0) + for iter.First(); iter.Valid(); iter.Next() { hashes = append(hashes, string(iter.Value())) } + if err := iter.Error(); err != nil { return nil, fmt.Errorf("iterate blocks: %w", err) } + return hashes, nil } func (p *PebbleStore) ListBlockHashEntries() ([]BlockHashEntry, error) { iter, err := p.db.NewIter(&pebble.IterOptions{LowerBound: []byte("block:"), UpperBound: []byte("block;")}) + if err != nil { return nil, fmt.Errorf("new iterator: %w", err) } + defer iter.Close() out := make([]BlockHashEntry, 0) + for iter.First(); iter.Valid(); iter.Next() { key := iter.Key() + if len(key) < len("block:")+8 { continue } + id := binary.BigEndian.Uint64(key[len("block:"):]) out = append(out, BlockHashEntry{BlockID: id, Hash: string(iter.Value())}) } + if err := iter.Error(); err != nil { return nil, fmt.Errorf("iterate block entries: %w", err) } + sort.Slice(out, func(i, j int) bool { return out[i].BlockID < out[j].BlockID }) + return out, nil } func (p *PebbleStore) PutCheckpoint(cp types.CheckpointV1) error { b, err := json.Marshal(cp) + if err != nil { return fmt.Errorf("marshal checkpoint: %w", err) } + if err := p.db.Set(checkpointKey(cp.Sequence), b, pebble.Sync); err != nil { return err } + latest := make([]byte, 8) + binary.BigEndian.PutUint64(latest, cp.Sequence) + return p.db.Set(metaKey("latest_checkpoint"), latest, pebble.Sync) } func (p *PebbleStore) GetCheckpoint(sequence uint64) (types.CheckpointV1, bool, error) { b, ok, err := p.getBytes(checkpointKey(sequence)) + if err != nil || !ok { return types.CheckpointV1{}, ok, err } + var cp types.CheckpointV1 + if err := json.Unmarshal(b, &cp); err != nil { return types.CheckpointV1{}, false, fmt.Errorf("unmarshal checkpoint: %w", err) } + return cp, true, nil } func (p *PebbleStore) GetLatestCheckpoint() (types.CheckpointV1, bool, error) { seq, ok, err := p.getUint64(metaKey("latest_checkpoint")) + if err != nil || !ok { return types.CheckpointV1{}, ok, err } + return p.GetCheckpoint(seq) } func (p *PebbleStore) getBytes(key []byte) ([]byte, bool, error) { v, closer, err := p.db.Get(key) + if err != nil { if err == pebble.ErrNotFound { return nil, false, nil } + return nil, false, err } + defer closer.Close() + return append([]byte(nil), v...), true, nil } func (p *PebbleStore) getString(key []byte) (string, bool, error) { b, ok, err := p.getBytes(key) + if err != nil || !ok { return "", ok, err } + return string(b), true, nil } func (p *PebbleStore) getUint64(key []byte) (uint64, bool, error) { b, ok, err := p.getBytes(key) + if err != nil || !ok { return 0, ok, err } + if len(b) != 8 { return 0, false, fmt.Errorf("invalid uint64 value length for %q", key) } + return binary.BigEndian.Uint64(b), true, nil } func didKey(did string) []byte { return []byte("did:" + did) } + func chainKey(did string) []byte { return []byte("chain:" + did) } + func metaKey(k string) []byte { return []byte("meta:" + k) } + func checkpointKey(s uint64) []byte { return append([]byte("checkpoint:"), u64bytes(s)...) } + func opSeqKey(s uint64) []byte { return append([]byte("opseq:"), u64bytes(s)...) } + func blockKey(id uint64) []byte { return append([]byte("block:"), u64bytes(id)...) } + func didOpKey(did string, seq uint64) []byte { return append([]byte("didop:"+did+":"), u64bytes(seq)...) } func u64bytes(v uint64) []byte { b := make([]byte, 8) + binary.BigEndian.PutUint64(b, v) + return b } @@ -393,8 +496,10 @@ func okOrErr(ok bool, err error) error { if err != nil { return err } + if !ok { return nil } + return nil } diff --git a/internal/storage/pebble_store_batch_durability_test.go b/internal/storage/pebble_store_batch_durability_test.go index 9d881c3..256ddda 100644 --- a/internal/storage/pebble_store_batch_durability_test.go +++ b/internal/storage/pebble_store_batch_durability_test.go @@ -3,13 +3,13 @@ package storage import ( "testing" "time" - "github.com/Fuwn/plutia/internal/types" ) func TestApplyOperationsBatch_DurabilityBetweenBatches(t *testing.T) { tmp := t.TempDir() store, err := OpenPebble(tmp) + if err != nil { t.Fatalf("open pebble: %v", err) } @@ -40,24 +40,30 @@ func TestApplyOperationsBatch_DurabilityBetweenBatches(t *testing.T) { }, []BlockHashEntry{{BlockID: 1, Hash: "abc"}}); err != nil { t.Fatalf("apply first batch: %v", err) } + if err := store.Close(); err != nil { t.Fatalf("close store: %v", err) } // Simulate a crash before the second batch commit by reopening without applying it. reopened, err := OpenPebble(tmp) + if err != nil { t.Fatalf("reopen pebble: %v", err) } + defer reopened.Close() seq, err := reopened.GetGlobalSeq() + if err != nil { t.Fatalf("get global seq: %v", err) } + if seq != 2 { t.Fatalf("global seq mismatch after simulated crash: got %d want 2", seq) } + if _, ok, err := reopened.GetOpSeqRef(3); err != nil { t.Fatalf("get opseq 3: %v", err) } else if ok { diff --git a/internal/storage/pebble_store_batch_test.go b/internal/storage/pebble_store_batch_test.go index 85afc92..742e7bf 100644 --- a/internal/storage/pebble_store_batch_test.go +++ b/internal/storage/pebble_store_batch_test.go @@ -3,16 +3,17 @@ package storage import ( "testing" "time" - "github.com/Fuwn/plutia/internal/types" ) func TestPebbleStoreApplyOperationBatch(t *testing.T) { tmp := t.TempDir() store, err := OpenPebble(tmp) + if err != nil { t.Fatalf("open pebble: %v", err) } + defer store.Close() state := types.StateV1{ @@ -39,35 +40,43 @@ func TestPebbleStoreApplyOperationBatch(t *testing.T) { } gotState, ok, err := store.GetState(state.DID) + if err != nil || !ok { t.Fatalf("get state: ok=%v err=%v", ok, err) } + if gotState.ChainTipHash != state.ChainTipHash || gotState.LatestOpSeq != state.LatestOpSeq { t.Fatalf("state mismatch: got tip=%s seq=%d", gotState.ChainTipHash, gotState.LatestOpSeq) } head, ok, err := store.GetChainHead(state.DID) + if err != nil || !ok || head != state.LatestOpSeq { t.Fatalf("chain head mismatch: head=%d ok=%v err=%v", head, ok, err) } seqs, err := store.ListDIDSequences(state.DID) + if err != nil { t.Fatalf("list did sequences: %v", err) } + if len(seqs) != 1 || seqs[0] != state.LatestOpSeq { t.Fatalf("did sequence mismatch: %v", seqs) } gotRef, ok, err := store.GetOpSeqRef(state.LatestOpSeq) + if err != nil || !ok { t.Fatalf("get opseq ref: ok=%v err=%v", ok, err) } + if gotRef.BlockID != ref.BlockID || gotRef.CID != ref.CID { t.Fatalf("op ref mismatch: got block=%d cid=%s", gotRef.BlockID, gotRef.CID) } globalSeq, err := store.GetGlobalSeq() + if err != nil || globalSeq != state.LatestOpSeq { t.Fatalf("global seq mismatch: seq=%d err=%v", globalSeq, err) } diff --git a/internal/storage/store.go b/internal/storage/store.go index 05ce5ae..5267386 100644 --- a/internal/storage/store.go +++ b/internal/storage/store.go @@ -14,33 +14,26 @@ type OperationMutation struct { type Store interface { Close() error - GetMode() (string, error) SetMode(mode string) error - GetGlobalSeq() (uint64, error) SetGlobalSeq(seq uint64) error - PutState(state types.StateV1) error GetState(did string) (types.StateV1, bool, error) ListStates() ([]types.StateV1, error) ForEachState(fn func(types.StateV1) error) error ApplyOperationsBatch(ops []OperationMutation, blockHashes []BlockHashEntry) error - SetChainHead(did string, seq uint64) error GetChainHead(did string) (uint64, bool, error) AddDIDSequence(did string, seq uint64) error ListDIDSequences(did string) ([]uint64, error) - PutOpSeqRef(seq uint64, ref types.BlockRefV1) error GetOpSeqRef(seq uint64) (types.BlockRefV1, bool, error) ForEachOpSeqRef(fn func(seq uint64, ref types.BlockRefV1) error) error - PutBlockHash(blockID uint64, hash string) error GetBlockHash(blockID uint64) (string, bool, error) ListBlockHashes() ([]string, error) ListBlockHashEntries() ([]BlockHashEntry, error) - PutCheckpoint(cp types.CheckpointV1) error GetCheckpoint(sequence uint64) (types.CheckpointV1, bool, error) GetLatestCheckpoint() (types.CheckpointV1, bool, error) diff --git a/internal/types/operation.go b/internal/types/operation.go index 2facb82..d659a92 100644 --- a/internal/types/operation.go +++ b/internal/types/operation.go @@ -30,16 +30,21 @@ type ParsedOperation struct { func CanonicalizeJSON(raw []byte) ([]byte, error) { trimmed := bytes.TrimSpace(raw) + if len(trimmed) == 0 { return nil, fmt.Errorf("empty json payload") } + if !json.Valid(trimmed) { return nil, fmt.Errorf("invalid json payload") } + var out bytes.Buffer + if err := json.Compact(&out, trimmed); err != nil { return nil, fmt.Errorf("compact json: %w", err) } + return out.Bytes(), nil } @@ -47,19 +52,26 @@ func ParseOperation(rec ExportRecord) (ParsedOperation, error) { if rec.DID == "" { return ParsedOperation{}, fmt.Errorf("missing did") } + canonical, err := CanonicalizeJSON(rec.Operation) + if err != nil { return ParsedOperation{}, err } + payload := map[string]any{} + if err := json.Unmarshal(canonical, &payload); err != nil { return ParsedOperation{}, fmt.Errorf("decode operation: %w", err) } + prev, _ := payload["prev"].(string) cid := rec.CID + if cid == "" { cid = ComputeDigestCID(canonical) } + return ParsedOperation{ DID: rec.DID, CanonicalBytes: canonical, @@ -73,5 +85,6 @@ func ParseOperation(rec ExportRecord) (ParsedOperation, error) { func ComputeDigestCID(payload []byte) string { sum := sha256.Sum256(payload) + return "sha256:" + hex.EncodeToString(sum[:]) } diff --git a/internal/types/operation_test.go b/internal/types/operation_test.go index a0cbc38..691d598 100644 --- a/internal/types/operation_test.go +++ b/internal/types/operation_test.go @@ -8,9 +8,11 @@ import ( func TestCanonicalizeJSON(t *testing.T) { raw := []byte(` { "z": 1, "a" : { "b" : 2 } } `) canon, err := CanonicalizeJSON(raw) + if err != nil { t.Fatalf("CanonicalizeJSON returned error: %v", err) } + if got, want := string(canon), `{"z":1,"a":{"b":2}}`; got != want { t.Fatalf("canonical mismatch: got %s want %s", got, want) } @@ -23,12 +25,15 @@ func TestParseOperationComputesCID(t *testing.T) { Operation: json.RawMessage(`{"didDoc":{"id":"did:plc:test"},"sig":"abc","publicKey":"def"}`), } op, err := ParseOperation(rec) + if err != nil { t.Fatalf("ParseOperation returned error: %v", err) } + if op.CID == "" { t.Fatalf("expected computed CID") } + if op.Sequence != rec.Seq { t.Fatalf("seq mismatch: got %d want %d", op.Sequence, rec.Seq) } diff --git a/internal/verify/verifier.go b/internal/verify/verifier.go index ae648d3..4cb6b38 100644 --- a/internal/verify/verifier.go +++ b/internal/verify/verifier.go @@ -13,7 +13,6 @@ import ( "fmt" "math/big" "strings" - "github.com/Fuwn/plutia/internal/config" "github.com/Fuwn/plutia/internal/types" "github.com/decred/dcrd/dcrec/secp256k1/v4" @@ -38,6 +37,7 @@ func (v *Verifier) ShouldVerify(existing *types.StateV1, seq uint64) bool { if existing == nil { return true } + return seq >= existing.LatestOpSeq default: return true @@ -53,36 +53,49 @@ func (v *Verifier) VerifyOperation(op types.ParsedOperation, existing *types.Sta if op.Prev == "" { return errors.New("missing prev on non-genesis operation") } + if op.Prev != existing.ChainTipHash { return fmt.Errorf("prev linkage mismatch: got %s want %s", op.Prev, existing.ChainTipHash) } } sig, ok := findString(op.Payload, "sig", "signature") + if !ok || strings.TrimSpace(sig) == "" { return errors.New("missing signature") } + pubKeys := make([]string, 0, 8) + if existing != nil && len(existing.RotationKeys) > 0 { pubKeys = append(pubKeys, existing.RotationKeys...) } + pubKeys = append(pubKeys, extractPublicKeys(op.Payload)...) + if len(pubKeys) == 0 { return errors.New("missing verification public key") } + sigBytes, err := decodeFlexible(sig) + if err != nil { return fmt.Errorf("decode signature: %w", err) } + payload, err := signaturePayload(op.Payload) + if err != nil { return err } + for _, key := range pubKeys { vk, err := decodePublicKey(key) + if err != nil { continue } + switch vk.Algo { case "ed25519": if ed25519.Verify(vk.Ed25519, payload, sigBytes) { @@ -98,18 +111,21 @@ func (v *Verifier) VerifyOperation(op types.ParsedOperation, existing *types.Sta } } } + return errors.New("signature verification failed") } func signaturePayload(m map[string]any) ([]byte, error) { if raw, ok := findString(m, "sigPayload", "signaturePayload"); ok && raw != "" { decoded, err := decodeFlexible(raw) + if err == nil && json.Valid(decoded) { return types.CanonicalizeJSON(decoded) } } clone := make(map[string]any, len(m)) + for k, v := range m { switch k { case "sig", "signature", "sigPayload", "signaturePayload": @@ -118,14 +134,19 @@ func signaturePayload(m map[string]any) ([]byte, error) { clone[k] = v } } + encMode, err := cbor.CanonicalEncOptions().EncMode() + if err != nil { return nil, fmt.Errorf("init canonical cbor encoder: %w", err) } + b, err := encMode.Marshal(clone) + if err != nil { return nil, fmt.Errorf("marshal signature payload cbor: %w", err) } + return b, nil } @@ -135,6 +156,7 @@ func findString(m map[string]any, keys ...string) (string, bool) { return v, true } } + return "", false } @@ -143,15 +165,19 @@ func extractPublicKeys(payload map[string]any) []string { seen := map[string]struct{}{} add := func(v string) { v = strings.TrimSpace(v) + if v == "" { return } + if _, ok := seen[v]; ok { return } + seen[v] = struct{}{} out = append(out, v) } + if arr, ok := payload["rotationKeys"].([]any); ok { for _, v := range arr { if s, ok := v.(string); ok { @@ -159,19 +185,23 @@ func extractPublicKeys(payload map[string]any) []string { } } } + if v, ok := findString(payload, "publicKey", "verificationMethod", "signingKey", "recoveryKey"); ok { add(v) } + if vm, ok := payload["verificationMethods"].(map[string]any); ok { if v, ok := vm["atproto"].(string); ok { add(v) } + for _, anyV := range vm { if s, ok := anyV.(string); ok { add(s) } } } + return out } @@ -185,41 +215,54 @@ type verificationKey struct { func decodePublicKey(value string) (verificationKey, error) { if strings.HasPrefix(value, "did:key:") { mb := strings.TrimPrefix(value, "did:key:") + if mb == "" || mb[0] != 'z' { return verificationKey{}, errors.New("did:key must be multibase base58btc") } + decoded, err := base58.Decode(mb[1:]) + if err != nil { return verificationKey{}, fmt.Errorf("decode did:key base58: %w", err) } + if len(decoded) < 3 { return verificationKey{}, errors.New("invalid did:key length") } + code, n := binary.Uvarint(decoded) + if n <= 0 || n >= len(decoded) { return verificationKey{}, errors.New("invalid did:key multicodec prefix") } + keyBytes := decoded[n:] + switch code { case 0xED: if len(keyBytes) != ed25519.PublicKeySize { return verificationKey{}, errors.New("invalid did:key ed25519 length") } + return verificationKey{ Algo: "ed25519", Ed25519: ed25519.PublicKey(keyBytes), }, nil case 0xE7: pub, err := secp256k1.ParsePubKey(keyBytes) + if err != nil { return verificationKey{}, fmt.Errorf("parse secp256k1 did:key: %w", err) } + return verificationKey{Algo: "secp256k1", Secp256k1: pub}, nil case 0x1200: x, y := elliptic.UnmarshalCompressed(elliptic.P256(), keyBytes) + if x == nil || y == nil { return verificationKey{}, errors.New("parse p256 did:key: invalid compressed key") } + return verificationKey{ Algo: "p256", P256: &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y}, @@ -228,23 +271,30 @@ func decodePublicKey(value string) (verificationKey, error) { return verificationKey{}, errors.New("unsupported did:key multicodec") } } + b, err := decodeFlexible(value) + if err != nil { return verificationKey{}, fmt.Errorf("decode public key: %w", err) } + if len(b) == ed25519.PublicKeySize { return verificationKey{Algo: "ed25519", Ed25519: ed25519.PublicKey(b)}, nil } + pub, err := secp256k1.ParsePubKey(b) + if err == nil { return verificationKey{Algo: "secp256k1", Secp256k1: pub}, nil } + if x, y := elliptic.UnmarshalCompressed(elliptic.P256(), b); x != nil && y != nil { return verificationKey{ Algo: "p256", P256: &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y}, }, nil } + return verificationKey{}, fmt.Errorf("invalid public key length/type: %d", len(b)) } @@ -252,20 +302,28 @@ func verifySecp256k1(pub *secp256k1.PublicKey, payload, sig []byte) bool { if pub == nil { return false } + var parsed *secpECDSA.Signature + if len(sig) == 64 { var r, s secp256k1.ModNScalar + r.SetByteSlice(sig[:32]) s.SetByteSlice(sig[32:]) + parsed = secpECDSA.NewSignature(&r, &s) } else { der, err := secpECDSA.ParseDERSignature(sig) + if err != nil { return false } + parsed = der } + sum := sha256.Sum256(payload) + return parsed.Verify(sum[:], pub) } @@ -273,12 +331,16 @@ func verifyP256(pub *ecdsa.PublicKey, payload, sig []byte) bool { if pub == nil { return false } + sum := sha256.Sum256(payload) + if len(sig) == 64 { r := new(big.Int).SetBytes(sig[:32]) s := new(big.Int).SetBytes(sig[32:]) + return ecdsa.Verify(pub, sum[:], r, s) } + return ecdsa.VerifyASN1(pub, sum[:], sig) } @@ -286,11 +348,14 @@ func decodeFlexible(v string) ([]byte, error) { if b, err := base64.RawURLEncoding.DecodeString(v); err == nil { return b, nil } + if b, err := base64.StdEncoding.DecodeString(v); err == nil { return b, nil } + if b, err := hex.DecodeString(v); err == nil { return b, nil } + return nil, errors.New("unsupported encoding") } diff --git a/internal/verify/verifier_test.go b/internal/verify/verifier_test.go index 0b38411..e4f46fa 100644 --- a/internal/verify/verifier_test.go +++ b/internal/verify/verifier_test.go @@ -9,7 +9,6 @@ import ( "encoding/base64" "encoding/json" "testing" - "github.com/Fuwn/plutia/internal/config" "github.com/Fuwn/plutia/internal/types" "github.com/decred/dcrd/dcrec/secp256k1/v4" @@ -20,9 +19,11 @@ import ( func TestVerifyOperationValidSignature(t *testing.T) { pub, priv, err := ed25519.GenerateKey(rand.Reader) + if err != nil { t.Fatalf("generate key: %v", err) } + payloadDoc := []byte(`{"did":"did:plc:alice","didDoc":{"id":"did:plc:alice"}}`) sig := ed25519.Sign(priv, payloadDoc) opJSON := map[string]any{ @@ -38,10 +39,13 @@ func TestVerifyOperationValidSignature(t *testing.T) { DID: "did:plc:alice", Operation: raw, }) + if err != nil { t.Fatalf("parse operation: %v", err) } + v := New(config.VerifyFull) + if err := v.VerifyOperation(op, nil); err != nil { t.Fatalf("verify operation: %v", err) } @@ -49,9 +53,11 @@ func TestVerifyOperationValidSignature(t *testing.T) { func TestVerifyOperationPrevMismatch(t *testing.T) { pub, priv, err := ed25519.GenerateKey(rand.Reader) + if err != nil { t.Fatalf("generate key: %v", err) } + payloadDoc := []byte(`{"did":"did:plc:alice","didDoc":{"id":"did:plc:alice"},"prev":"sha256:wrong"}`) sig := ed25519.Sign(priv, payloadDoc) opJSON := map[string]any{ @@ -64,11 +70,14 @@ func TestVerifyOperationPrevMismatch(t *testing.T) { } raw, _ := json.Marshal(opJSON) op, err := types.ParseOperation(types.ExportRecord{Seq: 2, DID: "did:plc:alice", Operation: raw}) + if err != nil { t.Fatalf("parse operation: %v", err) } + v := New(config.VerifyFull) existing := &types.StateV1{DID: "did:plc:alice", ChainTipHash: "sha256:right", LatestOpSeq: 1} + if err := v.VerifyOperation(op, existing); err == nil { t.Fatalf("expected prev mismatch error") } @@ -76,13 +85,14 @@ func TestVerifyOperationPrevMismatch(t *testing.T) { func TestVerifyOperationSecp256k1(t *testing.T) { priv, err := secp256k1.GeneratePrivateKey() + if err != nil { t.Fatalf("generate secp256k1 key: %v", err) } + pubKey := priv.PubKey() didKeyBytes := append([]byte{0xE7, 0x01}, pubKey.SerializeCompressed()...) didKey := "did:key:z" + base58.Encode(didKeyBytes) - unsigned := map[string]any{ "type": "create", "prev": nil, @@ -91,16 +101,19 @@ func TestVerifyOperationSecp256k1(t *testing.T) { "signingKey": didKey, } enc, err := cbor.CanonicalEncOptions().EncMode() + if err != nil { t.Fatalf("init cbor encoder: %v", err) } + payload, err := enc.Marshal(unsigned) + if err != nil { t.Fatalf("marshal cbor: %v", err) } + sum := sha256.Sum256(payload) sig := secpECDSA.Sign(priv, sum[:]).Serialize() - opJSON := map[string]any{ "type": "create", "prev": nil, @@ -115,10 +128,13 @@ func TestVerifyOperationSecp256k1(t *testing.T) { DID: "did:plc:alice", Operation: raw, }) + if err != nil { t.Fatalf("parse operation: %v", err) } + v := New(config.VerifyFull) + if err := v.VerifyOperation(op, nil); err != nil { t.Fatalf("verify operation: %v", err) } @@ -126,13 +142,14 @@ func TestVerifyOperationSecp256k1(t *testing.T) { func TestVerifyOperationP256(t *testing.T) { priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { t.Fatalf("generate p256 key: %v", err) } + compressed := elliptic.MarshalCompressed(priv.Curve, priv.PublicKey.X, priv.PublicKey.Y) didKeyBytes := append([]byte{0x80, 0x24}, compressed...) didKey := "did:key:z" + base58.Encode(didKeyBytes) - unsigned := map[string]any{ "type": "create", "prev": nil, @@ -141,21 +158,28 @@ func TestVerifyOperationP256(t *testing.T) { "signingKey": didKey, } enc, err := cbor.CanonicalEncOptions().EncMode() + if err != nil { t.Fatalf("init cbor encoder: %v", err) } + payload, err := enc.Marshal(unsigned) + if err != nil { t.Fatalf("marshal cbor: %v", err) } + sum := sha256.Sum256(payload) r, s, err := ecdsa.Sign(rand.Reader, priv, sum[:]) + if err != nil { t.Fatalf("sign p256: %v", err) } + sig := make([]byte, 64) rb := r.Bytes() sb := s.Bytes() + copy(sig[32-len(rb):32], rb) copy(sig[64-len(sb):], sb) @@ -173,10 +197,13 @@ func TestVerifyOperationP256(t *testing.T) { DID: "did:plc:alice", Operation: raw, }) + if err != nil { t.Fatalf("parse operation: %v", err) } + v := New(config.VerifyFull) + if err := v.VerifyOperation(op, nil); err != nil { t.Fatalf("verify operation: %v", err) } |