mirror of
https://gitlab.com/pulsechaincom/prysm-pulse.git
synced 2025-01-11 12:10:05 +00:00
2a7a09b112
* add lock analyzer * fix locks * progress * fix failures * fix error log Co-authored-by: Raul Jordan <raul@prysmaticlabs.com>
644 lines
18 KiB
Go
644 lines
18 KiB
Go
// Analyzer tool for detecting nested or recursive mutex read lock statements
|
|
|
|
package recursivelock
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"go/ast"
|
|
"go/token"
|
|
"go/types"
|
|
|
|
"golang.org/x/tools/go/analysis"
|
|
"golang.org/x/tools/go/analysis/passes/inspect"
|
|
"golang.org/x/tools/go/ast/astutil"
|
|
"golang.org/x/tools/go/ast/inspector"
|
|
"golang.org/x/tools/go/types/typeutil"
|
|
)
|
|
|
|
// Analyzer runs static analysis.
|
|
var Analyzer = &analysis.Analyzer{
|
|
Name: "recursivelock",
|
|
Doc: "Checks for recursive or nested RLock calls",
|
|
Requires: []*analysis.Analyzer{inspect.Analyzer},
|
|
Run: run,
|
|
}
|
|
|
|
var errNestedRLock = errors.New("found recursive read lock call")
|
|
var errNestedLock = errors.New("found recursive lock call")
|
|
var errNestedMixedLock = errors.New("found recursive mixed lock call")
|
|
|
|
type mode int
|
|
|
|
const (
|
|
LockMode = mode(iota)
|
|
RLockMode
|
|
)
|
|
|
|
func (m mode) LockName() string {
|
|
switch m {
|
|
case LockMode:
|
|
return "Lock"
|
|
case RLockMode:
|
|
return "RLock"
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func (m mode) UnLockName() string {
|
|
switch m {
|
|
case LockMode:
|
|
return "Unlock"
|
|
case RLockMode:
|
|
return "RUnlock"
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func (m mode) ErrorFound() error {
|
|
switch m {
|
|
case LockMode:
|
|
return errNestedLock
|
|
case RLockMode:
|
|
return errNestedRLock
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func run(pass *analysis.Pass) (interface{}, error) {
|
|
inspect, ok := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
|
|
if !ok {
|
|
return nil, errors.New("analyzer is not type *inspector.Inspector")
|
|
}
|
|
|
|
nodeFilter := []ast.Node{
|
|
(*ast.GoStmt)(nil),
|
|
(*ast.CallExpr)(nil),
|
|
(*ast.DeferStmt)(nil),
|
|
(*ast.FuncDecl)(nil),
|
|
(*ast.FuncLit)(nil),
|
|
(*ast.File)(nil),
|
|
(*ast.IfStmt)(nil),
|
|
(*ast.ReturnStmt)(nil),
|
|
}
|
|
|
|
keepTrackOf := &tracker{
|
|
rLockTrack: &lockTracker{},
|
|
lockTrack: &lockTracker{},
|
|
}
|
|
inspect.Preorder(nodeFilter, func(node ast.Node) {
|
|
if keepTrackOf.rLockTrack.funcLitEnd.IsValid() && node.Pos() <= keepTrackOf.rLockTrack.funcLitEnd &&
|
|
keepTrackOf.lockTrack.funcLitEnd.IsValid() && node.Pos() <= keepTrackOf.lockTrack.funcLitEnd {
|
|
return
|
|
}
|
|
keepTrackOf.rLockTrack.funcLitEnd = token.NoPos
|
|
keepTrackOf.lockTrack.funcLitEnd = token.NoPos
|
|
|
|
if keepTrackOf.rLockTrack.deferEnd.IsValid() && node.Pos() > keepTrackOf.rLockTrack.deferEnd {
|
|
keepTrackOf.rLockTrack.deferEnd = token.NoPos
|
|
} else if keepTrackOf.rLockTrack.deferEnd.IsValid() {
|
|
return
|
|
}
|
|
if keepTrackOf.lockTrack.deferEnd.IsValid() && node.Pos() > keepTrackOf.lockTrack.deferEnd {
|
|
keepTrackOf.lockTrack.deferEnd = token.NoPos
|
|
} else if keepTrackOf.lockTrack.deferEnd.IsValid() {
|
|
return
|
|
}
|
|
|
|
if keepTrackOf.rLockTrack.retEnd.IsValid() && node.Pos() > keepTrackOf.rLockTrack.retEnd {
|
|
keepTrackOf.rLockTrack.retEnd = token.NoPos
|
|
keepTrackOf.rLockTrack.incFRU()
|
|
}
|
|
if keepTrackOf.lockTrack.retEnd.IsValid() && node.Pos() > keepTrackOf.lockTrack.retEnd {
|
|
keepTrackOf.lockTrack.retEnd = token.NoPos
|
|
keepTrackOf.lockTrack.incFRU()
|
|
}
|
|
keepTrackOf = stmtSelector(node, pass, keepTrackOf, inspect)
|
|
})
|
|
return nil, nil
|
|
}
|
|
|
|
func stmtSelector(node ast.Node, pass *analysis.Pass, keepTrackOf *tracker, inspect *inspector.Inspector) *tracker {
|
|
switch stmt := node.(type) {
|
|
case *ast.GoStmt:
|
|
keepTrackOf.rLockTrack.goroutinePos = stmt.Call.End()
|
|
keepTrackOf.lockTrack.goroutinePos = stmt.Call.End()
|
|
case *ast.CallExpr:
|
|
if stmt.End() == keepTrackOf.rLockTrack.goroutinePos ||
|
|
stmt.End() == keepTrackOf.lockTrack.goroutinePos {
|
|
keepTrackOf.rLockTrack.goroutinePos = 0
|
|
keepTrackOf.lockTrack.goroutinePos = 0
|
|
break
|
|
}
|
|
call := getCallInfo(pass.TypesInfo, stmt)
|
|
if call == nil {
|
|
break
|
|
}
|
|
selMap := mapSelTypes(stmt, pass)
|
|
if selMap == nil {
|
|
break
|
|
}
|
|
checkForRecLocks(node, pass, inspect, RLockMode, call, keepTrackOf.rLockTrack, selMap)
|
|
checkForRecLocks(node, pass, inspect, LockMode, call, keepTrackOf.lockTrack, selMap)
|
|
|
|
case *ast.File:
|
|
keepTrackOf = &tracker{
|
|
rLockTrack: &lockTracker{},
|
|
lockTrack: &lockTracker{},
|
|
}
|
|
|
|
case *ast.FuncDecl:
|
|
keepTrackOf = &tracker{
|
|
rLockTrack: &lockTracker{},
|
|
lockTrack: &lockTracker{},
|
|
}
|
|
keepTrackOf.rLockTrack.funcEnd = stmt.End()
|
|
|
|
case *ast.FuncLit:
|
|
if keepTrackOf.rLockTrack.funcLitEnd == token.NoPos {
|
|
keepTrackOf.rLockTrack.funcLitEnd = stmt.End()
|
|
}
|
|
if keepTrackOf.lockTrack.funcLitEnd == token.NoPos {
|
|
keepTrackOf.lockTrack.funcLitEnd = stmt.End()
|
|
}
|
|
case *ast.IfStmt:
|
|
stmts := stmt.Body.List
|
|
for i := 0; i < len(stmts); i++ {
|
|
keepTrackOf = stmtSelector(stmts[i], pass, keepTrackOf, inspect)
|
|
}
|
|
keepTrackOf = stmtSelector(stmt.Else, pass, keepTrackOf, inspect)
|
|
case *ast.DeferStmt:
|
|
call := getCallInfo(pass.TypesInfo, stmt.Call)
|
|
if keepTrackOf.rLockTrack.deferEnd == token.NoPos {
|
|
keepTrackOf.rLockTrack.deferEnd = stmt.End()
|
|
}
|
|
if keepTrackOf.lockTrack.deferEnd == token.NoPos {
|
|
keepTrackOf.lockTrack.deferEnd = stmt.End()
|
|
}
|
|
|
|
if call != nil && call.name == RLockMode.UnLockName() {
|
|
keepTrackOf.rLockTrack.deferredRUnlock = true
|
|
}
|
|
if call != nil && call.name == LockMode.UnLockName() {
|
|
keepTrackOf.lockTrack.deferredRUnlock = true
|
|
}
|
|
|
|
case *ast.ReturnStmt:
|
|
for i := 0; i < len(stmt.Results); i++ {
|
|
keepTrackOf = stmtSelector(stmt.Results[i], pass, keepTrackOf, inspect)
|
|
}
|
|
if keepTrackOf.rLockTrack.deferredRUnlock && keepTrackOf.rLockTrack.retEnd == token.NoPos {
|
|
keepTrackOf.rLockTrack.deincFRU()
|
|
keepTrackOf.rLockTrack.retEnd = stmt.End()
|
|
}
|
|
if keepTrackOf.lockTrack.deferredRUnlock && keepTrackOf.lockTrack.retEnd == token.NoPos {
|
|
keepTrackOf.lockTrack.deincFRU()
|
|
keepTrackOf.lockTrack.retEnd = stmt.End()
|
|
}
|
|
}
|
|
return keepTrackOf
|
|
}
|
|
|
|
type tracker struct {
|
|
rLockTrack *lockTracker
|
|
lockTrack *lockTracker
|
|
}
|
|
|
|
type lockTracker struct {
|
|
funcEnd token.Pos
|
|
retEnd token.Pos
|
|
deferEnd token.Pos
|
|
funcLitEnd token.Pos
|
|
goroutinePos token.Pos
|
|
deferredRUnlock bool
|
|
foundRLock int
|
|
rLockSelector *selIdentList
|
|
}
|
|
|
|
func (t lockTracker) String() string {
|
|
return fmt.Sprintf("funcEnd:%v\nretEnd:%v\ndeferEnd:%v\ndeferredRU:%v\nfoundRLock:%v\n", t.funcEnd, t.retEnd, t.deferEnd, t.deferredRUnlock, t.foundRLock)
|
|
}
|
|
|
|
func (t *lockTracker) deincFRU() {
|
|
if t.foundRLock > 0 {
|
|
t.foundRLock -= 1
|
|
}
|
|
}
|
|
func (t *lockTracker) incFRU() {
|
|
t.foundRLock += 1
|
|
}
|
|
|
|
func checkForRecLocks(node ast.Node, pass *analysis.Pass, inspect *inspector.Inspector, lockmode mode, call *callInfo,
|
|
lockTracker *lockTracker, selMap *selIdentList) {
|
|
name := call.name
|
|
if lockTracker.rLockSelector != nil {
|
|
if lockTracker.foundRLock > 0 {
|
|
if lockTracker.rLockSelector.isRelated(selMap, 0) {
|
|
pass.Reportf(
|
|
node.Pos(),
|
|
fmt.Sprintf(
|
|
"%v",
|
|
errNestedMixedLock,
|
|
),
|
|
)
|
|
}
|
|
if lockTracker.rLockSelector.isEqual(selMap, 0) {
|
|
pass.Reportf(
|
|
node.Pos(),
|
|
fmt.Sprintf(
|
|
"%v",
|
|
lockmode.ErrorFound(),
|
|
),
|
|
)
|
|
} else {
|
|
if stack := hasNestedlock(lockTracker.rLockSelector, lockTracker.goroutinePos, selMap, call, inspect, pass, make(map[string]bool),
|
|
lockmode.UnLockName()); stack != "" {
|
|
pass.Reportf(
|
|
node.Pos(),
|
|
fmt.Sprintf(
|
|
"%v\n%v",
|
|
lockmode.ErrorFound(),
|
|
stack,
|
|
),
|
|
)
|
|
}
|
|
}
|
|
}
|
|
if name == lockmode.UnLockName() && lockTracker.rLockSelector.isEqual(selMap, 1) {
|
|
lockTracker.deincFRU()
|
|
}
|
|
if name == lockmode.LockName() && lockTracker.foundRLock == 0 && lockTracker.rLockSelector.isEqual(selMap, 0) {
|
|
lockTracker.incFRU()
|
|
}
|
|
} else if name == lockmode.LockName() && lockTracker.foundRLock == 0 {
|
|
lockTracker.rLockSelector = selMap
|
|
lockTracker.incFRU()
|
|
}
|
|
}
|
|
|
|
// Stores the AST and type information of a single item in a selector expression
|
|
// For example, "a.b.c()", a selIdentNode might store the information for "a"
|
|
type selIdentNode struct {
|
|
next *selIdentNode
|
|
this *ast.Ident
|
|
typObj types.Object
|
|
}
|
|
|
|
// a list of selIdentNodes. Stores the information of an entire selector expression
|
|
// For example, each item in "a.b.c()" is stored as a node in this list, with the start node being "a"
|
|
type selIdentList struct {
|
|
start *selIdentNode
|
|
length int
|
|
current *selIdentNode // used for internal functions
|
|
currentIndex int // used for internal functions
|
|
}
|
|
|
|
// returns the next item in the list, and increments the counter keeping track of where we are in the list
|
|
func (s *selIdentList) next() (n *selIdentNode) {
|
|
n = s.current.next
|
|
if n != nil {
|
|
s.current = n
|
|
s.currentIndex++
|
|
}
|
|
return n
|
|
}
|
|
|
|
// reset resets the current node to the start node in the list
|
|
func (s *selIdentList) reset() {
|
|
s.current = s.start
|
|
s.currentIndex = 0
|
|
}
|
|
|
|
// isEqual returns true if two selIdentLists are equal to each other.
|
|
// The offset parameter tells how far in the list to check for equality.
|
|
// For example, a.b.c() and a.b.d() are equal with an offset of 1.
|
|
func (s *selIdentList) isEqual(s2 *selIdentList, offset int) bool {
|
|
if s2 == nil || (s.length != s2.length) {
|
|
return false
|
|
}
|
|
s.reset()
|
|
s2.reset()
|
|
for i := true; i; {
|
|
if !s.current.isEqual(s2.current) {
|
|
return false
|
|
}
|
|
if s.currentIndex < s.length-offset-1 && s.next() != nil {
|
|
s2.next()
|
|
} else {
|
|
i = false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
// isRelated checks if our selectors are of the same type and
|
|
// reference the same underlying object. If they do we check
|
|
// if the provided list is referencing a non-equal but related
|
|
// lock. Ex: Lock - RLock, RLock - Lock
|
|
// TODO: Use a generalizable method here instead of hardcoding
|
|
// the lock definitions here.
|
|
func (s *selIdentList) isRelated(s2 *selIdentList, offset int) bool {
|
|
if s2 == nil || (s.length != s2.length) {
|
|
return false
|
|
}
|
|
s.reset()
|
|
s2.reset()
|
|
for i := true; i; {
|
|
if !s.current.isEqual(s2.current) {
|
|
return false
|
|
}
|
|
if s.currentIndex < s.length-offset-1 && s.next() != nil {
|
|
s2.next()
|
|
} else {
|
|
i = false
|
|
}
|
|
// Only check if we are at the last index for
|
|
// related method calls.
|
|
if s.currentIndex == s.length-1 {
|
|
switch s.current.this.String() {
|
|
case LockMode.LockName():
|
|
if s2.current.this.String() == RLockMode.LockName() {
|
|
return true
|
|
}
|
|
case RLockMode.LockName():
|
|
if s2.current.this.String() == LockMode.LockName() {
|
|
return true
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// getSub returns the shared beginning selIdentList of s and s2,
|
|
// if s contains all elements (except the last) of s2,
|
|
// and returns nil otherwise.
|
|
// For example, if s represents "a.b.c.d()" and s2 represents
|
|
// "a.b.e()", getSub will return a selIdentList representing "a.b".
|
|
// getSub returns nil if s2's length is greater than that of s
|
|
func (s *selIdentList) getSub(s2 *selIdentList) *selIdentList {
|
|
if s2 == nil || s2.length > s.length {
|
|
return nil
|
|
}
|
|
s.reset()
|
|
s2.reset()
|
|
for i := true; i; {
|
|
if !s.current.isEqual(s2.current) {
|
|
return nil
|
|
}
|
|
if s2.currentIndex != s2.length-2 { // might want to add a selNode.prev() func
|
|
s.next()
|
|
s2.next()
|
|
} else {
|
|
i = false
|
|
}
|
|
}
|
|
return &selIdentList{
|
|
start: s.current,
|
|
length: s.length - s.currentIndex,
|
|
current: s.current,
|
|
currentIndex: 0,
|
|
}
|
|
}
|
|
|
|
// changeRoot changes the first selIdentNode of a selIdentList
|
|
// to one with given *ast.Ident and types.Object
|
|
func (s *selIdentList) changeRoot(r *ast.Ident, t types.Object) {
|
|
selNode := &selIdentNode{
|
|
this: r,
|
|
next: s.start.next,
|
|
typObj: t,
|
|
}
|
|
if s.start == s.current {
|
|
s.start = selNode
|
|
s.current = selNode
|
|
} else {
|
|
s.start = selNode
|
|
}
|
|
}
|
|
|
|
func (s selIdentList) String() (str string) {
|
|
var temp *selIdentNode = s.start
|
|
str = fmt.Sprintf("length: %v\n[\n", s.length)
|
|
for i := 0; temp != nil; i++ {
|
|
if i == s.currentIndex {
|
|
str += "*"
|
|
}
|
|
str += fmt.Sprintf("%v: %v\n", i, temp)
|
|
temp = temp.next
|
|
}
|
|
str += "]"
|
|
return str
|
|
}
|
|
|
|
func (s *selIdentNode) isEqual(s2 *selIdentNode) bool {
|
|
return (s.this.Name == s2.this.Name) && (s.typObj == s2.typObj)
|
|
}
|
|
|
|
func (s selIdentNode) String() string {
|
|
return fmt.Sprintf("{ ident: '%v', type: '%v' }", s.this, s.typObj)
|
|
}
|
|
|
|
// mapSelTypes returns a selIdentList representation of the given call expression
|
|
func mapSelTypes(c *ast.CallExpr, pass *analysis.Pass) *selIdentList {
|
|
list := &selIdentList{}
|
|
valid := list.recurMapSelTypes(c.Fun, nil, pass.TypesInfo)
|
|
if !valid {
|
|
return nil
|
|
}
|
|
return list
|
|
}
|
|
|
|
// recursively identifies the type of each identity node in a selector expression
|
|
func (l *selIdentList) recurMapSelTypes(e ast.Expr, next *selIdentNode, t *types.Info) bool {
|
|
expr := astutil.Unparen(e)
|
|
l.length++
|
|
s := &selIdentNode{next: next}
|
|
switch stmt := expr.(type) {
|
|
case *ast.Ident:
|
|
s.this = stmt
|
|
s.typObj = t.ObjectOf(stmt)
|
|
case *ast.SelectorExpr:
|
|
s.this = stmt.Sel
|
|
if sel, ok := t.Selections[stmt]; ok {
|
|
s.typObj = sel.Obj() // method or field
|
|
} else {
|
|
s.typObj = t.Uses[stmt.Sel] // qualified identifier?
|
|
}
|
|
return l.recurMapSelTypes(stmt.X, s, t)
|
|
default:
|
|
return false
|
|
}
|
|
l.current = s
|
|
l.start = s
|
|
return true
|
|
}
|
|
|
|
type callInfo struct {
|
|
call *ast.CallExpr
|
|
id string // String representation of the type object
|
|
name string // type ID [either the name (if the function is exported) or the package/name if otherwise] of the function/method
|
|
}
|
|
|
|
// getCallInfo returns a *callInfo struct with call info
|
|
func getCallInfo(tInfo *types.Info, call *ast.CallExpr) (c *callInfo) {
|
|
c = &callInfo{}
|
|
c.call = call
|
|
f := typeutil.Callee(tInfo, call)
|
|
if f == nil {
|
|
return nil
|
|
}
|
|
if _, isBuiltin := f.(*types.Builtin); isBuiltin {
|
|
return nil
|
|
}
|
|
s, ok := f.Type().(*types.Signature)
|
|
if ok && interfaceMethod(s) {
|
|
return nil
|
|
}
|
|
c.id = f.String()
|
|
c.name = f.Id()
|
|
return c
|
|
}
|
|
|
|
func interfaceMethod(s *types.Signature) bool {
|
|
recv := s.Recv()
|
|
return recv != nil && types.IsInterface(recv.Type())
|
|
}
|
|
|
|
// hasNestedlock returns a stack trace of the nested or recursive lock within the declaration of a function/method call (given by call).
|
|
// If the call expression does not contain a nested or recursive lock, hasNestedlock returns an empty string.
|
|
// hasNestedlock finds a nested or recursive lock by recursively calling itself on any functions called by the function/method represented
|
|
// by callInfo.
|
|
func hasNestedlock(fullRLockSelector *selIdentList, goPos token.Pos, compareMap *selIdentList, call *callInfo, inspect *inspector.Inspector,
|
|
pass *analysis.Pass, hist map[string]bool, lockName string) (retStack string) {
|
|
var rLockSelector *selIdentList
|
|
f := pass.Fset
|
|
tInfo := pass.TypesInfo
|
|
cH := callHelper{
|
|
call: call.call,
|
|
fset: pass.Fset,
|
|
}
|
|
var node ast.Node = cH.identifyFuncLitBlock(cH.call.Fun) // this seems a bit redundant
|
|
var recv *ast.Ident
|
|
if node == (*ast.BlockStmt)(nil) {
|
|
subMap := fullRLockSelector.getSub(compareMap)
|
|
if subMap != nil {
|
|
rLockSelector = subMap
|
|
} else {
|
|
return "" // if this is not a local function literal call, and the selectors don't match up, then we can just return
|
|
}
|
|
node = findCallDeclarationNode(call, inspect, pass.TypesInfo)
|
|
if node == (*ast.FuncDecl)(nil) {
|
|
return ""
|
|
} else if castedNode, ok := node.(*ast.FuncDecl); ok && castedNode.Recv != nil {
|
|
recv = castedNode.Recv.List[0].Names[0]
|
|
rLockSelector.changeRoot(recv, pass.TypesInfo.ObjectOf(recv))
|
|
}
|
|
} else {
|
|
rLockSelector = fullRLockSelector // no need to find a submap, since this is a local function call
|
|
}
|
|
addition := fmt.Sprintf("\t%q at %v\n", call.name, f.Position(call.call.Pos()))
|
|
ast.Inspect(node, func(iNode ast.Node) bool {
|
|
switch stmt := iNode.(type) {
|
|
case *ast.GoStmt:
|
|
goPos = stmt.End()
|
|
case *ast.CallExpr:
|
|
if stmt.End() == goPos {
|
|
goPos = 0
|
|
return false
|
|
}
|
|
c := getCallInfo(tInfo, stmt)
|
|
if c == nil {
|
|
return false
|
|
}
|
|
name := c.name
|
|
selMap := mapSelTypes(stmt, pass)
|
|
if rLockSelector.isEqual(selMap, 0) || rLockSelector.isRelated(selMap, 0) { // if the method found is an RLock method
|
|
retStack += addition + fmt.Sprintf("\t%q at %v\n", name, f.Position(iNode.Pos()))
|
|
} else if name != lockName { // name should not equal the previousName to prevent infinite recursive loop
|
|
nt := c.id
|
|
if !hist[nt] { // make sure we are not in an infinite recursive loop
|
|
hist[nt] = true
|
|
stack := hasNestedlock(rLockSelector, goPos, selMap, c, inspect, pass, hist, lockName)
|
|
delete(hist, nt)
|
|
if stack != "" {
|
|
retStack += addition + stack
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return true
|
|
})
|
|
return retStack
|
|
}
|
|
|
|
// findCallDeclarationNode takes a callInfo struct and inspects the AST of the package
|
|
// to find a matching method or function declaration. It returns this declaration of type *ast.FuncDecl
|
|
func findCallDeclarationNode(c *callInfo, inspect *inspector.Inspector, tInfo *types.Info) *ast.FuncDecl {
|
|
var retNode *ast.FuncDecl = nil
|
|
nodeFilter := []ast.Node{
|
|
(*ast.FuncDecl)(nil),
|
|
}
|
|
inspect.Preorder(nodeFilter, func(node ast.Node) {
|
|
funcDec, ok := node.(*ast.FuncDecl)
|
|
if !ok {
|
|
return
|
|
}
|
|
compareId := tInfo.ObjectOf(funcDec.Name).String()
|
|
if c.id == compareId {
|
|
retNode = funcDec
|
|
}
|
|
})
|
|
return retNode
|
|
}
|
|
|
|
type callHelper struct {
|
|
call *ast.CallExpr
|
|
fset *token.FileSet
|
|
}
|
|
|
|
// identifyFuncLitBlock returns the AST block statement of the function literal called by the given expression,
|
|
// or nil if no function literal block statement could be identified.
|
|
func (c callHelper) identifyFuncLitBlock(expr ast.Expr) *ast.BlockStmt {
|
|
switch stmt := expr.(type) {
|
|
case *ast.FuncLit:
|
|
return stmt.Body
|
|
case *ast.Ident:
|
|
if stmt.Obj != nil {
|
|
switch objDecl := stmt.Obj.Decl.(type) {
|
|
case *ast.ValueSpec:
|
|
identIndex := findIdentIndex(stmt, objDecl.Names)
|
|
if identIndex != -1 && len(objDecl.Names) == len(objDecl.Values) {
|
|
value := objDecl.Values[identIndex]
|
|
return c.identifyFuncLitBlock(value)
|
|
}
|
|
case *ast.AssignStmt:
|
|
exprIndex := findIdentIndexFromExpr(stmt, objDecl.Lhs)
|
|
if exprIndex != -1 && len(objDecl.Lhs) == len(objDecl.Rhs) { // only deals with simple func lit assignments
|
|
value := objDecl.Rhs[exprIndex]
|
|
return c.identifyFuncLitBlock(value)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func findIdentIndex(id *ast.Ident, exprs []*ast.Ident) int {
|
|
for i, v := range exprs {
|
|
if v.Name == id.Name {
|
|
return i
|
|
}
|
|
}
|
|
return -1
|
|
}
|
|
|
|
func findIdentIndexFromExpr(id *ast.Ident, exprs []ast.Expr) int {
|
|
for i, v := range exprs {
|
|
if val, ok := v.(*ast.Ident); ok && val.Name == id.Name {
|
|
return i
|
|
}
|
|
}
|
|
return -1
|
|
}
|