aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorTeddy Wing2023-05-24 01:28:44 +0200
committerTeddy Wing2023-05-24 01:28:44 +0200
commitbc386735703ac9bb9d26ef1011f5d2c84dd32ec1 (patch)
tree0a1094cc5d2adaef5a71508ca7dc8a6011a3954e
parente9ff801c405082cbfbcc99ba54e5a07f0af908e8 (diff)
downloadgodefererr-bc386735703ac9bb9d26ef1011f5d2c84dd32ec1.tar.bz2
Clean up `checkFunctionReturns()` function
-rw-r--r--defererr.go59
1 files changed, 11 insertions, 48 deletions
diff --git a/defererr.go b/defererr.go
index dfaffc4..536a87a 100644
--- a/defererr.go
+++ b/defererr.go
@@ -171,69 +171,32 @@ func checkFunctionReturns(
return true
}
+ // Ignore any `return` statements before the end of the `defer`
+ // closure.
if returnStmt.Pos() <= fState.firstErrorDeferEndPos {
return true
}
- // TODO: Check whether returnStmt uses error variable.
- fmt.Printf("returnStmt: %#v\n", returnStmt)
-
if returnStmt.Results == nil {
return true
}
- fmt.Printf("returnStmt.Results: %#v\n", returnStmt.Results)
-
- for _, expr := range returnStmt.Results {
- fmt.Printf("returnStmt expr: %#v\n", expr)
-
- t := pass.TypesInfo.Types[expr]
- fmt.Printf("returnStmt expr type: %#v\n", t)
- }
-
- // TODO: Get returnStmt.Results[error index from function result signature]
- // If not variable and name not [error variable name from defer], report diagnostic
+ // Get the value used when returning the error.
returnErrorExpr := returnStmt.Results[errorReturnIndex]
- t := pass.TypesInfo.Types[returnErrorExpr]
- fmt.Printf("returnStmt value type: %#v\n", t)
- fmt.Printf("returnStmt type type: %#v\n", t.Type)
-
returnErrorIdent, ok := returnErrorExpr.(*ast.Ident)
if !ok {
return true
}
- // TODO: Require t.Type to be *types.Named
- _, ok = t.Type.(*types.Named)
- if !ok {
- // TODO: report
+ _, isReturnErrorNamedType :=
+ pass.TypesInfo.Types[returnErrorExpr].Type.(*types.Named)
- pass.Reportf(
- returnErrorIdent.Pos(),
- "does not return '%s'",
- fState.deferErrorVar,
- )
-
- return true
- }
-
- // Or, we want to compare with the error declared in the function signature.
- fmt.Printf("returnError: %#v\n", returnErrorExpr)
-
- if returnErrorIdent.Name == fState.deferErrorVar.Name {
- fmt.Printf(
- "names: return:%#v : defer:%#v\n",
- returnErrorIdent.Name,
- fState.deferErrorVar.Name,
- )
- }
-
- if returnErrorIdent.Name != fState.deferErrorVar.Name {
- fmt.Printf(
- "names: return:%#v : defer:%#v\n",
- returnErrorIdent.Name,
- fState.deferErrorVar.Name,
- )
+ // Ensure the value used when returning the error is a named type
+ // (checking that no nil constant is used), and that the name of
+ // the error variable used in the `return` statement matches the
+ // name of the error variable assigned in the `defer` closure.
+ if !isReturnErrorNamedType ||
+ returnErrorIdent.Name != fState.deferErrorVar.Name {
pass.Reportf(
returnErrorIdent.Pos(),