diff options
| -rw-r--r-- | gocapturedrefrace.go | 76 | 
1 files changed, 40 insertions, 36 deletions
| diff --git a/gocapturedrefrace.go b/gocapturedrefrace.go index 1404f78..a3c1699 100644 --- a/gocapturedrefrace.go +++ b/gocapturedrefrace.go @@ -25,55 +25,59 @@ import (  	"go/types"  	"golang.org/x/tools/go/analysis" +	"golang.org/x/tools/go/analysis/passes/inspect" +	"golang.org/x/tools/go/ast/inspector"  )  var version = "0.0.1"  var Analyzer = &analysis.Analyzer{ -	Name: "gocapturedrefrace", -	Doc:  "reports captured references in goroutine closures", -	Run:  run, +	Name:     "gocapturedrefrace", +	Doc:      "reports captured references in goroutine closures", +	Run:      run, +	Requires: []*analysis.Analyzer{inspect.Analyzer},  }  func run(pass *analysis.Pass) (interface{}, error) { -	// TODO: Since we're calling ast.Inspect a bunch of times, maybe it's worthwhile using passes/inspect now. -	for _, file := range pass.Files { -		ast.Inspect( -			file, -			func(node ast.Node) bool { -				// Find `go` statements. -				goStmt, ok := node.(*ast.GoStmt) -				if !ok { -					return true -				} +	inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector) -				// Look for a function literal after the `go` statement. -				funcLit, ok := goStmt.Call.Fun.(*ast.FuncLit) -				if !ok { -					return true -				} +	nodeFilter := []ast.Node{ +		(*ast.GoStmt)(nil), +	} -				// Inspect closure argument list. -				for _, arg := range funcLit.Type.Params.List { -					// Report reference arguments. -					_, ok := arg.Type.(*ast.StarExpr) -					if !ok { -						continue -					} - -					pass.Reportf( -						arg.Pos(), -						"reference %s in goroutine closure", -						arg.Names[0], -					) +	inspect.Preorder( +		nodeFilter, +		func(node ast.Node) { +			// Find `go` statements. +			goStmt, ok := node.(*ast.GoStmt) +			if !ok { +				return +			} + +			// Look for a function literal after the `go` statement. +			funcLit, ok := goStmt.Call.Fun.(*ast.FuncLit) +			if !ok { +				return +			} + +			// Inspect closure argument list. +			for _, arg := range funcLit.Type.Params.List { +				// Report reference arguments. +				_, ok := arg.Type.(*ast.StarExpr) +				if !ok { +					continue  				} -				checkClosure(pass, funcLit) +				pass.Reportf( +					arg.Pos(), +					"reference %s in goroutine closure", +					arg.Names[0], +				) +			} -				return true -			}, -		) -	} +			checkClosure(pass, funcLit) +		}, +	)  	return nil, nil  } | 
