diff options
| -rw-r--r-- | defererr.go | 59 | 
1 files changed, 11 insertions, 48 deletions
| diff --git a/defererr.go b/defererr.go index dfaffc4..536a87a 100644 --- a/defererr.go +++ b/defererr.go @@ -171,69 +171,32 @@ func checkFunctionReturns(  				return true  			} +			// Ignore any `return` statements before the end of the `defer` +			// closure.  			if returnStmt.Pos() <= fState.firstErrorDeferEndPos {  				return true  			} -			// TODO: Check whether returnStmt uses error variable. -			fmt.Printf("returnStmt: %#v\n", returnStmt) -  			if returnStmt.Results == nil {  				return true  			} -			fmt.Printf("returnStmt.Results: %#v\n", returnStmt.Results) - -			for _, expr := range returnStmt.Results { -				fmt.Printf("returnStmt expr: %#v\n", expr) - -				t := pass.TypesInfo.Types[expr] -				fmt.Printf("returnStmt expr type: %#v\n", t) -			} - -			// TODO: Get returnStmt.Results[error index from function result signature] -			// If not variable and name not [error variable name from defer], report diagnostic +			// Get the value used when returning the error.  			returnErrorExpr := returnStmt.Results[errorReturnIndex] -			t := pass.TypesInfo.Types[returnErrorExpr] -			fmt.Printf("returnStmt value type: %#v\n", t) -			fmt.Printf("returnStmt type type: %#v\n", t.Type) -  			returnErrorIdent, ok := returnErrorExpr.(*ast.Ident)  			if !ok {  				return true  			} -			// TODO: Require t.Type to be *types.Named -			_, ok = t.Type.(*types.Named) -			if !ok { -				// TODO: report +			_, isReturnErrorNamedType := +				pass.TypesInfo.Types[returnErrorExpr].Type.(*types.Named) -				pass.Reportf( -					returnErrorIdent.Pos(), -					"does not return '%s'", -					fState.deferErrorVar, -				) - -				return true -			} - -			// Or, we want to compare with the error declared in the function signature. -			fmt.Printf("returnError: %#v\n", returnErrorExpr) - -			if returnErrorIdent.Name == fState.deferErrorVar.Name { -				fmt.Printf( -					"names: return:%#v : defer:%#v\n", -					returnErrorIdent.Name, -					fState.deferErrorVar.Name, -				) -			} - -			if returnErrorIdent.Name != fState.deferErrorVar.Name { -				fmt.Printf( -					"names: return:%#v : defer:%#v\n", -					returnErrorIdent.Name, -					fState.deferErrorVar.Name, -				) +			// Ensure the value used when returning the error is a named type +			// (checking that no nil constant is used), and that the name of +			// the error variable used in the `return` statement matches the +			// name of the error variable assigned in the `defer` closure. +			if !isReturnErrorNamedType || +				returnErrorIdent.Name != fState.deferErrorVar.Name {  				pass.Reportf(  					returnErrorIdent.Pos(), | 
