prysm-pulse/tools/analyzers/recursivelock/analyzer.go
Nishant Das 2a7a09b112
Add Lock Analyzer (#10430)
* add lock analyzer

* fix locks

* progress

* fix failures

* fix error log

Co-authored-by: Raul Jordan <raul@prysmaticlabs.com>
2022-04-05 16:39:48 +00:00

644 lines
18 KiB
Go

// Analyzer tool for detecting nested or recursive mutex read lock statements
package recursivelock
import (
"errors"
"fmt"
"go/ast"
"go/token"
"go/types"
"golang.org/x/tools/go/analysis"
"golang.org/x/tools/go/analysis/passes/inspect"
"golang.org/x/tools/go/ast/astutil"
"golang.org/x/tools/go/ast/inspector"
"golang.org/x/tools/go/types/typeutil"
)
// Analyzer runs static analysis.
var Analyzer = &analysis.Analyzer{
Name: "recursivelock",
Doc: "Checks for recursive or nested RLock calls",
Requires: []*analysis.Analyzer{inspect.Analyzer},
Run: run,
}
var errNestedRLock = errors.New("found recursive read lock call")
var errNestedLock = errors.New("found recursive lock call")
var errNestedMixedLock = errors.New("found recursive mixed lock call")
type mode int
const (
LockMode = mode(iota)
RLockMode
)
func (m mode) LockName() string {
switch m {
case LockMode:
return "Lock"
case RLockMode:
return "RLock"
}
return ""
}
func (m mode) UnLockName() string {
switch m {
case LockMode:
return "Unlock"
case RLockMode:
return "RUnlock"
}
return ""
}
func (m mode) ErrorFound() error {
switch m {
case LockMode:
return errNestedLock
case RLockMode:
return errNestedRLock
}
return nil
}
func run(pass *analysis.Pass) (interface{}, error) {
inspect, ok := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
if !ok {
return nil, errors.New("analyzer is not type *inspector.Inspector")
}
nodeFilter := []ast.Node{
(*ast.GoStmt)(nil),
(*ast.CallExpr)(nil),
(*ast.DeferStmt)(nil),
(*ast.FuncDecl)(nil),
(*ast.FuncLit)(nil),
(*ast.File)(nil),
(*ast.IfStmt)(nil),
(*ast.ReturnStmt)(nil),
}
keepTrackOf := &tracker{
rLockTrack: &lockTracker{},
lockTrack: &lockTracker{},
}
inspect.Preorder(nodeFilter, func(node ast.Node) {
if keepTrackOf.rLockTrack.funcLitEnd.IsValid() && node.Pos() <= keepTrackOf.rLockTrack.funcLitEnd &&
keepTrackOf.lockTrack.funcLitEnd.IsValid() && node.Pos() <= keepTrackOf.lockTrack.funcLitEnd {
return
}
keepTrackOf.rLockTrack.funcLitEnd = token.NoPos
keepTrackOf.lockTrack.funcLitEnd = token.NoPos
if keepTrackOf.rLockTrack.deferEnd.IsValid() && node.Pos() > keepTrackOf.rLockTrack.deferEnd {
keepTrackOf.rLockTrack.deferEnd = token.NoPos
} else if keepTrackOf.rLockTrack.deferEnd.IsValid() {
return
}
if keepTrackOf.lockTrack.deferEnd.IsValid() && node.Pos() > keepTrackOf.lockTrack.deferEnd {
keepTrackOf.lockTrack.deferEnd = token.NoPos
} else if keepTrackOf.lockTrack.deferEnd.IsValid() {
return
}
if keepTrackOf.rLockTrack.retEnd.IsValid() && node.Pos() > keepTrackOf.rLockTrack.retEnd {
keepTrackOf.rLockTrack.retEnd = token.NoPos
keepTrackOf.rLockTrack.incFRU()
}
if keepTrackOf.lockTrack.retEnd.IsValid() && node.Pos() > keepTrackOf.lockTrack.retEnd {
keepTrackOf.lockTrack.retEnd = token.NoPos
keepTrackOf.lockTrack.incFRU()
}
keepTrackOf = stmtSelector(node, pass, keepTrackOf, inspect)
})
return nil, nil
}
func stmtSelector(node ast.Node, pass *analysis.Pass, keepTrackOf *tracker, inspect *inspector.Inspector) *tracker {
switch stmt := node.(type) {
case *ast.GoStmt:
keepTrackOf.rLockTrack.goroutinePos = stmt.Call.End()
keepTrackOf.lockTrack.goroutinePos = stmt.Call.End()
case *ast.CallExpr:
if stmt.End() == keepTrackOf.rLockTrack.goroutinePos ||
stmt.End() == keepTrackOf.lockTrack.goroutinePos {
keepTrackOf.rLockTrack.goroutinePos = 0
keepTrackOf.lockTrack.goroutinePos = 0
break
}
call := getCallInfo(pass.TypesInfo, stmt)
if call == nil {
break
}
selMap := mapSelTypes(stmt, pass)
if selMap == nil {
break
}
checkForRecLocks(node, pass, inspect, RLockMode, call, keepTrackOf.rLockTrack, selMap)
checkForRecLocks(node, pass, inspect, LockMode, call, keepTrackOf.lockTrack, selMap)
case *ast.File:
keepTrackOf = &tracker{
rLockTrack: &lockTracker{},
lockTrack: &lockTracker{},
}
case *ast.FuncDecl:
keepTrackOf = &tracker{
rLockTrack: &lockTracker{},
lockTrack: &lockTracker{},
}
keepTrackOf.rLockTrack.funcEnd = stmt.End()
case *ast.FuncLit:
if keepTrackOf.rLockTrack.funcLitEnd == token.NoPos {
keepTrackOf.rLockTrack.funcLitEnd = stmt.End()
}
if keepTrackOf.lockTrack.funcLitEnd == token.NoPos {
keepTrackOf.lockTrack.funcLitEnd = stmt.End()
}
case *ast.IfStmt:
stmts := stmt.Body.List
for i := 0; i < len(stmts); i++ {
keepTrackOf = stmtSelector(stmts[i], pass, keepTrackOf, inspect)
}
keepTrackOf = stmtSelector(stmt.Else, pass, keepTrackOf, inspect)
case *ast.DeferStmt:
call := getCallInfo(pass.TypesInfo, stmt.Call)
if keepTrackOf.rLockTrack.deferEnd == token.NoPos {
keepTrackOf.rLockTrack.deferEnd = stmt.End()
}
if keepTrackOf.lockTrack.deferEnd == token.NoPos {
keepTrackOf.lockTrack.deferEnd = stmt.End()
}
if call != nil && call.name == RLockMode.UnLockName() {
keepTrackOf.rLockTrack.deferredRUnlock = true
}
if call != nil && call.name == LockMode.UnLockName() {
keepTrackOf.lockTrack.deferredRUnlock = true
}
case *ast.ReturnStmt:
for i := 0; i < len(stmt.Results); i++ {
keepTrackOf = stmtSelector(stmt.Results[i], pass, keepTrackOf, inspect)
}
if keepTrackOf.rLockTrack.deferredRUnlock && keepTrackOf.rLockTrack.retEnd == token.NoPos {
keepTrackOf.rLockTrack.deincFRU()
keepTrackOf.rLockTrack.retEnd = stmt.End()
}
if keepTrackOf.lockTrack.deferredRUnlock && keepTrackOf.lockTrack.retEnd == token.NoPos {
keepTrackOf.lockTrack.deincFRU()
keepTrackOf.lockTrack.retEnd = stmt.End()
}
}
return keepTrackOf
}
type tracker struct {
rLockTrack *lockTracker
lockTrack *lockTracker
}
type lockTracker struct {
funcEnd token.Pos
retEnd token.Pos
deferEnd token.Pos
funcLitEnd token.Pos
goroutinePos token.Pos
deferredRUnlock bool
foundRLock int
rLockSelector *selIdentList
}
func (t lockTracker) String() string {
return fmt.Sprintf("funcEnd:%v\nretEnd:%v\ndeferEnd:%v\ndeferredRU:%v\nfoundRLock:%v\n", t.funcEnd, t.retEnd, t.deferEnd, t.deferredRUnlock, t.foundRLock)
}
func (t *lockTracker) deincFRU() {
if t.foundRLock > 0 {
t.foundRLock -= 1
}
}
func (t *lockTracker) incFRU() {
t.foundRLock += 1
}
func checkForRecLocks(node ast.Node, pass *analysis.Pass, inspect *inspector.Inspector, lockmode mode, call *callInfo,
lockTracker *lockTracker, selMap *selIdentList) {
name := call.name
if lockTracker.rLockSelector != nil {
if lockTracker.foundRLock > 0 {
if lockTracker.rLockSelector.isRelated(selMap, 0) {
pass.Reportf(
node.Pos(),
fmt.Sprintf(
"%v",
errNestedMixedLock,
),
)
}
if lockTracker.rLockSelector.isEqual(selMap, 0) {
pass.Reportf(
node.Pos(),
fmt.Sprintf(
"%v",
lockmode.ErrorFound(),
),
)
} else {
if stack := hasNestedlock(lockTracker.rLockSelector, lockTracker.goroutinePos, selMap, call, inspect, pass, make(map[string]bool),
lockmode.UnLockName()); stack != "" {
pass.Reportf(
node.Pos(),
fmt.Sprintf(
"%v\n%v",
lockmode.ErrorFound(),
stack,
),
)
}
}
}
if name == lockmode.UnLockName() && lockTracker.rLockSelector.isEqual(selMap, 1) {
lockTracker.deincFRU()
}
if name == lockmode.LockName() && lockTracker.foundRLock == 0 && lockTracker.rLockSelector.isEqual(selMap, 0) {
lockTracker.incFRU()
}
} else if name == lockmode.LockName() && lockTracker.foundRLock == 0 {
lockTracker.rLockSelector = selMap
lockTracker.incFRU()
}
}
// Stores the AST and type information of a single item in a selector expression
// For example, "a.b.c()", a selIdentNode might store the information for "a"
type selIdentNode struct {
next *selIdentNode
this *ast.Ident
typObj types.Object
}
// a list of selIdentNodes. Stores the information of an entire selector expression
// For example, each item in "a.b.c()" is stored as a node in this list, with the start node being "a"
type selIdentList struct {
start *selIdentNode
length int
current *selIdentNode // used for internal functions
currentIndex int // used for internal functions
}
// returns the next item in the list, and increments the counter keeping track of where we are in the list
func (s *selIdentList) next() (n *selIdentNode) {
n = s.current.next
if n != nil {
s.current = n
s.currentIndex++
}
return n
}
// reset resets the current node to the start node in the list
func (s *selIdentList) reset() {
s.current = s.start
s.currentIndex = 0
}
// isEqual returns true if two selIdentLists are equal to each other.
// The offset parameter tells how far in the list to check for equality.
// For example, a.b.c() and a.b.d() are equal with an offset of 1.
func (s *selIdentList) isEqual(s2 *selIdentList, offset int) bool {
if s2 == nil || (s.length != s2.length) {
return false
}
s.reset()
s2.reset()
for i := true; i; {
if !s.current.isEqual(s2.current) {
return false
}
if s.currentIndex < s.length-offset-1 && s.next() != nil {
s2.next()
} else {
i = false
}
}
return true
}
// isRelated checks if our selectors are of the same type and
// reference the same underlying object. If they do we check
// if the provided list is referencing a non-equal but related
// lock. Ex: Lock - RLock, RLock - Lock
// TODO: Use a generalizable method here instead of hardcoding
// the lock definitions here.
func (s *selIdentList) isRelated(s2 *selIdentList, offset int) bool {
if s2 == nil || (s.length != s2.length) {
return false
}
s.reset()
s2.reset()
for i := true; i; {
if !s.current.isEqual(s2.current) {
return false
}
if s.currentIndex < s.length-offset-1 && s.next() != nil {
s2.next()
} else {
i = false
}
// Only check if we are at the last index for
// related method calls.
if s.currentIndex == s.length-1 {
switch s.current.this.String() {
case LockMode.LockName():
if s2.current.this.String() == RLockMode.LockName() {
return true
}
case RLockMode.LockName():
if s2.current.this.String() == LockMode.LockName() {
return true
}
}
}
}
return false
}
// getSub returns the shared beginning selIdentList of s and s2,
// if s contains all elements (except the last) of s2,
// and returns nil otherwise.
// For example, if s represents "a.b.c.d()" and s2 represents
// "a.b.e()", getSub will return a selIdentList representing "a.b".
// getSub returns nil if s2's length is greater than that of s
func (s *selIdentList) getSub(s2 *selIdentList) *selIdentList {
if s2 == nil || s2.length > s.length {
return nil
}
s.reset()
s2.reset()
for i := true; i; {
if !s.current.isEqual(s2.current) {
return nil
}
if s2.currentIndex != s2.length-2 { // might want to add a selNode.prev() func
s.next()
s2.next()
} else {
i = false
}
}
return &selIdentList{
start: s.current,
length: s.length - s.currentIndex,
current: s.current,
currentIndex: 0,
}
}
// changeRoot changes the first selIdentNode of a selIdentList
// to one with given *ast.Ident and types.Object
func (s *selIdentList) changeRoot(r *ast.Ident, t types.Object) {
selNode := &selIdentNode{
this: r,
next: s.start.next,
typObj: t,
}
if s.start == s.current {
s.start = selNode
s.current = selNode
} else {
s.start = selNode
}
}
func (s selIdentList) String() (str string) {
var temp *selIdentNode = s.start
str = fmt.Sprintf("length: %v\n[\n", s.length)
for i := 0; temp != nil; i++ {
if i == s.currentIndex {
str += "*"
}
str += fmt.Sprintf("%v: %v\n", i, temp)
temp = temp.next
}
str += "]"
return str
}
func (s *selIdentNode) isEqual(s2 *selIdentNode) bool {
return (s.this.Name == s2.this.Name) && (s.typObj == s2.typObj)
}
func (s selIdentNode) String() string {
return fmt.Sprintf("{ ident: '%v', type: '%v' }", s.this, s.typObj)
}
// mapSelTypes returns a selIdentList representation of the given call expression
func mapSelTypes(c *ast.CallExpr, pass *analysis.Pass) *selIdentList {
list := &selIdentList{}
valid := list.recurMapSelTypes(c.Fun, nil, pass.TypesInfo)
if !valid {
return nil
}
return list
}
// recursively identifies the type of each identity node in a selector expression
func (l *selIdentList) recurMapSelTypes(e ast.Expr, next *selIdentNode, t *types.Info) bool {
expr := astutil.Unparen(e)
l.length++
s := &selIdentNode{next: next}
switch stmt := expr.(type) {
case *ast.Ident:
s.this = stmt
s.typObj = t.ObjectOf(stmt)
case *ast.SelectorExpr:
s.this = stmt.Sel
if sel, ok := t.Selections[stmt]; ok {
s.typObj = sel.Obj() // method or field
} else {
s.typObj = t.Uses[stmt.Sel] // qualified identifier?
}
return l.recurMapSelTypes(stmt.X, s, t)
default:
return false
}
l.current = s
l.start = s
return true
}
type callInfo struct {
call *ast.CallExpr
id string // String representation of the type object
name string // type ID [either the name (if the function is exported) or the package/name if otherwise] of the function/method
}
// getCallInfo returns a *callInfo struct with call info
func getCallInfo(tInfo *types.Info, call *ast.CallExpr) (c *callInfo) {
c = &callInfo{}
c.call = call
f := typeutil.Callee(tInfo, call)
if f == nil {
return nil
}
if _, isBuiltin := f.(*types.Builtin); isBuiltin {
return nil
}
s, ok := f.Type().(*types.Signature)
if ok && interfaceMethod(s) {
return nil
}
c.id = f.String()
c.name = f.Id()
return c
}
func interfaceMethod(s *types.Signature) bool {
recv := s.Recv()
return recv != nil && types.IsInterface(recv.Type())
}
// hasNestedlock returns a stack trace of the nested or recursive lock within the declaration of a function/method call (given by call).
// If the call expression does not contain a nested or recursive lock, hasNestedlock returns an empty string.
// hasNestedlock finds a nested or recursive lock by recursively calling itself on any functions called by the function/method represented
// by callInfo.
func hasNestedlock(fullRLockSelector *selIdentList, goPos token.Pos, compareMap *selIdentList, call *callInfo, inspect *inspector.Inspector,
pass *analysis.Pass, hist map[string]bool, lockName string) (retStack string) {
var rLockSelector *selIdentList
f := pass.Fset
tInfo := pass.TypesInfo
cH := callHelper{
call: call.call,
fset: pass.Fset,
}
var node ast.Node = cH.identifyFuncLitBlock(cH.call.Fun) // this seems a bit redundant
var recv *ast.Ident
if node == (*ast.BlockStmt)(nil) {
subMap := fullRLockSelector.getSub(compareMap)
if subMap != nil {
rLockSelector = subMap
} else {
return "" // if this is not a local function literal call, and the selectors don't match up, then we can just return
}
node = findCallDeclarationNode(call, inspect, pass.TypesInfo)
if node == (*ast.FuncDecl)(nil) {
return ""
} else if castedNode, ok := node.(*ast.FuncDecl); ok && castedNode.Recv != nil {
recv = castedNode.Recv.List[0].Names[0]
rLockSelector.changeRoot(recv, pass.TypesInfo.ObjectOf(recv))
}
} else {
rLockSelector = fullRLockSelector // no need to find a submap, since this is a local function call
}
addition := fmt.Sprintf("\t%q at %v\n", call.name, f.Position(call.call.Pos()))
ast.Inspect(node, func(iNode ast.Node) bool {
switch stmt := iNode.(type) {
case *ast.GoStmt:
goPos = stmt.End()
case *ast.CallExpr:
if stmt.End() == goPos {
goPos = 0
return false
}
c := getCallInfo(tInfo, stmt)
if c == nil {
return false
}
name := c.name
selMap := mapSelTypes(stmt, pass)
if rLockSelector.isEqual(selMap, 0) || rLockSelector.isRelated(selMap, 0) { // if the method found is an RLock method
retStack += addition + fmt.Sprintf("\t%q at %v\n", name, f.Position(iNode.Pos()))
} else if name != lockName { // name should not equal the previousName to prevent infinite recursive loop
nt := c.id
if !hist[nt] { // make sure we are not in an infinite recursive loop
hist[nt] = true
stack := hasNestedlock(rLockSelector, goPos, selMap, c, inspect, pass, hist, lockName)
delete(hist, nt)
if stack != "" {
retStack += addition + stack
}
}
}
}
return true
})
return retStack
}
// findCallDeclarationNode takes a callInfo struct and inspects the AST of the package
// to find a matching method or function declaration. It returns this declaration of type *ast.FuncDecl
func findCallDeclarationNode(c *callInfo, inspect *inspector.Inspector, tInfo *types.Info) *ast.FuncDecl {
var retNode *ast.FuncDecl = nil
nodeFilter := []ast.Node{
(*ast.FuncDecl)(nil),
}
inspect.Preorder(nodeFilter, func(node ast.Node) {
funcDec, ok := node.(*ast.FuncDecl)
if !ok {
return
}
compareId := tInfo.ObjectOf(funcDec.Name).String()
if c.id == compareId {
retNode = funcDec
}
})
return retNode
}
type callHelper struct {
call *ast.CallExpr
fset *token.FileSet
}
// identifyFuncLitBlock returns the AST block statement of the function literal called by the given expression,
// or nil if no function literal block statement could be identified.
func (c callHelper) identifyFuncLitBlock(expr ast.Expr) *ast.BlockStmt {
switch stmt := expr.(type) {
case *ast.FuncLit:
return stmt.Body
case *ast.Ident:
if stmt.Obj != nil {
switch objDecl := stmt.Obj.Decl.(type) {
case *ast.ValueSpec:
identIndex := findIdentIndex(stmt, objDecl.Names)
if identIndex != -1 && len(objDecl.Names) == len(objDecl.Values) {
value := objDecl.Values[identIndex]
return c.identifyFuncLitBlock(value)
}
case *ast.AssignStmt:
exprIndex := findIdentIndexFromExpr(stmt, objDecl.Lhs)
if exprIndex != -1 && len(objDecl.Lhs) == len(objDecl.Rhs) { // only deals with simple func lit assignments
value := objDecl.Rhs[exprIndex]
return c.identifyFuncLitBlock(value)
}
}
}
}
return nil
}
func findIdentIndex(id *ast.Ident, exprs []*ast.Ident) int {
for i, v := range exprs {
if v.Name == id.Name {
return i
}
}
return -1
}
func findIdentIndexFromExpr(id *ast.Ident, exprs []ast.Expr) int {
for i, v := range exprs {
if val, ok := v.(*ast.Ident); ok && val.Name == id.Name {
return i
}
}
return -1
}