diff options
author | Teddy Wing | 2023-05-20 14:57:21 +0200 |
---|---|---|
committer | Teddy Wing | 2023-05-20 14:57:21 +0200 |
commit | 8ac64e76c32a1f803460a71d29efd235d9be8edd (patch) | |
tree | 1b91bfd38ad5f2c341d9097ff6642a86327435f3 | |
parent | 68571cceeb663340bc196a6f01e022713a1eb233 (diff) | |
download | gocapturedrefrace-8ac64e76c32a1f803460a71d29efd235d9be8edd.tar.bz2 |
Add support for function literals defined in local variables
In addition to checking function literals after the `go` statement, also
check closures assigned to variables.
-rw-r--r-- | capturedrefrace.go | 64 |
1 files changed, 54 insertions, 10 deletions
diff --git a/capturedrefrace.go b/capturedrefrace.go index 3d7dd3a..e9221a9 100644 --- a/capturedrefrace.go +++ b/capturedrefrace.go @@ -116,8 +116,17 @@ func run(pass *analysis.Pass) (interface{}, error) { } } - // Look for a function literal after the `go` statement. - funcLit, ok := goStmt.Call.Fun.(*ast.FuncLit) + var funcLit *ast.FuncLit + + switch goStmtFunc := goStmt.Call.Fun.(type) { + case *ast.FuncLit: + // Look for a function literal after the `go` statement. + funcLit, ok = goStmt.Call.Fun.(*ast.FuncLit) + + case *ast.Ident: + // Get a function literal stored in a local variable. + funcLit, ok = funcLitFromIdent(goStmtFunc) + } if !ok { return } @@ -144,17 +153,52 @@ func run(pass *analysis.Pass) (interface{}, error) { return nil, nil } -func funcLitFromIdent(ident *ast.Ident) *ast.FuncLit { +// TODO: doc +func funcLitFromIdent(funcIdent *ast.Ident) (funcLit *ast.FuncLit, ok bool) { assignStmt, ok := funcIdent.Obj.Decl.(*ast.AssignStmt) - if ok { - // TODO: Get assignStmt.Rhs[position of ident name in assignStmt.Lhs] - for _, expr := range assignStmt.Rhs { - fl, ok := expr.(*ast.FuncLit) - if ok { - fmt.Printf("funclit: %#v\n", fl) - } + if !ok { + return nil, ok + } + + funcVariableName := funcIdent.Name + + assignmentIndex := -1 + for i, expr := range assignStmt.Lhs { + lhsIdent, ok := expr.(*ast.Ident) + if !ok { + continue + } + + if lhsIdent.Name == funcVariableName { + assignmentIndex = i + break } } + + if assignmentIndex == -1 { + return nil, false + } + + // TODO: Get assignStmt.Rhs[position of ident name in assignStmt.Lhs] + // for _, expr := range assignStmt.Rhs { + // funcLit, ok := expr.(*ast.FuncLit) + // if !ok { + // fmt.Printf("funclit: %#v\n", fl) + // return nil, ok + // } + // } + + if assignmentIndex > len(assignStmt.Rhs)-1 { + return nil, false + } + + expr := assignStmt.Rhs[assignmentIndex] + funcLit, ok = expr.(*ast.FuncLit) + if !ok { + return nil, ok + } + + return funcLit, true } // checkClosure reports variables used in funcLit that are captured from an |