diff options
| -rw-r--r-- | capturedrefrace.go | 64 | 
1 files changed, 54 insertions, 10 deletions
| diff --git a/capturedrefrace.go b/capturedrefrace.go index 3d7dd3a..e9221a9 100644 --- a/capturedrefrace.go +++ b/capturedrefrace.go @@ -116,8 +116,17 @@ func run(pass *analysis.Pass) (interface{}, error) {  				}  			} -			// Look for a function literal after the `go` statement. -			funcLit, ok := goStmt.Call.Fun.(*ast.FuncLit) +			var funcLit *ast.FuncLit + +			switch goStmtFunc := goStmt.Call.Fun.(type) { +			case *ast.FuncLit: +				// Look for a function literal after the `go` statement. +				funcLit, ok = goStmt.Call.Fun.(*ast.FuncLit) + +			case *ast.Ident: +				// Get a function literal stored in a local variable. +				funcLit, ok = funcLitFromIdent(goStmtFunc) +			}  			if !ok {  				return  			} @@ -144,17 +153,52 @@ func run(pass *analysis.Pass) (interface{}, error) {  	return nil, nil  } -func funcLitFromIdent(ident *ast.Ident) *ast.FuncLit { +// TODO: doc +func funcLitFromIdent(funcIdent *ast.Ident) (funcLit *ast.FuncLit, ok bool) {  	assignStmt, ok := funcIdent.Obj.Decl.(*ast.AssignStmt) -	if ok { -		// TODO: Get assignStmt.Rhs[position of ident name in assignStmt.Lhs] -		for _, expr := range assignStmt.Rhs { -			fl, ok := expr.(*ast.FuncLit) -			if ok { -				fmt.Printf("funclit: %#v\n", fl) -			} +	if !ok { +		return nil, ok +	} + +	funcVariableName := funcIdent.Name + +	assignmentIndex := -1 +	for i, expr := range assignStmt.Lhs { +		lhsIdent, ok := expr.(*ast.Ident) +		if !ok { +			continue +		} + +		if lhsIdent.Name == funcVariableName { +			assignmentIndex = i +			break  		}  	} + +	if assignmentIndex == -1 { +		return nil, false +	} + +	// TODO: Get assignStmt.Rhs[position of ident name in assignStmt.Lhs] +	// for _, expr := range assignStmt.Rhs { +	// 	funcLit, ok := expr.(*ast.FuncLit) +	// 	if !ok { +	// 		fmt.Printf("funclit: %#v\n", fl) +	// 		return nil, ok +	// 	} +	// } + +	if assignmentIndex > len(assignStmt.Rhs)-1 { +		return nil, false +	} + +	expr := assignStmt.Rhs[assignmentIndex] +	funcLit, ok = expr.(*ast.FuncLit) +	if !ok { +		return nil, ok +	} + +	return funcLit, true  }  // checkClosure reports variables used in funcLit that are captured from an | 
