prysm-pulse/tools/analyzers/errcheck/analyzer.go
Preston Van Loon d5ddd012bc
Enforce error handling and checking type assertions (#5403)
* Enforce error handling and checking type assertions
* Reference issue #5404 in the TODO message
* doc description
* Merge branch 'master' into errcheck
* fix tests and address @nisdas feedbacK
* gaz
* fix docker image
2020-04-13 04:11:09 +00:00

443 lines
13 KiB
Go

// Package errcheck implements an static analysis analyzer to ensure that errors are handled in go
// code. This analyzer was adapted from https://github.com/kisielk/errcheck (MIT License).
package errcheck
import (
"errors"
"fmt"
"go/ast"
"go/token"
"go/types"
"golang.org/x/tools/go/analysis"
"golang.org/x/tools/go/analysis/passes/inspect"
"golang.org/x/tools/go/ast/inspector"
)
// Doc explaining the tool.
const Doc = "This tool enforces all errors must be handled and that type assertions test that " +
"the type implements the given interface to prevent runtime panics."
// Analyzer runs static analysis.
var Analyzer = &analysis.Analyzer{
Name: "errcheck",
Doc: Doc,
Requires: []*analysis.Analyzer{inspect.Analyzer},
Run: run,
}
var exclusions = make(map[string]bool)
func init() {
for _, exc := range [...]string{
// bytes
"(*bytes.Buffer).Write",
"(*bytes.Buffer).WriteByte",
"(*bytes.Buffer).WriteRune",
"(*bytes.Buffer).WriteString",
// fmt
"fmt.Errorf",
"fmt.Print",
"fmt.Printf",
"fmt.Println",
"fmt.Fprint(*bytes.Buffer)",
"fmt.Fprintf(*bytes.Buffer)",
"fmt.Fprintln(*bytes.Buffer)",
"fmt.Fprint(*strings.Builder)",
"fmt.Fprintf(*strings.Builder)",
"fmt.Fprintln(*strings.Builder)",
"fmt.Fprint(os.Stderr)",
"fmt.Fprintf(os.Stderr)",
"fmt.Fprintln(os.Stderr)",
// math/rand
"math/rand.Read",
"(*math/rand.Rand).Read",
// hash
"(hash.Hash).Write",
} {
exclusions[exc] = true
}
}
func run(pass *analysis.Pass) (interface{}, error) {
inspect, ok := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
if !ok {
return nil, errors.New("analyzer is not type *inspector.Inspector")
}
nodeFilter := []ast.Node{
(*ast.CallExpr)(nil),
(*ast.ExprStmt)(nil),
(*ast.GoStmt)(nil),
(*ast.DeferStmt)(nil),
(*ast.AssignStmt)(nil),
}
inspect.Preorder(nodeFilter, func(node ast.Node) {
switch stmt := node.(type) {
case *ast.ExprStmt:
if call, ok := stmt.X.(*ast.CallExpr); ok {
if !ignoreCall(pass, call) && callReturnsError(pass, call) {
reportUnhandledError(pass, call.Lparen, call)
}
}
case *ast.GoStmt:
if !ignoreCall(pass, stmt.Call) && callReturnsError(pass, stmt.Call) {
reportUnhandledError(pass, stmt.Call.Lparen, stmt.Call)
}
case *ast.DeferStmt:
if !ignoreCall(pass, stmt.Call) && callReturnsError(pass, stmt.Call) {
reportUnhandledError(pass, stmt.Call.Lparen, stmt.Call)
}
case *ast.AssignStmt:
if len(stmt.Rhs) == 1 {
// single value on rhs; check against lhs identifiers
if call, ok := stmt.Rhs[0].(*ast.CallExpr); ok {
if ignoreCall(pass, call) {
break
}
isError := errorsByArg(pass, call)
for i := 0; i < len(stmt.Lhs); i++ {
if id, ok := stmt.Lhs[i].(*ast.Ident); ok {
// We shortcut calls to recover() because errorsByArg can't
// check its return types for errors since it returns interface{}.
if id.Name == "_" && (isRecover(pass, call) || isError[i]) {
reportUnhandledError(pass, id.NamePos, call)
}
}
}
} else if assert, ok := stmt.Rhs[0].(*ast.TypeAssertExpr); ok {
if assert.Type == nil {
// type switch
break
}
if len(stmt.Lhs) < 2 {
// assertion result not read
reportUnhandledTypeAssertion(pass, stmt.Rhs[0].Pos())
} else if id, ok := stmt.Lhs[1].(*ast.Ident); ok && id.Name == "_" {
// assertion result ignored
reportUnhandledTypeAssertion(pass, id.NamePos)
}
}
} else {
// multiple value on rhs; in this case a call can't return
// multiple values. Assume len(stmt.Lhs) == len(stmt.Rhs)
for i := 0; i < len(stmt.Lhs); i++ {
if id, ok := stmt.Lhs[i].(*ast.Ident); ok {
if call, ok := stmt.Rhs[i].(*ast.CallExpr); ok {
if ignoreCall(pass, call) {
continue
}
if id.Name == "_" && callReturnsError(pass, call) {
reportUnhandledError(pass, id.NamePos, call)
}
} else if assert, ok := stmt.Rhs[i].(*ast.TypeAssertExpr); ok {
if assert.Type == nil {
// Shouldn't happen anyway, no multi assignment in type switches
continue
}
reportUnhandledError(pass, id.NamePos, nil)
}
}
}
}
default:
}
})
return nil, nil
}
func reportUnhandledError(pass *analysis.Pass, pos token.Pos, call *ast.CallExpr) {
pass.Reportf(pos, "Unhandled error for function call %s", fullName(pass, call))
}
func reportUnhandledTypeAssertion(pass *analysis.Pass, pos token.Pos) {
pass.Reportf(pos, "Unhandled type assertion check. You must test whether or not an "+
"interface implements the asserted type.")
}
func fullName(pass *analysis.Pass, call *ast.CallExpr) string {
_, fn, ok := selectorAndFunc(pass, call)
if !ok {
return ""
}
return fn.FullName()
}
// selectorAndFunc tries to get the selector and function from call expression.
// For example, given the call expression representing "a.b()", the selector
// is "a.b" and the function is "b" itself.
//
// The final return value will be true if it is able to do extract a selector
// from the call and look up the function object it refers to.
//
// If the call does not include a selector (like if it is a plain "f()" function call)
// then the final return value will be false.
func selectorAndFunc(pass *analysis.Pass, call *ast.CallExpr) (*ast.SelectorExpr, *types.Func, bool) {
if call == nil || call.Fun == nil {
return nil, nil, false
}
sel, ok := call.Fun.(*ast.SelectorExpr)
if !ok {
return nil, nil, false
}
fn, ok := pass.TypesInfo.ObjectOf(sel.Sel).(*types.Func)
if !ok {
return nil, nil, false
}
return sel, fn, true
}
func ignoreCall(pass *analysis.Pass, call *ast.CallExpr) bool {
for _, name := range namesForExcludeCheck(pass, call) {
if exclusions[name] {
return true
}
}
return false
}
var errorType = types.Universe.Lookup("error").Type().Underlying().(*types.Interface)
func isErrorType(t types.Type) bool {
return types.Implements(t, errorType)
}
func callReturnsError(pass *analysis.Pass, call *ast.CallExpr) bool {
if isRecover(pass, call) {
return true
}
for _, isError := range errorsByArg(pass, call) {
if isError {
return true
}
}
return false
}
// errorsByArg returns a slice s such that
// len(s) == number of return types of call
// s[i] == true iff return type at position i from left is an error type
func errorsByArg(pass *analysis.Pass, call *ast.CallExpr) []bool {
switch t := pass.TypesInfo.Types[call].Type.(type) {
case *types.Named:
// Single return
return []bool{isErrorType(t)}
case *types.Pointer:
// Single return via pointer
return []bool{isErrorType(t)}
case *types.Tuple:
// Multiple returns
s := make([]bool, t.Len())
for i := 0; i < t.Len(); i++ {
switch et := t.At(i).Type().(type) {
case *types.Named:
// Single return
s[i] = isErrorType(et)
case *types.Pointer:
// Single return via pointer
s[i] = isErrorType(et)
default:
s[i] = false
}
}
return s
}
return []bool{false}
}
func isRecover(pass *analysis.Pass, call *ast.CallExpr) bool {
if fun, ok := call.Fun.(*ast.Ident); ok {
if _, ok := pass.TypesInfo.Uses[fun].(*types.Builtin); ok {
return fun.Name == "recover"
}
}
return false
}
func namesForExcludeCheck(pass *analysis.Pass, call *ast.CallExpr) []string {
sel, fn, ok := selectorAndFunc(pass, call)
if !ok {
return nil
}
name := fullName(pass, call)
if name == "" {
return nil
}
// This will be missing for functions without a receiver (like fmt.Printf),
// so just fall back to the the function's fullName in that case.
selection, ok := pass.TypesInfo.Selections[sel]
if !ok {
return []string{name}
}
// This will return with ok false if the function isn't defined
// on an interface, so just fall back to the fullName.
ts, ok := walkThroughEmbeddedInterfaces(selection)
if !ok {
return []string{name}
}
result := make([]string, len(ts))
for i, t := range ts {
// Like in fullName, vendored packages will have /vendor/ in their name,
// thus not matching vendored standard library packages. If we
// want to support vendored stdlib packages, we need to implement
// additional logic here.
result[i] = fmt.Sprintf("(%s).%s", t.String(), fn.Name())
}
return result
}
// walkThroughEmbeddedInterfaces returns a slice of Interfaces that
// we need to walk through in order to reach the actual definition,
// in an Interface, of the method selected by the given selection.
//
// false will be returned in the second return value if:
// - the right side of the selection is not a function
// - the actual definition of the function is not in an Interface
//
// The returned slice will contain all the interface types that need
// to be walked through to reach the actual definition.
//
// For example, say we have:
//
// type Inner interface {Method()}
// type Middle interface {Inner}
// type Outer interface {Middle}
// type T struct {Outer}
// type U struct {T}
// type V struct {U}
//
// And then the selector:
//
// V.Method
//
// We'll return [Outer, Middle, Inner] by first walking through the embedded structs
// until we reach the Outer interface, then descending through the embedded interfaces
// until we find the one that actually explicitly defines Method.
func walkThroughEmbeddedInterfaces(sel *types.Selection) ([]types.Type, bool) {
fn, ok := sel.Obj().(*types.Func)
if !ok {
return nil, false
}
// Start off at the receiver.
currentT := sel.Recv()
// First, we can walk through any Struct fields provided
// by the selection Index() method. We ignore the last
// index because it would give the method itself.
indexes := sel.Index()
for _, fieldIndex := range indexes[:len(indexes)-1] {
currentT = getTypeAtFieldIndex(currentT, fieldIndex)
}
// Now currentT is either a type implementing the actual function,
// an Invalid type (if the receiver is a package), or an interface.
//
// If it's not an Interface, then we're done, as this function
// only cares about Interface-defined functions.
//
// If it is an Interface, we potentially need to continue digging until
// we find the Interface that actually explicitly defines the function.
interfaceT, ok := maybeUnname(currentT).(*types.Interface)
if !ok {
return nil, false
}
// The first interface we pass through is this one we've found. We return the possibly
// wrapping types.Named because it is more useful to work with for callers.
result := []types.Type{currentT}
// If this interface itself explicitly defines the given method
// then we're done digging.
for !explicitlyDefinesMethod(interfaceT, fn) {
// Otherwise, we find which of the embedded interfaces _does_
// define the method, add it to our list, and loop.
namedInterfaceT, ok := getEmbeddedInterfaceDefiningMethod(interfaceT, fn)
if !ok {
// This should be impossible as long as we type-checked: either the
// interface or one of its embedded ones must implement the method...
panic(fmt.Sprintf("either %v or one of its embedded interfaces must implement %v", currentT, fn))
}
result = append(result, namedInterfaceT)
interfaceT, ok = namedInterfaceT.Underlying().(*types.Interface)
if !ok {
panic(fmt.Sprintf("either %v or one of its embedded interfaces must implement %v", currentT, fn))
}
}
return result, true
}
func getTypeAtFieldIndex(startingAt types.Type, fieldIndex int) types.Type {
t := maybeUnname(maybeDereference(startingAt))
s, ok := t.(*types.Struct)
if !ok {
panic(fmt.Sprintf("cannot get Field of a type that is not a struct, got a %T", t))
}
return s.Field(fieldIndex).Type()
}
// getEmbeddedInterfaceDefiningMethod searches through any embedded interfaces of the
// passed interface searching for one that defines the given function. If found, the
// types.Named wrapping that interface will be returned along with true in the second value.
//
// If no such embedded interface is found, nil and false are returned.
func getEmbeddedInterfaceDefiningMethod(interfaceT *types.Interface, fn *types.Func) (*types.Named, bool) {
for i := 0; i < interfaceT.NumEmbeddeds(); i++ {
embedded := interfaceT.Embedded(i)
if definesMethod(embedded.Underlying().(*types.Interface), fn) {
return embedded, true
}
}
return nil, false
}
func explicitlyDefinesMethod(interfaceT *types.Interface, fn *types.Func) bool {
for i := 0; i < interfaceT.NumExplicitMethods(); i++ {
if interfaceT.ExplicitMethod(i) == fn {
return true
}
}
return false
}
func definesMethod(interfaceT *types.Interface, fn *types.Func) bool {
for i := 0; i < interfaceT.NumMethods(); i++ {
if interfaceT.Method(i) == fn {
return true
}
}
return false
}
func maybeDereference(t types.Type) types.Type {
p, ok := t.(*types.Pointer)
if ok {
return p.Elem()
}
return t
}
func maybeUnname(t types.Type) types.Type {
n, ok := t.(*types.Named)
if ok {
return n.Underlying()
}
return t
}