diff options
| author | Fuwn <[email protected]> | 2026-02-26 20:20:09 -0800 |
|---|---|---|
| committer | Fuwn <[email protected]> | 2026-02-26 20:20:09 -0800 |
| commit | 80dc07bf7bbf3ee9f7191a0446199d74cbb2d341 (patch) | |
| tree | 746f1518723b64a5270131a1e19b6c6f5a312926 /internal/api | |
| parent | chore: remove accidentally committed binary (diff) | |
| download | plutia-test-80dc07bf7bbf3ee9f7191a0446199d74cbb2d341.tar.xz plutia-test-80dc07bf7bbf3ee9f7191a0446199d74cbb2d341.zip | |
fix: align PLC compatibility read endpoints with plc.directory schema
Diffstat (limited to 'internal/api')
| -rw-r--r-- | internal/api/plc_compatibility_test.go | 122 | ||||
| -rw-r--r-- | internal/api/server.go | 291 |
2 files changed, 366 insertions, 47 deletions
diff --git a/internal/api/plc_compatibility_test.go b/internal/api/plc_compatibility_test.go index fe7ee82..45428d8 100644 --- a/internal/api/plc_compatibility_test.go +++ b/internal/api/plc_compatibility_test.go @@ -21,25 +21,55 @@ import ( "time" ) -func TestPLCCompatibilityGetDIDMatchesStoredDocument(t *testing.T) { - ts, store, _, cleanup := newCompatibilityServer(t) +func TestPLCCompatibilityGetDIDShape(t *testing.T) { + ts, _, _, cleanup := newCompatibilityServer(t) defer cleanup() - state, ok, err := store.GetState("did:plc:alice") + resp, err := http.Get(ts.URL + "/did:plc:alice") if err != nil { - t.Fatalf("get state: %v", err) + t.Fatalf("get did: %v", err) } - if !ok { - t.Fatalf("state not found") + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("status: got %d want 200", resp.StatusCode) } - resp, err := http.Get(ts.URL + "/did:plc:alice") + if got := resp.Header.Get("Content-Type"); !strings.Contains(got, "application/did+ld+json") { + t.Fatalf("content-type mismatch: %s", got) + } + + var body map[string]any + + if err := json.NewDecoder(resp.Body).Decode(&body); err != nil { + t.Fatalf("decode body: %v", err) + } + + if _, ok := body["@context"].([]any); !ok { + t.Fatalf("missing or invalid @context: %#v", body["@context"]) + } + + if got, _ := body["id"].(string); got != "did:plc:alice" { + t.Fatalf("id mismatch: got %q want %q", got, "did:plc:alice") + } + + if _, ok := body["authentication"]; ok { + t.Fatalf("unexpected authentication field in compatibility did document") + } +} + +func TestPLCCompatibilityGetDataShape(t *testing.T) { + ts, _, _, cleanup := newCompatibilityServer(t) + + defer cleanup() + + resp, err := http.Get(ts.URL + "/did:plc:alice/data") if err != nil { - t.Fatalf("get did: %v", err) + t.Fatalf("get data: %v", err) } defer resp.Body.Close() @@ -48,14 +78,34 @@ func TestPLCCompatibilityGetDIDMatchesStoredDocument(t *testing.T) { t.Fatalf("status: got %d want 200", resp.StatusCode) } - if got := resp.Header.Get("Content-Type"); !strings.Contains(got, "application/did+ld+json") { - t.Fatalf("content-type mismatch: %s", got) + var body map[string]any + + if err := json.NewDecoder(resp.Body).Decode(&body); err != nil { + t.Fatalf("decode body: %v", err) } - body, _ := io.ReadAll(resp.Body) + required := []string{"did", "verificationMethods", "rotationKeys", "alsoKnownAs", "services"} + + for _, key := range required { + if _, ok := body[key]; !ok { + t.Fatalf("missing key %q in /data response", key) + } + } - if strings.TrimSpace(string(body)) != strings.TrimSpace(string(state.DIDDocument)) { - t.Fatalf("did document mismatch\n got: %s\nwant: %s", string(body), string(state.DIDDocument)) + if _, ok := body["verificationMethods"].(map[string]any); !ok { + t.Fatalf("verificationMethods has wrong type: %#v", body["verificationMethods"]) + } + + if _, ok := body["rotationKeys"].([]any); !ok { + t.Fatalf("rotationKeys has wrong type: %#v", body["rotationKeys"]) + } + + if _, ok := body["alsoKnownAs"].([]any); !ok { + t.Fatalf("alsoKnownAs has wrong type: %#v", body["alsoKnownAs"]) + } + + if _, ok := body["services"].(map[string]any); !ok { + t.Fatalf("services has wrong type: %#v", body["services"]) } } @@ -164,6 +214,20 @@ func TestPLCCompatibilityPostIsMethodNotAllowed(t *testing.T) { if allow := resp.Header.Get("Allow"); allow != http.MethodGet { t.Fatalf("allow header mismatch: got %q want %q", allow, http.MethodGet) } + + var body map[string]any + + if err := json.NewDecoder(resp.Body).Decode(&body); err != nil { + t.Fatalf("decode body: %v", err) + } + + if _, ok := body["message"].(string); !ok { + t.Fatalf("expected PLC-style message field, got: %v", body) + } + + if _, ok := body["error"]; ok { + t.Fatalf("unexpected internal error field in compatibility response: %v", body) + } } func TestPLCCompatibilityNoVerificationMetadataLeak(t *testing.T) { @@ -206,6 +270,38 @@ func TestPLCCompatibilityProofEndpointStillWorks(t *testing.T) { } } +func TestPLCCompatibilityNotFoundUsesPLCErrorShape(t *testing.T) { + ts, _, _, cleanup := newCompatibilityServer(t) + + defer cleanup() + + resp, err := http.Get(ts.URL + "/did:plc:not-registered") + + if err != nil { + t.Fatalf("get missing did: %v", err) + } + + defer resp.Body.Close() + + if resp.StatusCode != http.StatusNotFound { + t.Fatalf("status: got %d want 404", resp.StatusCode) + } + + var body map[string]any + + if err := json.NewDecoder(resp.Body).Decode(&body); err != nil { + t.Fatalf("decode body: %v", err) + } + + if _, ok := body["message"].(string); !ok { + t.Fatalf("expected message field for 404, got: %v", body) + } + + if _, ok := body["error"]; ok { + t.Fatalf("unexpected error field in compatibility 404 body: %v", body) + } +} + func newCompatibilityServer(t *testing.T) (*httptest.Server, *storage.PebbleStore, []types.ExportRecord, func()) { t.Helper() diff --git a/internal/api/server.go b/internal/api/server.go index a3ef211..e9e1e15 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -14,6 +14,7 @@ import ( "github.com/Fuwn/plutia/internal/types" "github.com/Fuwn/plutia/pkg/proof" "net/http" + "sort" "strconv" "strings" "time" @@ -364,11 +365,39 @@ type plcAuditEntry struct { CreatedAt string `json:"createdAt"` } +type plcDIDDocument struct { + Context []string `json:"@context"` + ID string `json:"id"` + AlsoKnownAs []string `json:"alsoKnownAs,omitempty"` + VerificationMethod []plcVerificationEntry `json:"verificationMethod,omitempty"` + Service []plcServiceEntry `json:"service,omitempty"` + Deactivated bool `json:"deactivated,omitempty"` +} + +type plcVerificationEntry struct { + ID string `json:"id"` + Type string `json:"type"` + Controller string `json:"controller"` + PublicKeyMultibase string `json:"publicKeyMultibase"` +} + +type plcServiceEntry struct { + ID string `json:"id"` + Type string `json:"type"` + ServiceEndpoint string `json:"serviceEndpoint"` +} + +var plcDIDContexts = []string{ + "https://www.w3.org/ns/did/v1", + "https://w3id.org/security/multikey/v1", + "https://w3id.org/security/suites/secp256k1-2019/v1", +} + func (s *Server) handlePLCCompatibility(w http.ResponseWriter, r *http.Request) { path := strings.TrimPrefix(r.URL.Path, "/") if path == "" { - writeErr(w, http.StatusNotFound, fmt.Errorf("not found")) + writeCompatibilityErr(w, http.StatusNotFound, "not found") return } @@ -383,28 +412,28 @@ func (s *Server) handlePLCCompatibility(w http.ResponseWriter, r *http.Request) did := parts[0] if !strings.HasPrefix(did, "did:") { - writeErr(w, http.StatusNotFound, fmt.Errorf("not found")) + writeCompatibilityErr(w, http.StatusNotFound, "not found") return } if r.Method == http.MethodPost && len(parts) == 1 { w.Header().Set("Allow", http.MethodGet) - writeErr(w, http.StatusMethodNotAllowed, fmt.Errorf("write operations are not supported by this mirror")) + writeCompatibilityErr(w, http.StatusMethodNotAllowed, "write operations are not supported by this mirror") return } if r.Method != http.MethodGet { w.Header().Set("Allow", http.MethodGet) - writeErr(w, http.StatusMethodNotAllowed, fmt.Errorf("method not allowed")) + writeCompatibilityErr(w, http.StatusMethodNotAllowed, "method not allowed") return } switch { case len(parts) == 1: - s.handleGetDIDCompatibility(w, did) + s.handleGetDIDCompatibility(w, r, did) case len(parts) == 2 && parts[1] == "data": s.handleGetDIDDataCompatibility(w, r, did) case len(parts) == 2 && parts[1] == "log": @@ -414,40 +443,63 @@ func (s *Server) handlePLCCompatibility(w http.ResponseWriter, r *http.Request) case len(parts) == 3 && parts[1] == "log" && parts[2] == "audit": s.handleGetDIDLogAuditCompatibility(w, r, did) default: - writeErr(w, http.StatusNotFound, fmt.Errorf("not found")) + writeCompatibilityErr(w, http.StatusNotFound, "not found") } } -func (s *Server) handleGetDIDCompatibility(w http.ResponseWriter, did string) { +func (s *Server) handleGetDIDCompatibility(w http.ResponseWriter, r *http.Request, did string) { state, ok, err := s.store.GetState(did) if err != nil { - writeErr(w, http.StatusInternalServerError, err) + writeCompatibilityErr(w, http.StatusInternalServerError, err.Error()) return } if !ok { - writeErr(w, http.StatusNotFound, fmt.Errorf("did not found")) + writeCompatibilityErr(w, http.StatusNotFound, "DID not registered: "+did) + + return + } + + if s.ingestor == nil { + writeCompatibilityErr(w, http.StatusServiceUnavailable, "ingestor unavailable") + + return + } + + data, err := s.ingestor.LoadCurrentPLCData(r.Context(), did) + if err != nil { + if errors.Is(err, ingest.ErrDIDNotFound) { + writeCompatibilityErr(w, http.StatusNotFound, "DID not registered: "+did) + + return + } + + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + writeCompatibilityErr(w, http.StatusGatewayTimeout, err.Error()) + + return + } + + writeCompatibilityErr(w, http.StatusInternalServerError, err.Error()) return } status := http.StatusOK - if isTombstonedDIDDocument(state.DIDDocument) { + deactivated := isTombstonedDIDDocument(state.DIDDocument) + if deactivated { status = http.StatusGone } - w.Header().Set("Content-Type", "application/did+ld+json") - w.WriteHeader(status) - - _, _ = w.Write(state.DIDDocument) + writeJSONWithContentType(w, status, "application/did+ld+json", buildPLCDIDDocument(did, data, deactivated)) } func (s *Server) handleGetDIDLogCompatibility(w http.ResponseWriter, r *http.Request, did string) { if s.ingestor == nil { - writeErr(w, http.StatusServiceUnavailable, fmt.Errorf("ingestor unavailable")) + writeCompatibilityErr(w, http.StatusServiceUnavailable, "ingestor unavailable") return } @@ -456,18 +508,18 @@ func (s *Server) handleGetDIDLogCompatibility(w http.ResponseWriter, r *http.Req if err != nil { if errors.Is(err, ingest.ErrDIDNotFound) || errors.Is(err, ingest.ErrHistoryNotStored) { - writeErr(w, http.StatusNotFound, fmt.Errorf("did not found")) + writeCompatibilityErr(w, http.StatusNotFound, "DID not registered: "+did) return } if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { - writeErr(w, http.StatusGatewayTimeout, err) + writeCompatibilityErr(w, http.StatusGatewayTimeout, err.Error()) return } - writeErr(w, http.StatusInternalServerError, err) + writeCompatibilityErr(w, http.StatusInternalServerError, err.Error()) return } @@ -483,7 +535,7 @@ func (s *Server) handleGetDIDLogCompatibility(w http.ResponseWriter, r *http.Req func (s *Server) handleGetDIDLogLastCompatibility(w http.ResponseWriter, r *http.Request, did string) { if s.ingestor == nil { - writeErr(w, http.StatusServiceUnavailable, fmt.Errorf("ingestor unavailable")) + writeCompatibilityErr(w, http.StatusServiceUnavailable, "ingestor unavailable") return } @@ -492,18 +544,18 @@ func (s *Server) handleGetDIDLogLastCompatibility(w http.ResponseWriter, r *http if err != nil { if errors.Is(err, ingest.ErrDIDNotFound) || errors.Is(err, ingest.ErrHistoryNotStored) { - writeErr(w, http.StatusNotFound, fmt.Errorf("did not found")) + writeCompatibilityErr(w, http.StatusNotFound, "DID not registered: "+did) return } if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { - writeErr(w, http.StatusGatewayTimeout, err) + writeCompatibilityErr(w, http.StatusGatewayTimeout, err.Error()) return } - writeErr(w, http.StatusInternalServerError, err) + writeCompatibilityErr(w, http.StatusInternalServerError, err.Error()) return } @@ -516,7 +568,7 @@ func (s *Server) handleGetDIDLogLastCompatibility(w http.ResponseWriter, r *http func (s *Server) handleGetDIDLogAuditCompatibility(w http.ResponseWriter, r *http.Request, did string) { if s.ingestor == nil { - writeErr(w, http.StatusServiceUnavailable, fmt.Errorf("ingestor unavailable")) + writeCompatibilityErr(w, http.StatusServiceUnavailable, "ingestor unavailable") return } @@ -525,18 +577,18 @@ func (s *Server) handleGetDIDLogAuditCompatibility(w http.ResponseWriter, r *htt if err != nil { if errors.Is(err, ingest.ErrDIDNotFound) || errors.Is(err, ingest.ErrHistoryNotStored) { - writeErr(w, http.StatusNotFound, fmt.Errorf("did not found")) + writeCompatibilityErr(w, http.StatusNotFound, "DID not registered: "+did) return } if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { - writeErr(w, http.StatusGatewayTimeout, err) + writeCompatibilityErr(w, http.StatusGatewayTimeout, err.Error()) return } - writeErr(w, http.StatusInternalServerError, err) + writeCompatibilityErr(w, http.StatusInternalServerError, err.Error()) return } @@ -558,7 +610,7 @@ func (s *Server) handleGetDIDLogAuditCompatibility(w http.ResponseWriter, r *htt func (s *Server) handleGetDIDDataCompatibility(w http.ResponseWriter, r *http.Request, did string) { if s.ingestor == nil { - writeErr(w, http.StatusServiceUnavailable, fmt.Errorf("ingestor unavailable")) + writeCompatibilityErr(w, http.StatusServiceUnavailable, "ingestor unavailable") return } @@ -567,18 +619,18 @@ func (s *Server) handleGetDIDDataCompatibility(w http.ResponseWriter, r *http.Re if err != nil { if errors.Is(err, ingest.ErrDIDNotFound) { - writeErr(w, http.StatusNotFound, fmt.Errorf("did not found")) + writeCompatibilityErr(w, http.StatusNotFound, "DID not registered: "+did) return } if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { - writeErr(w, http.StatusGatewayTimeout, err) + writeCompatibilityErr(w, http.StatusGatewayTimeout, err.Error()) return } - writeErr(w, http.StatusInternalServerError, err) + writeCompatibilityErr(w, http.StatusInternalServerError, err.Error()) return } @@ -589,13 +641,13 @@ func (s *Server) handleGetDIDDataCompatibility(w http.ResponseWriter, r *http.Re func (s *Server) handleExportCompatibility(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { w.Header().Set("Allow", http.MethodGet) - writeErr(w, http.StatusMethodNotAllowed, fmt.Errorf("method not allowed")) + writeCompatibilityErr(w, http.StatusMethodNotAllowed, "method not allowed") return } if s.ingestor == nil { - writeErr(w, http.StatusServiceUnavailable, fmt.Errorf("ingestor unavailable")) + writeCompatibilityErr(w, http.StatusServiceUnavailable, "ingestor unavailable") return } @@ -606,7 +658,7 @@ func (s *Server) handleExportCompatibility(w http.ResponseWriter, r *http.Reques n, err := strconv.Atoi(rawCount) if err != nil || n < 1 { - writeErr(w, http.StatusBadRequest, fmt.Errorf("invalid count query parameter")) + writeCompatibilityErr(w, http.StatusBadRequest, "invalid count query parameter") return } @@ -624,7 +676,7 @@ func (s *Server) handleExportCompatibility(w http.ResponseWriter, r *http.Reques parsed, err := time.Parse(time.RFC3339, rawAfter) if err != nil { - writeErr(w, http.StatusBadRequest, fmt.Errorf("invalid after query parameter")) + writeCompatibilityErr(w, http.StatusBadRequest, "invalid after query parameter") return } @@ -679,6 +731,177 @@ func isTombstonedDIDDocument(raw []byte) bool { return deactivated } +func buildPLCDIDDocument(did string, plcData map[string]any, deactivated bool) plcDIDDocument { + doc := plcDIDDocument{ + Context: append([]string(nil), plcDIDContexts...), + ID: did, + AlsoKnownAs: extractStringArray(plcData["alsoKnownAs"]), + Deactivated: deactivated, + } + + verificationMethods := extractVerificationMethodMap(plcData["verificationMethods"]) + if len(verificationMethods) > 0 { + names := make([]string, 0, len(verificationMethods)) + for name := range verificationMethods { + names = append(names, name) + } + + sort.Strings(names) + + doc.VerificationMethod = make([]plcVerificationEntry, 0, len(names)) + for _, name := range names { + value := verificationMethods[name] + if strings.TrimSpace(value) == "" { + continue + } + + doc.VerificationMethod = append(doc.VerificationMethod, plcVerificationEntry{ + ID: did + "#" + name, + Type: "Multikey", + Controller: did, + PublicKeyMultibase: value, + }) + } + } + + services := extractServicesMap(plcData["services"]) + if len(services) > 0 { + names := make([]string, 0, len(services)) + for name := range services { + names = append(names, name) + } + + sort.Strings(names) + + doc.Service = make([]plcServiceEntry, 0, len(names)) + for _, name := range names { + entry := services[name] + typ := entry["type"] + endpoint := entry["endpoint"] + + if strings.TrimSpace(endpoint) == "" { + continue + } + + doc.Service = append(doc.Service, plcServiceEntry{ + ID: "#" + name, + Type: typ, + ServiceEndpoint: endpoint, + }) + } + } + + return doc +} + +func extractStringArray(v any) []string { + switch raw := v.(type) { + case []string: + out := make([]string, 0, len(raw)) + for _, item := range raw { + item = strings.TrimSpace(item) + if item == "" { + continue + } + + out = append(out, item) + } + + return out + case []any: + out := make([]string, 0, len(raw)) + for _, item := range raw { + s, _ := item.(string) + if strings.TrimSpace(s) == "" { + continue + } + + out = append(out, s) + } + + return out + default: + return nil + } +} + +func extractVerificationMethodMap(v any) map[string]string { + out := map[string]string{} + + switch vm := v.(type) { + case map[string]string: + for name, key := range vm { + if strings.TrimSpace(key) == "" { + continue + } + + out[name] = key + } + case map[string]any: + for name, raw := range vm { + key, _ := raw.(string) + if strings.TrimSpace(key) == "" { + continue + } + + out[name] = key + } + } + + return out +} + +func extractServicesMap(v any) map[string]map[string]string { + out := map[string]map[string]string{} + + switch services := v.(type) { + case map[string]map[string]string: + for name, entry := range services { + endpoint := strings.TrimSpace(entry["endpoint"]) + if endpoint == "" { + endpoint = strings.TrimSpace(entry["serviceEndpoint"]) + } + + if endpoint == "" { + continue + } + + out[name] = map[string]string{ + "type": entry["type"], + "endpoint": endpoint, + } + } + case map[string]any: + for name, raw := range services { + entry, ok := raw.(map[string]any) + if !ok { + continue + } + + typ, _ := entry["type"].(string) + endpoint, _ := entry["endpoint"].(string) + if endpoint == "" { + endpoint, _ = entry["serviceEndpoint"].(string) + } + + if strings.TrimSpace(endpoint) == "" { + continue + } + + out[name] = map[string]string{ + "type": typ, + "endpoint": endpoint, + } + } + } + + return out +} + +func writeCompatibilityErr(w http.ResponseWriter, code int, message string) { + writeJSON(w, code, map[string]any{"message": message}) +} + func (s *Server) withTimeout(next http.Handler) http.Handler { timeout := s.cfg.RequestTimeout |