diff options
| author | Fuwn <[email protected]> | 2026-02-27 07:13:17 +0000 |
|---|---|---|
| committer | Fuwn <[email protected]> | 2026-02-27 07:13:17 +0000 |
| commit | 856e2994722e2e7f67b47d55b8e673ddabcebe83 (patch) | |
| tree | 5a4e108384038eaa072d8e6c5f71ab68901fb431 /internal/collect | |
| download | kivia-main.tar.xz kivia-main.zip | |
Diffstat (limited to 'internal/collect')
| -rw-r--r-- | internal/collect/collect.go | 331 |
1 files changed, 331 insertions, 0 deletions
diff --git a/internal/collect/collect.go b/internal/collect/collect.go new file mode 100644 index 0000000..ccb3b46 --- /dev/null +++ b/internal/collect/collect.go @@ -0,0 +1,331 @@ +package collect + +import ( + "bytes" + "fmt" + "go/ast" + "go/parser" + "go/printer" + "go/token" + "io/fs" + "os" + "path/filepath" + "sort" + "strings" +) + +type Context struct { + EnclosingFunction string `json:"enclosingFunction,omitempty"` + Type string `json:"type,omitempty"` + ValueExpression string `json:"valueExpression,omitempty"` + ParentType string `json:"parentType,omitempty"` +} + +type Identifier struct { + Name string `json:"name"` + Kind string `json:"kind"` + File string `json:"file"` + Line int `json:"line"` + Column int `json:"column"` + Context Context `json:"context"` +} + +func FromPath(path string) ([]Identifier, error) { + files, err := discoverFiles(path) + + if err != nil { + return nil, err + } + + fileSet := token.NewFileSet() + identifiers := make([]Identifier, 0, 128) + + for _, filePath := range files { + fileNode, parseErr := parser.ParseFile(fileSet, filePath, nil, parser.SkipObjectResolution) + + if parseErr != nil { + return nil, fmt.Errorf("Failed to parse %s: %w", filePath, parseErr) + } + + collector := visitor{ + fileSet: fileSet, + file: filePath, + } + + ast.Walk(&collector, fileNode) + + identifiers = append(identifiers, collector.identifiers...) + } + + return identifiers, nil +} + +type visitor struct { + fileSet *token.FileSet + file string + identifiers []Identifier + functionStack []string + typeStack []string +} + +func (identifierVisitor *visitor) Visit(node ast.Node) ast.Visitor { + switch typedNode := node.(type) { + case *ast.FuncDecl: + identifierVisitor.addIdentifier(typedNode.Name, "function", Context{}) + + identifierVisitor.functionStack = append(identifierVisitor.functionStack, typedNode.Name.Name) + + identifierVisitor.captureFieldList(typedNode.Recv, "receiver") + identifierVisitor.captureFieldList(typedNode.Type.Params, "parameter") + identifierVisitor.captureFieldList(typedNode.Type.Results, "result") + + return leaveScope(identifierVisitor, func() { + identifierVisitor.functionStack = identifierVisitor.functionStack[:len(identifierVisitor.functionStack)-1] + }) + case *ast.TypeSpec: + identifierVisitor.addIdentifier(typedNode.Name, "type", Context{}) + + identifierVisitor.typeStack = append(identifierVisitor.typeStack, typedNode.Name.Name) + + identifierVisitor.captureTypeMembers(typedNode.Name.Name, typedNode.Type) + + return leaveScope(identifierVisitor, func() { identifierVisitor.typeStack = identifierVisitor.typeStack[:len(identifierVisitor.typeStack)-1] }) + case *ast.ValueSpec: + declaredType := renderExpression(identifierVisitor.fileSet, typedNode.Type) + rightHandValue := renderExpressionList(identifierVisitor.fileSet, typedNode.Values) + + for _, name := range typedNode.Names { + identifierVisitor.addIdentifier(name, "variable", Context{Type: declaredType, ValueExpression: rightHandValue}) + } + case *ast.AssignStmt: + if typedNode.Tok != token.DEFINE { + break + } + + rightHandValue := renderExpressionList(identifierVisitor.fileSet, typedNode.Rhs) + + for index, left := range typedNode.Lhs { + identifierNode, ok := left.(*ast.Ident) + + if !ok { + continue + } + + assignmentContext := Context{ValueExpression: rightHandValue} + + if index < len(typedNode.Rhs) { + assignmentContext.Type = inferTypeFromExpression(typedNode.Rhs[index]) + } + + identifierVisitor.addIdentifier(identifierNode, "variable", assignmentContext) + } + case *ast.RangeStmt: + if typedNode.Tok != token.DEFINE { + break + } + + if keyIdentifier, ok := typedNode.Key.(*ast.Ident); ok { + identifierVisitor.addIdentifier(keyIdentifier, "rangeKey", Context{ValueExpression: renderExpression(identifierVisitor.fileSet, typedNode.X)}) + } + + if valueIdentifier, ok := typedNode.Value.(*ast.Ident); ok { + identifierVisitor.addIdentifier(valueIdentifier, "rangeValue", Context{ValueExpression: renderExpression(identifierVisitor.fileSet, typedNode.X)}) + } + } + + return identifierVisitor +} + +type scopeExit struct { + parent *visitor + onLeave func() +} + +func leaveScope(parent *visitor, onLeave func()) ast.Visitor { + return &scopeExit{parent: parent, onLeave: onLeave} +} + +func (scopeExitVisitor *scopeExit) Visit(node ast.Node) ast.Visitor { + if node == nil { + scopeExitVisitor.onLeave() + + return nil + } + + return scopeExitVisitor.parent +} + +func (identifierVisitor *visitor) captureFieldList(fields *ast.FieldList, kind string) { + if fields == nil { + return + } + + for _, field := range fields.List { + declaredType := renderExpression(identifierVisitor.fileSet, field.Type) + + for _, name := range field.Names { + identifierVisitor.addIdentifier(name, kind, Context{Type: declaredType}) + } + } +} + +func (identifierVisitor *visitor) captureTypeMembers(typeName string, typeExpression ast.Expr) { + switch typedType := typeExpression.(type) { + case *ast.StructType: + if typedType.Fields == nil { + return + } + + for _, field := range typedType.Fields.List { + memberType := renderExpression(identifierVisitor.fileSet, field.Type) + + for _, fieldName := range field.Names { + identifierVisitor.addIdentifier(fieldName, "field", Context{Type: memberType, ParentType: typeName}) + } + } + case *ast.InterfaceType: + if typedType.Methods == nil { + return + } + + for _, method := range typedType.Methods.List { + memberType := renderExpression(identifierVisitor.fileSet, method.Type) + + for _, methodName := range method.Names { + identifierVisitor.addIdentifier(methodName, "interfaceMethod", Context{Type: memberType, ParentType: typeName}) + } + } + } +} + +func (identifierVisitor *visitor) addIdentifier(identifier *ast.Ident, kind string, context Context) { + if identifier == nil || identifier.Name == "_" { + return + } + + position := identifierVisitor.fileSet.Position(identifier.NamePos) + context.EnclosingFunction = currentFunction(identifierVisitor.functionStack) + identifierVisitor.identifiers = append(identifierVisitor.identifiers, Identifier{ + Name: identifier.Name, + Kind: kind, + File: identifierVisitor.file, + Line: position.Line, + Column: position.Column, + Context: context, + }) +} + +func currentFunction(stack []string) string { + if len(stack) == 0 { + return "" + } + + return stack[len(stack)-1] +} + +func discoverFiles(path string) ([]string, error) { + searchRoot := path + recursive := false + + if strings.HasSuffix(path, "/...") { + searchRoot = strings.TrimSuffix(path, "/...") + recursive = true + } + + if searchRoot == "" { + searchRoot = "." + } + + pathFileDetails, err := os.Stat(searchRoot) + + if err != nil { + return nil, err + } + + if !pathFileDetails.IsDir() { + if strings.HasSuffix(searchRoot, ".go") { + return []string{searchRoot}, nil + } + + return nil, fmt.Errorf("Path %q is not a Go file.", searchRoot) + } + + files := make([]string, 0, 64) + walkErr := filepath.WalkDir(searchRoot, func(candidate string, entry fs.DirEntry, walkError error) error { + if walkError != nil { + return walkError + } + + if entry.IsDir() { + name := entry.Name() + + if name == ".git" || name == "vendor" || name == "node_modules" { + return filepath.SkipDir + } + + if !recursive && candidate != searchRoot { + return filepath.SkipDir + } + + return nil + } + + if strings.HasSuffix(candidate, ".go") { + files = append(files, candidate) + } + + return nil + }) + + if walkErr != nil { + return nil, walkErr + } + + sort.Strings(files) + + return files, nil +} + +func renderExpression(fileSet *token.FileSet, expression ast.Expr) string { + if expression == nil { + return "" + } + + var buffer bytes.Buffer + + if err := printer.Fprint(&buffer, fileSet, expression); err != nil { + return "" + } + + return buffer.String() +} + +func renderExpressionList(fileSet *token.FileSet, expressions []ast.Expr) string { + if len(expressions) == 0 { + return "" + } + + parts := make([]string, 0, len(expressions)) + + for _, expression := range expressions { + parts = append(parts, renderExpression(fileSet, expression)) + } + + return strings.Join(parts, ", ") +} + +func inferTypeFromExpression(expression ast.Expr) string { + switch typedExpression := expression.(type) { + case *ast.CallExpr: + switch functionExpression := typedExpression.Fun.(type) { + case *ast.Ident: + return functionExpression.Name + case *ast.SelectorExpr: + return functionExpression.Sel.Name + } + + return "" + default: + return "" + } +} |