// Copyright 2019 The go-ethereum Authors // This file is part of the go-ethereum library. // // The go-ethereum library is free software: you can redistribute it and/or modify // it under the terms of the GNU Lesser General Public License as published by // the Free Software Foundation, either version 3 of the License, or // (at your option) any later version. // // The go-ethereum library is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU Lesser General Public License for more details. // // You should have received a copy of the GNU Lesser General Public License // along with the go-ethereum library. If not, see . package rpc import ( "context" "errors" "fmt" "reflect" "strings" "sync" "unicode" jsoniter "github.com/json-iterator/go" "github.com/ledgerwatch/erigon-lib/common/dbg" "github.com/ledgerwatch/log/v3" ) var ( contextType = reflect.TypeOf((*context.Context)(nil)).Elem() jsonStreamType = reflect.TypeOf(&jsoniter.Stream{}) errorType = reflect.TypeOf((*error)(nil)).Elem() subscriptionType = reflect.TypeOf(Subscription{}) stringType = reflect.TypeOf("") ) type serviceRegistry struct { mu sync.Mutex services map[string]service logger log.Logger } // service represents a registered object. type service struct { name string // name for service callbacks map[string]*callback // registered handlers subscriptions map[string]*callback // available subscriptions/notifications } // callback is a method callback which was registered in the server type callback struct { fn reflect.Value // the function rcvr reflect.Value // receiver object of method, set if fn is method argTypes []reflect.Type // input argument types hasCtx bool // method's first argument is a context (not included in argTypes) errPos int // err return idx, of -1 when method cannot return error isSubscribe bool // true if this is a subscription callback streamable bool // support JSON streaming (more efficient for large responses) logger log.Logger } func (r *serviceRegistry) registerName(name string, rcvr interface{}) error { rcvrVal := reflect.ValueOf(rcvr) if name == "" { return fmt.Errorf("no service name for type %s", rcvrVal.Type().String()) } callbacks := suitableCallbacks(rcvrVal, r.logger) if len(callbacks) == 0 { return fmt.Errorf("service %T doesn't have any suitable methods/subscriptions to expose", rcvr) } r.mu.Lock() defer r.mu.Unlock() if r.services == nil { r.services = make(map[string]service) } svc, ok := r.services[name] if !ok { svc = service{ name: name, callbacks: make(map[string]*callback), subscriptions: make(map[string]*callback), } r.services[name] = svc } for name, cb := range callbacks { if cb.isSubscribe { svc.subscriptions[name] = cb } else { svc.callbacks[name] = cb } } return nil } // callback returns the callback corresponding to the given RPC method name. func (r *serviceRegistry) callback(method string) *callback { elem := strings.SplitN(method, serviceMethodSeparator, 2) if len(elem) != 2 { return nil } r.mu.Lock() defer r.mu.Unlock() return r.services[elem[0]].callbacks[elem[1]] } // subscription returns a subscription callback in the given service. func (r *serviceRegistry) subscription(service, name string) *callback { r.mu.Lock() defer r.mu.Unlock() return r.services[service].subscriptions[name] } // suitableCallbacks iterates over the methods of the given type. It determines if a method // satisfies the criteria for a RPC callback or a subscription callback and adds it to the // collection of callbacks. See server documentation for a summary of these criteria. func suitableCallbacks(receiver reflect.Value, logger log.Logger) map[string]*callback { typ := receiver.Type() callbacks := make(map[string]*callback) for m := 0; m < typ.NumMethod(); m++ { method := typ.Method(m) if method.PkgPath != "" { continue // method not exported } name := formatName(method.Name) cb := newCallback(receiver, method.Func, name, logger) if cb == nil { continue // function invalid } callbacks[name] = cb } return callbacks } // newCallback turns fn (a function) into a callback object. It returns nil if the function // is unsuitable as an RPC callback. func newCallback(receiver, fn reflect.Value, name string, logger log.Logger) *callback { fntype := fn.Type() c := &callback{fn: fn, rcvr: receiver, errPos: -1, isSubscribe: isPubSub(fntype), logger: logger} // Determine parameter types. They must all be exported or builtin types. c.makeArgTypes() // Verify return types. The function must return at most one error // and/or one other non-error value. outs := make([]reflect.Type, fntype.NumOut()) for i := 0; i < fntype.NumOut(); i++ { outs[i] = fntype.Out(i) } if len(outs) > 2 { logger.Warn(fmt.Sprintf("Cannot register RPC callback [%s] - maximum 2 return values are allowed, got %d", name, len(outs))) return nil } // If an error is returned, it must be the last returned value. switch { case len(outs) == 1 && isErrorType(outs[0]): c.errPos = 0 case len(outs) == 2: if isErrorType(outs[0]) || !isErrorType(outs[1]) { logger.Warn(fmt.Sprintf("Cannot register RPC callback [%s] - error must the last return value", name)) return nil } c.errPos = 1 } // If there is only one return value (error), and the last argument is *jsoniter.Stream, mark it as streamable if len(outs) != 1 && c.streamable { log.Warn(fmt.Sprintf("Cannot register RPC callback [%s] - streamable method may only return 1 value (error)", name)) return nil } return c } // makeArgTypes composes the argTypes list. func (c *callback) makeArgTypes() { fntype := c.fn.Type() // Skip receiver and context.Context parameter (if present). firstArg := 0 if c.rcvr.IsValid() { firstArg++ } if fntype.NumIn() > firstArg && fntype.In(firstArg) == contextType { c.hasCtx = true firstArg++ } // Check if method is streamable numArgs := fntype.NumIn() if fntype.NumIn() > firstArg && fntype.In(numArgs-1) == jsonStreamType { c.streamable = true numArgs-- } // Add all remaining parameters (expect json stream, if present) c.argTypes = make([]reflect.Type, numArgs-firstArg) for i := firstArg; i < numArgs; i++ { c.argTypes[i-firstArg] = fntype.In(i) } } // call invokes the callback. func (c *callback) call(ctx context.Context, method string, args []reflect.Value, stream *jsoniter.Stream) (res interface{}, errRes error) { // Create the argument slice. fullargs := make([]reflect.Value, 0, 2+len(args)) if c.rcvr.IsValid() { fullargs = append(fullargs, c.rcvr) } if c.hasCtx { fullargs = append(fullargs, reflect.ValueOf(ctx)) } fullargs = append(fullargs, args...) if c.streamable { fullargs = append(fullargs, reflect.ValueOf(stream)) } // Catch panic while running the callback. defer func() { if err := recover(); err != nil { c.logger.Error("RPC method " + method + " crashed: " + fmt.Sprintf("%v\n%s", err, dbg.Stack())) errRes = errors.New("method handler crashed") } }() // Run the callback. results := c.fn.Call(fullargs) if len(results) == 0 { return nil, nil } if c.errPos >= 0 && !results[c.errPos].IsNil() { // Method has returned non-nil error value. err := results[c.errPos].Interface().(error) return reflect.Value{}, err } return results[0].Interface(), nil } // Is t context.Context or *context.Context? func isContextType(t reflect.Type) bool { for t.Kind() == reflect.Ptr { t = t.Elem() } return t == contextType } // Does t satisfy the error interface? func isErrorType(t reflect.Type) bool { for t.Kind() == reflect.Ptr { t = t.Elem() } return t.Implements(errorType) } // Is t Subscription or *Subscription? func isSubscriptionType(t reflect.Type) bool { for t.Kind() == reflect.Ptr { t = t.Elem() } return t == subscriptionType } // isPubSub tests whether the given method has as as first argument a context.Context and // returns the pair (Subscription, error). func isPubSub(methodType reflect.Type) bool { // numIn(0) is the receiver type if methodType.NumIn() < 2 || methodType.NumOut() != 2 { return false } return isContextType(methodType.In(1)) && isSubscriptionType(methodType.Out(0)) && isErrorType(methodType.Out(1)) } // formatName converts to first character of name to lowercase. func formatName(name string) string { ret := []rune(name) if len(ret) > 0 { ret[0] = unicode.ToLower(ret[0]) } return string(ret) }