aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorTeddy Wing2023-05-20 14:57:21 +0200
committerTeddy Wing2023-05-20 14:57:21 +0200
commit8ac64e76c32a1f803460a71d29efd235d9be8edd (patch)
tree1b91bfd38ad5f2c341d9097ff6642a86327435f3
parent68571cceeb663340bc196a6f01e022713a1eb233 (diff)
downloadgocapturedrefrace-8ac64e76c32a1f803460a71d29efd235d9be8edd.tar.bz2
Add support for function literals defined in local variables
In addition to checking function literals after the `go` statement, also check closures assigned to variables.
-rw-r--r--capturedrefrace.go64
1 files changed, 54 insertions, 10 deletions
diff --git a/capturedrefrace.go b/capturedrefrace.go
index 3d7dd3a..e9221a9 100644
--- a/capturedrefrace.go
+++ b/capturedrefrace.go
@@ -116,8 +116,17 @@ func run(pass *analysis.Pass) (interface{}, error) {
}
}
- // Look for a function literal after the `go` statement.
- funcLit, ok := goStmt.Call.Fun.(*ast.FuncLit)
+ var funcLit *ast.FuncLit
+
+ switch goStmtFunc := goStmt.Call.Fun.(type) {
+ case *ast.FuncLit:
+ // Look for a function literal after the `go` statement.
+ funcLit, ok = goStmt.Call.Fun.(*ast.FuncLit)
+
+ case *ast.Ident:
+ // Get a function literal stored in a local variable.
+ funcLit, ok = funcLitFromIdent(goStmtFunc)
+ }
if !ok {
return
}
@@ -144,17 +153,52 @@ func run(pass *analysis.Pass) (interface{}, error) {
return nil, nil
}
-func funcLitFromIdent(ident *ast.Ident) *ast.FuncLit {
+// TODO: doc
+func funcLitFromIdent(funcIdent *ast.Ident) (funcLit *ast.FuncLit, ok bool) {
assignStmt, ok := funcIdent.Obj.Decl.(*ast.AssignStmt)
- if ok {
- // TODO: Get assignStmt.Rhs[position of ident name in assignStmt.Lhs]
- for _, expr := range assignStmt.Rhs {
- fl, ok := expr.(*ast.FuncLit)
- if ok {
- fmt.Printf("funclit: %#v\n", fl)
- }
+ if !ok {
+ return nil, ok
+ }
+
+ funcVariableName := funcIdent.Name
+
+ assignmentIndex := -1
+ for i, expr := range assignStmt.Lhs {
+ lhsIdent, ok := expr.(*ast.Ident)
+ if !ok {
+ continue
+ }
+
+ if lhsIdent.Name == funcVariableName {
+ assignmentIndex = i
+ break
}
}
+
+ if assignmentIndex == -1 {
+ return nil, false
+ }
+
+ // TODO: Get assignStmt.Rhs[position of ident name in assignStmt.Lhs]
+ // for _, expr := range assignStmt.Rhs {
+ // funcLit, ok := expr.(*ast.FuncLit)
+ // if !ok {
+ // fmt.Printf("funclit: %#v\n", fl)
+ // return nil, ok
+ // }
+ // }
+
+ if assignmentIndex > len(assignStmt.Rhs)-1 {
+ return nil, false
+ }
+
+ expr := assignStmt.Rhs[assignmentIndex]
+ funcLit, ok = expr.(*ast.FuncLit)
+ if !ok {
+ return nil, ok
+ }
+
+ return funcLit, true
}
// checkClosure reports variables used in funcLit that are captured from an