package checkpoint import ( "bufio" "context" "crypto/ed25519" "crypto/sha256" "encoding/base64" "encoding/hex" "encoding/json" "errors" "fmt" "github.com/Fuwn/plutia/internal/merkle" "github.com/Fuwn/plutia/internal/storage" "github.com/Fuwn/plutia/internal/types" "github.com/mr-tron/base58" "os" "path/filepath" "sort" "strings" "time" ) type Manager struct { store storage.Store dataDir string keyPath string } type BuildMetrics struct { DIDCount int MerkleCompute time.Duration Total time.Duration } func NewManager(store storage.Store, dataDir, keyPath string) *Manager { return &Manager{store: store, dataDir: dataDir, keyPath: keyPath} } func (m *Manager) BuildAndStore(sequence uint64, states []types.StateV1, blockHashes []string) (types.CheckpointV1, error) { didRoot, leaves := didMerkle(states) if err := m.writeCheckpointStateSnapshot(sequence, leaves); err != nil { return types.CheckpointV1{}, err } return m.signAndPersist(sequence, didRoot, blockHashes) } func (m *Manager) BuildAndStoreFromStore(sequence uint64, blockHashes []string) (types.CheckpointV1, error) { cp, _, err := m.BuildAndStoreFromStoreWithMetrics(sequence, blockHashes) return cp, err } func (m *Manager) BuildAndStoreFromStoreWithMetrics(sequence uint64, blockHashes []string) (types.CheckpointV1, BuildMetrics, error) { start := time.Now() didRoot, didCount, merkleCompute, err := m.writeCheckpointStateSnapshotFromStore(sequence) if err != nil { return types.CheckpointV1{}, BuildMetrics{}, err } cp, err := m.signAndPersist(sequence, didRoot, blockHashes) if err != nil { return types.CheckpointV1{}, BuildMetrics{}, err } return cp, BuildMetrics{ DIDCount: didCount, MerkleCompute: merkleCompute, Total: time.Since(start), }, nil } func (m *Manager) signAndPersist(sequence uint64, didRoot string, blockHashes []string) (types.CheckpointV1, error) { privateKey, keyID, err := m.loadSigningKey() if err != nil { return types.CheckpointV1{}, err } blockRoot := blockMerkle(blockHashes) prev := "" if latest, ok, err := m.store.GetLatestCheckpoint(); err == nil && ok { prev = latest.CheckpointHash } else if err != nil { return types.CheckpointV1{}, fmt.Errorf("load latest checkpoint: %w", err) } unsigned := types.CheckpointV1{ Version: 1, Sequence: sequence, Timestamp: time.Now().UTC().Format(time.RFC3339), DIDMerkleRoot: didRoot, BlockMerkleRoot: blockRoot, PreviousCheckpointHash: prev, KeyID: keyID, } payload, err := marshalCheckpointPayload(unsigned) if err != nil { return types.CheckpointV1{}, err } sum := sha256.Sum256(payload) unsigned.CheckpointHash = hex.EncodeToString(sum[:]) unsigned.Signature = base64.RawURLEncoding.EncodeToString(ed25519.Sign(privateKey, payload)) if err := m.store.PutCheckpoint(unsigned); err != nil { return types.CheckpointV1{}, fmt.Errorf("persist checkpoint: %w", err) } if err := m.writeCheckpointFile(unsigned); err != nil { return types.CheckpointV1{}, err } return unsigned, nil } func (m *Manager) BuildDIDProofAtCheckpoint(ctx context.Context, did, chainTipHash string, checkpointSeq uint64) ([]merkle.Sibling, string, bool, error) { snapshot, err := m.LoadStateSnapshot(checkpointSeq) if err != nil { return nil, "", false, err } leaves := make([][]byte, len(snapshot.Leaves)) index := -1 leafHashHex := "" for i, s := range snapshot.Leaves { if err := ctx.Err(); err != nil { return nil, "", false, err } h := merkle.HashLeaf([]byte(s.DID + s.ChainTipHash)) leaves[i] = h if s.DID == did && s.ChainTipHash == chainTipHash { index = i leafHashHex = hex.EncodeToString(h) } } if index < 0 { return nil, "", false, nil } proof := merkle.BuildProof(leaves, index) return proof, leafHashHex, true, nil } func (m *Manager) LoadStateSnapshot(sequence uint64) (types.CheckpointStateSnapshotV1, error) { path := filepath.Join(m.dataDir, "checkpoints", fmt.Sprintf("%020d.state.json", sequence)) b, err := os.ReadFile(path) if err != nil { return types.CheckpointStateSnapshotV1{}, fmt.Errorf("read checkpoint state snapshot: %w", err) } var snap types.CheckpointStateSnapshotV1 if err := json.Unmarshal(b, &snap); err != nil { return types.CheckpointStateSnapshotV1{}, fmt.Errorf("unmarshal checkpoint state snapshot: %w", err) } if snap.Sequence != sequence { return types.CheckpointStateSnapshotV1{}, fmt.Errorf("snapshot sequence mismatch: got %d want %d", snap.Sequence, sequence) } return snap, nil } func didMerkle(states []types.StateV1) (string, []types.DIDLeaf) { sorted := append([]types.StateV1(nil), states...) sort.Slice(sorted, func(i, j int) bool { return sorted[i].DID < sorted[j].DID }) acc := merkle.NewAccumulator() snapshotLeaves := make([]types.DIDLeaf, 0, len(sorted)) for _, s := range sorted { snapshotLeaves = append(snapshotLeaves, types.DIDLeaf{ DID: s.DID, ChainTipHash: s.ChainTipHash, }) acc.AddLeafHash(merkle.HashLeaf([]byte(s.DID + s.ChainTipHash))) } root := acc.RootDuplicateLast() return hex.EncodeToString(root), snapshotLeaves } func blockMerkle(hashes []string) string { if len(hashes) == 0 { root := merkle.Root(nil) return hex.EncodeToString(root) } leaves := make([][]byte, 0, len(hashes)) for _, h := range hashes { decoded, err := hex.DecodeString(strings.TrimSpace(h)) if err != nil { decoded = []byte(h) } leaves = append(leaves, merkle.HashLeaf(decoded)) } root := merkle.Root(leaves) return hex.EncodeToString(root) } func marshalCheckpointPayload(cp types.CheckpointV1) ([]byte, error) { clone := cp clone.Signature = "" clone.CheckpointHash = "" b, err := json.Marshal(clone) if err != nil { return nil, fmt.Errorf("marshal checkpoint payload: %w", err) } return types.CanonicalizeJSON(b) } func (m *Manager) writeCheckpointFile(cp types.CheckpointV1) error { dir := filepath.Join(m.dataDir, "checkpoints") if err := os.MkdirAll(dir, 0o755); err != nil { return fmt.Errorf("mkdir checkpoints: %w", err) } path := filepath.Join(dir, fmt.Sprintf("%020d.json", cp.Sequence)) b, err := json.MarshalIndent(cp, "", " ") if err != nil { return fmt.Errorf("marshal checkpoint: %w", err) } if err := os.WriteFile(path, b, 0o644); err != nil { return fmt.Errorf("write checkpoint file: %w", err) } return nil } func (m *Manager) writeCheckpointStateSnapshot(sequence uint64, leaves []types.DIDLeaf) error { dir := filepath.Join(m.dataDir, "checkpoints") if err := os.MkdirAll(dir, 0o755); err != nil { return fmt.Errorf("mkdir checkpoints: %w", err) } snap := types.CheckpointStateSnapshotV1{ Version: 1, Sequence: sequence, CreatedAt: time.Now().UTC().Format(time.RFC3339), Leaves: leaves, } b, err := json.Marshal(snap) if err != nil { return fmt.Errorf("marshal checkpoint state snapshot: %w", err) } path := filepath.Join(dir, fmt.Sprintf("%020d.state.json", sequence)) if err := os.WriteFile(path, b, 0o644); err != nil { return fmt.Errorf("write checkpoint state snapshot: %w", err) } return nil } func (m *Manager) writeCheckpointStateSnapshotFromStore(sequence uint64) (string, int, time.Duration, error) { dir := filepath.Join(m.dataDir, "checkpoints") if err := os.MkdirAll(dir, 0o755); err != nil { return "", 0, 0, fmt.Errorf("mkdir checkpoints: %w", err) } finalPath := filepath.Join(dir, fmt.Sprintf("%020d.state.json", sequence)) tmpPath := finalPath + ".tmp" f, err := os.Create(tmpPath) if err != nil { return "", 0, 0, fmt.Errorf("create checkpoint state snapshot temp file: %w", err) } createdAt := time.Now().UTC().Format(time.RFC3339) w := bufio.NewWriterSize(f, 1<<20) if _, err := fmt.Fprintf(w, `{"v":1,"sequence":%d,"created_at":"%s","leaves":[`, sequence, createdAt); err != nil { _ = f.Close() return "", 0, 0, fmt.Errorf("write snapshot header: %w", err) } first := true didCount := 0 merkleCompute := time.Duration(0) acc := merkle.NewAccumulator() if err := m.store.ForEachState(func(s types.StateV1) error { leaf := types.DIDLeaf{ DID: s.DID, ChainTipHash: s.ChainTipHash, } b, err := json.Marshal(leaf) if err != nil { return fmt.Errorf("marshal checkpoint leaf: %w", err) } if !first { if _, err := w.WriteString(","); err != nil { return err } } if _, err := w.Write(b); err != nil { return err } first = false hashStart := time.Now() acc.AddLeafHash(merkle.HashLeaf([]byte(leaf.DID + leaf.ChainTipHash))) merkleCompute += time.Since(hashStart) didCount++ return nil }); err != nil { _ = f.Close() return "", 0, 0, err } if _, err := w.WriteString(`]}`); err != nil { _ = f.Close() return "", 0, 0, fmt.Errorf("write snapshot trailer: %w", err) } if err := w.Flush(); err != nil { _ = f.Close() return "", 0, 0, fmt.Errorf("flush snapshot writer: %w", err) } if err := f.Sync(); err != nil { _ = f.Close() return "", 0, 0, fmt.Errorf("sync snapshot file: %w", err) } if err := f.Close(); err != nil { return "", 0, 0, fmt.Errorf("close snapshot file: %w", err) } if err := os.Rename(tmpPath, finalPath); err != nil { return "", 0, 0, fmt.Errorf("rename snapshot file: %w", err) } return hex.EncodeToString(acc.RootDuplicateLast()), didCount, merkleCompute, nil } func (m *Manager) loadSigningKey() (ed25519.PrivateKey, string, error) { data, err := os.ReadFile(m.keyPath) if err != nil { return nil, "", fmt.Errorf("read mirror private key: %w", err) } text := strings.TrimSpace(string(data)) if text == "" { return nil, "", errors.New("empty mirror private key") } if raw, err := decodeKeyString(text); err == nil { return keyFromRaw(raw) } var k struct { PrivateKey string `json:"private_key"` } if err := json.Unmarshal(data, &k); err == nil && strings.TrimSpace(k.PrivateKey) != "" { raw, err := decodeKeyString(k.PrivateKey) if err != nil { return nil, "", err } return keyFromRaw(raw) } return nil, "", errors.New("unsupported private key format") } func decodeKeyString(v string) ([]byte, error) { if strings.HasPrefix(v, "did:key:") { mb := strings.TrimPrefix(v, "did:key:") if mb == "" || mb[0] != 'z' { return nil, errors.New("unsupported did:key format") } decoded, err := base58.Decode(mb[1:]) if err != nil { return nil, err } if len(decoded) < 34 { return nil, errors.New("invalid did:key length") } return decoded[len(decoded)-32:], nil } if isHexString(v) { if b, err := hex.DecodeString(v); err == nil { return b, nil } } if b, err := base64.RawURLEncoding.DecodeString(v); err == nil { return b, nil } if b, err := base64.StdEncoding.DecodeString(v); err == nil { return b, nil } if b, err := hex.DecodeString(v); err == nil { return b, nil } return nil, errors.New("unknown key encoding") } func isHexString(v string) bool { if len(v) == 0 || len(v)%2 != 0 { return false } for _, r := range v { if (r >= '0' && r <= '9') || (r >= 'a' && r <= 'f') || (r >= 'A' && r <= 'F') { continue } return false } return true } func keyFromRaw(raw []byte) (ed25519.PrivateKey, string, error) { switch len(raw) { case ed25519.SeedSize: pk := ed25519.NewKeyFromSeed(raw) kid := keyID(pk.Public().(ed25519.PublicKey)) return pk, kid, nil case ed25519.PrivateKeySize: pk := ed25519.PrivateKey(raw) kid := keyID(pk.Public().(ed25519.PublicKey)) return pk, kid, nil default: return nil, "", fmt.Errorf("invalid private key length %d", len(raw)) } } func keyID(pub ed25519.PublicKey) string { sum := sha256.Sum256(pub) return "ed25519:" + hex.EncodeToString(sum[:8]) }