Granular rpc control (Allow list for RPC daemon) (#1341)

This commit is contained in:
Igor Mandrigin 2020-11-10 10:08:42 +01:00 committed by GitHub
parent 546b91f47e
commit ed9672620b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 248 additions and 1411 deletions

View File

@ -305,6 +305,33 @@ WARN [11-05|09:03:47.911] Served conn=127.0.0.
WARN [11-05|09:03:47.911] Served conn=127.0.0.1:59754 method=eth_newPendingTransactionFilter reqid=6 t="9.053µs" err="the method eth_newPendingTransactionFilter does not exist/is not available"
```
## Allowing only specific methods (Allowlist)
In some cases you might want to only allow certain methods in the namespaces
and hide others. That is possible with `rpc.accessList` flag.
1. Create a file, say, `rules.json`
2. Add the following content
```json
{
"allow": [
"net_version",
"web3_eth_getBlockByHash"
]
}
```
3. Provide this file to the rpcdaemon using `--rpc.accessList` flag
```
> rpcdaemon --private.api.addr=localhost:9090 --http.api=eth,debug,net,web3 --rpc.accessList=rules.json
```
Now only these two methods are available.
## For Developers
### Code generation

View File

@ -18,22 +18,23 @@ import (
)
type Flags struct {
PrivateApiAddr string
Chaindata string
SnapshotDir string
SnapshotMode string
HttpListenAddress string
TLSCertfile string
TLSCACert string
TLSKeyFile string
HttpPort int
HttpCORSDomain []string
HttpVirtualHost []string
API []string
Gascap uint64
MaxTraces uint64
TraceType string
WebsocketEnabled bool
PrivateApiAddr string
Chaindata string
SnapshotDir string
SnapshotMode string
HttpListenAddress string
TLSCertfile string
TLSCACert string
TLSKeyFile string
HttpPort int
HttpCORSDomain []string
HttpVirtualHost []string
API []string
Gascap uint64
MaxTraces uint64
TraceType string
WebsocketEnabled bool
RpcAllowListFilePath string
}
var rootCmd = &cobra.Command{
@ -76,6 +77,11 @@ func RootCommand() (*cobra.Command, *Flags) {
rootCmd.PersistentFlags().Uint64Var(&cfg.MaxTraces, "trace.maxtraces", 200, "Sets a limit on traces that can be returned in trace_filter")
rootCmd.PersistentFlags().StringVar(&cfg.TraceType, "trace.type", "parity", "Specify the type of tracing [geth|parity*] (experimental)")
rootCmd.PersistentFlags().BoolVar(&cfg.WebsocketEnabled, "ws", false, "Enable Websockets")
rootCmd.PersistentFlags().StringVar(&cfg.RpcAllowListFilePath, "rpc.accessList", "", "Specify granular (method-by-method) API allowlist")
if err := rootCmd.MarkPersistentFlagFilename("rpc.accessList", "json"); err != nil {
panic(err)
}
return rootCmd, cfg
}
@ -124,12 +130,17 @@ func StartRpcServer(ctx context.Context, cfg Flags, rpcAPI []rpc.API) error {
httpEndpoint := fmt.Sprintf("%s:%d", cfg.HttpListenAddress, cfg.HttpPort)
srv := rpc.NewServer()
allowListForRPC, err := parseAllowListForRPC(cfg.RpcAllowListFilePath)
if err != nil {
return err
}
srv.SetAllowList(allowListForRPC)
if err := node.RegisterApisFromWhitelist(rpcAPI, cfg.API, srv, false); err != nil {
return fmt.Errorf("could not start register RPC apis: %w", err)
}
var err error
httpHandler := node.NewHTTPHandlerStack(srv, cfg.HttpCORSDomain, cfg.HttpVirtualHost)
var wsHandler http.Handler
if cfg.WebsocketEnabled {

View File

@ -0,0 +1,43 @@
package cli
import (
"encoding/json"
"io/ioutil"
"os"
"strings"
"github.com/ledgerwatch/turbo-geth/rpc"
)
type allowListFile struct {
Allow rpc.AllowList `json:"allow"`
}
func parseAllowListForRPC(path string) (rpc.AllowList, error) {
path = strings.TrimSpace(path)
if path == "" { // no file is provided
return nil, nil
}
file, err := os.Open(path)
if err != nil {
return nil, err
}
defer func() {
file.Close() //nolint: errcheck
}()
fileContents, err := ioutil.ReadAll(file)
if err != nil {
return nil, err
}
var allowListFileObj allowListFile
err = json.Unmarshal(fileContents, &allowListFileObj)
if err != nil {
return nil, err
}
return allowListFileObj.Allow, nil
}

View File

@ -1,487 +0,0 @@
// Copyright 2016 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package console
import (
"encoding/json"
"fmt"
"io"
"reflect"
"strings"
"time"
"github.com/dop251/goja"
"github.com/ledgerwatch/turbo-geth/accounts/scwallet"
"github.com/ledgerwatch/turbo-geth/accounts/usbwallet"
"github.com/ledgerwatch/turbo-geth/common/hexutil"
"github.com/ledgerwatch/turbo-geth/console/prompt"
"github.com/ledgerwatch/turbo-geth/internal/jsre"
"github.com/ledgerwatch/turbo-geth/rpc"
)
// bridge is a collection of JavaScript utility methods to bride the .js runtime
// environment and the Go RPC connection backing the remote method calls.
type bridge struct {
client *rpc.Client // RPC client to execute Ethereum requests through
prompter prompt.UserPrompter // Input prompter to allow interactive user feedback
printer io.Writer // Output writer to serialize any display strings to
}
// newBridge creates a new JavaScript wrapper around an RPC client.
func newBridge(client *rpc.Client, prompter prompt.UserPrompter, printer io.Writer) *bridge {
return &bridge{
client: client,
prompter: prompter,
printer: printer,
}
}
func getJeth(vm *goja.Runtime) *goja.Object {
jeth := vm.Get("jeth")
if jeth == nil {
panic(vm.ToValue("jeth object does not exist"))
}
return jeth.ToObject(vm)
}
// NewAccount is a wrapper around the personal.newAccount RPC method that uses a
// non-echoing password prompt to acquire the passphrase and executes the original
// RPC method (saved in jeth.newAccount) with it to actually execute the RPC call.
func (b *bridge) NewAccount(call jsre.Call) (goja.Value, error) {
var (
password string
confirm string
err error
)
switch {
// No password was specified, prompt the user for it
case len(call.Arguments) == 0:
if password, err = b.prompter.PromptPassword("Passphrase: "); err != nil {
return nil, err
}
if confirm, err = b.prompter.PromptPassword("Repeat passphrase: "); err != nil {
return nil, err
}
if password != confirm {
return nil, fmt.Errorf("passwords don't match")
}
// A single string password was specified, use that
case len(call.Arguments) == 1 && call.Argument(0).ToString() != nil:
password = call.Argument(0).ToString().String()
default:
return nil, fmt.Errorf("expected 0 or 1 string argument")
}
// Password acquired, execute the call and return
newAccount, callable := goja.AssertFunction(getJeth(call.VM).Get("newAccount"))
if !callable {
return nil, fmt.Errorf("jeth.newAccount is not callable")
}
ret, err := newAccount(goja.Null(), call.VM.ToValue(password))
if err != nil {
return nil, err
}
return ret, nil
}
// OpenWallet is a wrapper around personal.openWallet which can interpret and
// react to certain error messages, such as the Trezor PIN matrix request.
func (b *bridge) OpenWallet(call jsre.Call) (goja.Value, error) {
// Make sure we have a wallet specified to open
if call.Argument(0).ToObject(call.VM).ClassName() != "String" {
return nil, fmt.Errorf("first argument must be the wallet URL to open")
}
wallet := call.Argument(0)
var passwd goja.Value
if goja.IsUndefined(call.Argument(1)) || goja.IsNull(call.Argument(1)) {
passwd = call.VM.ToValue("")
} else {
passwd = call.Argument(1)
}
// Open the wallet and return if successful in itself
openWallet, callable := goja.AssertFunction(getJeth(call.VM).Get("openWallet"))
if !callable {
return nil, fmt.Errorf("jeth.openWallet is not callable")
}
val, err := openWallet(goja.Null(), wallet, passwd)
if err == nil {
return val, nil
}
// Wallet open failed, report error unless it's a PIN or PUK entry
switch {
case strings.HasSuffix(err.Error(), usbwallet.ErrTrezorPINNeeded.Error()):
val, err = b.readPinAndReopenWallet(call)
if err == nil {
return val, nil
}
val, err = b.readPassphraseAndReopenWallet(call)
if err != nil {
return nil, err
}
case strings.HasSuffix(err.Error(), scwallet.ErrPairingPasswordNeeded.Error()):
// PUK input requested, fetch from the user and call open again
var input string
input, err = b.prompter.PromptPassword("Please enter the pairing password: ")
if err != nil {
return nil, err
}
passwd = call.VM.ToValue(input)
if val, err = openWallet(goja.Null(), wallet, passwd); err != nil {
if !strings.HasSuffix(err.Error(), scwallet.ErrPINNeeded.Error()) {
return nil, err
} else {
// PIN input requested, fetch from the user and call open again
input, err = b.prompter.PromptPassword("Please enter current PIN: ")
if err != nil {
return nil, err
}
if val, err = openWallet(goja.Null(), wallet, call.VM.ToValue(input)); err != nil {
return nil, err
}
}
}
case strings.HasSuffix(err.Error(), scwallet.ErrPINUnblockNeeded.Error()):
// PIN unblock requested, fetch PUK and new PIN from the user
var pukpin string
var input string
input, err = b.prompter.PromptPassword("Please enter current PUK: ")
if err != nil {
return nil, err
}
pukpin = input
input, err = b.prompter.PromptPassword("Please enter new PIN: ")
if err != nil {
return nil, err
}
pukpin += input
if val, err = openWallet(goja.Null(), wallet, call.VM.ToValue(pukpin)); err != nil {
return nil, err
}
case strings.HasSuffix(err.Error(), scwallet.ErrPINNeeded.Error()):
// PIN input requested, fetch from the user and call open again
var input string
input, err = b.prompter.PromptPassword("Please enter current PIN: ")
if err != nil {
return nil, err
}
if val, err = openWallet(goja.Null(), wallet, call.VM.ToValue(input)); err != nil {
return nil, err
}
default:
// Unknown error occurred, drop to the user
return nil, err
}
return val, nil
}
func (b *bridge) readPassphraseAndReopenWallet(call jsre.Call) (goja.Value, error) {
wallet := call.Argument(0)
input, err := b.prompter.PromptPassword("Please enter your passphrase: ")
if err != nil {
return nil, err
}
openWallet, callable := goja.AssertFunction(getJeth(call.VM).Get("openWallet"))
if !callable {
return nil, fmt.Errorf("jeth.openWallet is not callable")
}
return openWallet(goja.Null(), wallet, call.VM.ToValue(input))
}
func (b *bridge) readPinAndReopenWallet(call jsre.Call) (goja.Value, error) {
wallet := call.Argument(0)
// Trezor PIN matrix input requested, display the matrix to the user and fetch the data
fmt.Fprintf(b.printer, "Look at the device for number positions\n\n")
fmt.Fprintf(b.printer, "7 | 8 | 9\n")
fmt.Fprintf(b.printer, "--+---+--\n")
fmt.Fprintf(b.printer, "4 | 5 | 6\n")
fmt.Fprintf(b.printer, "--+---+--\n")
fmt.Fprintf(b.printer, "1 | 2 | 3\n\n")
input, err := b.prompter.PromptPassword("Please enter current PIN: ")
if err != nil {
return nil, err
}
openWallet, callable := goja.AssertFunction(getJeth(call.VM).Get("openWallet"))
if !callable {
return nil, fmt.Errorf("jeth.openWallet is not callable")
}
return openWallet(goja.Null(), wallet, call.VM.ToValue(input))
}
// UnlockAccount is a wrapper around the personal.unlockAccount RPC method that
// uses a non-echoing password prompt to acquire the passphrase and executes the
// original RPC method (saved in jeth.unlockAccount) with it to actually execute
// the RPC call.
func (b *bridge) UnlockAccount(call jsre.Call) (goja.Value, error) {
if len(call.Arguments) < 1 {
return nil, fmt.Errorf("usage: unlockAccount(account, [ password, duration ])")
}
account := call.Argument(0)
// Make sure we have an account specified to unlock.
if goja.IsUndefined(account) || goja.IsNull(account) || account.ExportType().Kind() != reflect.String {
return nil, fmt.Errorf("first argument must be the account to unlock")
}
// If password is not given or is the null value, prompt the user for it.
var passwd goja.Value
if goja.IsUndefined(call.Argument(1)) || goja.IsNull(call.Argument(1)) {
fmt.Fprintf(b.printer, "Unlock account %s\n", account)
input, err := b.prompter.PromptPassword("Passphrase: ")
if err != nil {
return nil, err
}
passwd = call.VM.ToValue(input)
} else {
if call.Argument(1).ExportType().Kind() != reflect.String {
return nil, fmt.Errorf("password must be a string")
}
passwd = call.Argument(1)
}
// Third argument is the duration how long the account should be unlocked.
duration := goja.Null()
if !goja.IsUndefined(call.Argument(2)) && !goja.IsNull(call.Argument(2)) {
if !isNumber(call.Argument(2)) {
return nil, fmt.Errorf("unlock duration must be a number")
}
duration = call.Argument(2)
}
// Send the request to the backend and return.
unlockAccount, callable := goja.AssertFunction(getJeth(call.VM).Get("unlockAccount"))
if !callable {
return nil, fmt.Errorf("jeth.unlockAccount is not callable")
}
return unlockAccount(goja.Null(), account, passwd, duration)
}
// Sign is a wrapper around the personal.sign RPC method that uses a non-echoing password
// prompt to acquire the passphrase and executes the original RPC method (saved in
// jeth.sign) with it to actually execute the RPC call.
func (b *bridge) Sign(call jsre.Call) (goja.Value, error) {
if nArgs := len(call.Arguments); nArgs < 2 {
return nil, fmt.Errorf("usage: sign(message, account, [ password ])")
}
var (
message = call.Argument(0)
account = call.Argument(1)
passwd = call.Argument(2)
)
if goja.IsUndefined(message) || message.ExportType().Kind() != reflect.String {
return nil, fmt.Errorf("first argument must be the message to sign")
}
if goja.IsUndefined(account) || account.ExportType().Kind() != reflect.String {
return nil, fmt.Errorf("second argument must be the account to sign with")
}
// if the password is not given or null ask the user and ensure password is a string
if goja.IsUndefined(passwd) || goja.IsNull(passwd) {
fmt.Fprintf(b.printer, "Give password for account %s\n", account)
input, err := b.prompter.PromptPassword("Password: ")
if err != nil {
return nil, err
}
passwd = call.VM.ToValue(input)
} else if passwd.ExportType().Kind() != reflect.String {
return nil, fmt.Errorf("third argument must be the password to unlock the account")
}
// Send the request to the backend and return
sign, callable := goja.AssertFunction(getJeth(call.VM).Get("sign"))
if !callable {
return nil, fmt.Errorf("jeth.sign is not callable")
}
return sign(goja.Null(), message, account, passwd)
}
// Sleep will block the console for the specified number of seconds.
func (b *bridge) Sleep(call jsre.Call) (goja.Value, error) {
if nArgs := len(call.Arguments); nArgs < 1 {
return nil, fmt.Errorf("usage: sleep(<number of seconds>)")
}
sleepObj := call.Argument(0)
if goja.IsUndefined(sleepObj) || goja.IsNull(sleepObj) || !isNumber(sleepObj) {
return nil, fmt.Errorf("usage: sleep(<number of seconds>)")
}
sleep := sleepObj.ToFloat()
time.Sleep(time.Duration(sleep * float64(time.Second)))
return call.VM.ToValue(true), nil
}
// SleepBlocks will block the console for a specified number of new blocks optionally
// until the given timeout is reached.
func (b *bridge) SleepBlocks(call jsre.Call) (goja.Value, error) {
// Parse the input parameters for the sleep.
var (
blocks = int64(0)
sleep = int64(9999999999999999) // indefinitely
)
nArgs := len(call.Arguments)
if nArgs == 0 {
return nil, fmt.Errorf("usage: sleepBlocks(<n blocks>[, max sleep in seconds])")
}
if nArgs >= 1 {
if goja.IsNull(call.Argument(0)) || goja.IsUndefined(call.Argument(0)) || !isNumber(call.Argument(0)) {
return nil, fmt.Errorf("expected number as first argument")
}
blocks = call.Argument(0).ToInteger()
}
if nArgs >= 2 {
if goja.IsNull(call.Argument(1)) || goja.IsUndefined(call.Argument(1)) || !isNumber(call.Argument(1)) {
return nil, fmt.Errorf("expected number as second argument")
}
sleep = call.Argument(1).ToInteger()
}
// Poll the current block number until either it or a timeout is reached.
deadline := time.Now().Add(time.Duration(sleep) * time.Second)
var lastNumber hexutil.Uint64
if err := b.client.Call(&lastNumber, "eth_blockNumber"); err != nil {
return nil, err
}
for time.Now().Before(deadline) {
var number hexutil.Uint64
if err := b.client.Call(&number, "eth_blockNumber"); err != nil {
return nil, err
}
if number != lastNumber {
lastNumber = number
blocks--
}
if blocks <= 0 {
break
}
time.Sleep(time.Second)
}
return call.VM.ToValue(true), nil
}
type jsonrpcCall struct {
ID int64
Method string
Params []interface{}
}
// Send implements the web3 provider "send" method.
func (b *bridge) Send(call jsre.Call) (goja.Value, error) {
// Remarshal the request into a Go value.
reqVal, err := call.Argument(0).ToObject(call.VM).MarshalJSON()
if err != nil {
return nil, err
}
var (
rawReq = string(reqVal)
dec = json.NewDecoder(strings.NewReader(rawReq))
reqs []jsonrpcCall
batch bool
)
dec.UseNumber() // avoid float64s
if rawReq[0] == '[' {
batch = true
dec.Decode(&reqs)
} else {
batch = false
reqs = make([]jsonrpcCall, 1)
dec.Decode(&reqs[0])
}
// Execute the requests.
resps := make([]*goja.Object, 0, len(reqs))
for _, req := range reqs {
resp := call.VM.NewObject()
resp.Set("jsonrpc", "2.0") //nolint:errcheck
resp.Set("id", req.ID) //nolint:errcheck
var result json.RawMessage
if err = b.client.Call(&result, req.Method, req.Params...); err == nil {
if result == nil {
// Special case null because it is decoded as an empty
// raw message for some reason.
resp.Set("result", goja.Null()) // nolint:errcheck
} else {
JSON := call.VM.Get("JSON").ToObject(call.VM)
parse, callable := goja.AssertFunction(JSON.Get("parse"))
if !callable {
return nil, fmt.Errorf("JSON.parse is not a function")
}
resultVal, err := parse(goja.Null(), call.VM.ToValue(string(result)))
if err != nil {
setError(resp, -32603, err.Error(), nil)
} else {
resp.Set("result", resultVal)
}
}
} else {
code := -32603
var data interface{}
if err, ok := err.(rpc.Error); ok {
code = err.ErrorCode()
}
if err, ok := err.(rpc.DataError); ok {
data = err.ErrorData()
}
setError(resp, code, err.Error(), data)
}
resps = append(resps, resp)
}
// Return the responses either to the callback (if supplied)
// or directly as the return value.
var result goja.Value
if batch {
result = call.VM.ToValue(resps)
} else {
result = resps[0]
}
if fn, isFunc := goja.AssertFunction(call.Argument(1)); isFunc {
_, err := fn(goja.Null(), goja.Null(), result)
return goja.Undefined(), err
}
return result, nil
}
func setError(resp *goja.Object, code int, msg string, data interface{}) {
err := make(map[string]interface{})
err["code"] = code
err["message"] = msg
if data != nil {
err["data"] = data
}
resp.Set("error", err) //nolint:errcheck
}
// isNumber returns true if input value is a JS number.
func isNumber(v goja.Value) bool {
k := v.ExportType().Kind()
return k >= reflect.Int && k <= reflect.Float64
}
func getObject(vm *goja.Runtime, name string) *goja.Object {
v := vm.Get(name)
if v == nil {
return nil
}
return v.ToObject(vm)
}

View File

@ -1,49 +0,0 @@
// Copyright 2020 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
//nolint:errcheck
package console
import (
"testing"
"github.com/dop251/goja"
"github.com/ledgerwatch/turbo-geth/internal/jsre"
)
// TestUndefinedAsParam ensures that personal functions can receive
// `undefined` as a parameter.
func TestUndefinedAsParam(t *testing.T) {
b := bridge{}
call := jsre.Call{}
call.Arguments = []goja.Value{goja.Undefined()}
b.UnlockAccount(call)
b.Sign(call)
b.Sleep(call)
}
// TestNullAsParam ensures that personal functions can receive
// `null` as a parameter.
func TestNullAsParam(t *testing.T) {
b := bridge{}
call := jsre.Call{}
call.Arguments = []goja.Value{goja.Null()}
b.UnlockAccount(call)
b.Sign(call)
b.Sleep(call)
}

View File

@ -1,487 +0,0 @@
// Copyright 2016 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
// nolint:errcheck
package console
import (
"fmt"
"io"
"io/ioutil"
"os"
"os/signal"
"path/filepath"
"regexp"
"sort"
"strings"
"syscall"
"github.com/dop251/goja"
"github.com/ledgerwatch/turbo-geth/console/prompt"
"github.com/ledgerwatch/turbo-geth/internal/jsre"
"github.com/ledgerwatch/turbo-geth/internal/jsre/deps"
"github.com/ledgerwatch/turbo-geth/internal/web3ext"
"github.com/ledgerwatch/turbo-geth/rpc"
"github.com/mattn/go-colorable"
"github.com/peterh/liner"
)
var (
// u: unlock, s: signXX, sendXX, n: newAccount, i: importXX
passwordRegexp = regexp.MustCompile(`personal.[nusi]`)
onlyWhitespace = regexp.MustCompile(`^\s*$`)
exit = regexp.MustCompile(`^\s*exit\s*;*\s*$`)
)
// HistoryFile is the file within the data directory to store input scrollback.
const HistoryFile = "history"
// DefaultPrompt is the default prompt line prefix to use for user input querying.
const DefaultPrompt = "> "
// Config is the collection of configurations to fine tune the behavior of the
// JavaScript console.
type Config struct {
DataDir string // Data directory to store the console history at
DocRoot string // Filesystem path from where to load JavaScript files from
Client *rpc.Client // RPC client to execute Ethereum requests through
Prompt string // Input prompt prefix string (defaults to DefaultPrompt)
Prompter prompt.UserPrompter // Input prompter to allow interactive user feedback (defaults to TerminalPrompter)
Printer io.Writer // Output writer to serialize any display strings to (defaults to os.Stdout)
Preload []string // Absolute paths to JavaScript files to preload
}
// Console is a JavaScript interpreted runtime environment. It is a fully fledged
// JavaScript console attached to a running node via an external or in-process RPC
// client.
type Console struct {
client *rpc.Client // RPC client to execute Ethereum requests through
jsre *jsre.JSRE // JavaScript runtime environment running the interpreter
prompt string // Input prompt prefix string
prompter prompt.UserPrompter // Input prompter to allow interactive user feedback
histPath string // Absolute path to the console scrollback history
history []string // Scroll history maintained by the console
printer io.Writer // Output writer to serialize any display strings to
}
// New initializes a JavaScript interpreted runtime environment and sets defaults
// with the config struct.
func New(config Config) (*Console, error) {
// Handle unset config values gracefully
if config.Prompter == nil {
config.Prompter = prompt.Stdin
}
if config.Prompt == "" {
config.Prompt = DefaultPrompt
}
if config.Printer == nil {
config.Printer = colorable.NewColorableStdout()
}
// Initialize the console and return
console := &Console{
client: config.Client,
jsre: jsre.New(config.DocRoot, config.Printer),
prompt: config.Prompt,
prompter: config.Prompter,
printer: config.Printer,
histPath: filepath.Join(config.DataDir, HistoryFile),
}
if err := os.MkdirAll(config.DataDir, 0700); err != nil {
return nil, err
}
if err := console.init(config.Preload); err != nil {
return nil, err
}
return console, nil
}
// init retrieves the available APIs from the remote RPC provider and initializes
// the console's JavaScript namespaces based on the exposed modules.
func (c *Console) init(preload []string) error {
c.initConsoleObject()
// Initialize the JavaScript <-> Go RPC bridge.
bridge := newBridge(c.client, c.prompter, c.printer)
if err := c.initWeb3(bridge); err != nil {
return err
}
if err := c.initExtensions(); err != nil {
return err
}
// Add bridge overrides for web3.js functionality.
c.jsre.Do(func(vm *goja.Runtime) {
c.initAdmin(vm, bridge)
c.initPersonal(vm, bridge)
})
// Preload JavaScript files.
for _, path := range preload {
if err := c.jsre.Exec(path); err != nil {
failure := err.Error()
if gojaErr, ok := err.(*goja.Exception); ok {
failure = gojaErr.String()
}
return fmt.Errorf("%s: %v", path, failure)
}
}
// Configure the input prompter for history and tab completion.
if c.prompter != nil {
if content, err := ioutil.ReadFile(c.histPath); err != nil {
c.prompter.SetHistory(nil)
} else {
c.history = strings.Split(string(content), "\n")
c.prompter.SetHistory(c.history)
}
c.prompter.SetWordCompleter(c.AutoCompleteInput)
}
return nil
}
func (c *Console) initConsoleObject() {
c.jsre.Do(func(vm *goja.Runtime) {
console := vm.NewObject()
console.Set("log", c.consoleOutput)
console.Set("error", c.consoleOutput)
vm.Set("console", console)
})
}
func (c *Console) initWeb3(bridge *bridge) error {
bnJS := string(deps.MustAsset("bignumber.js"))
web3JS := string(deps.MustAsset("web3.js"))
if err := c.jsre.Compile("bignumber.js", bnJS); err != nil {
return fmt.Errorf("bignumber.js: %v", err)
}
if err := c.jsre.Compile("web3.js", web3JS); err != nil {
return fmt.Errorf("web3.js: %v", err)
}
if _, err := c.jsre.Run("var Web3 = require('web3');"); err != nil {
return fmt.Errorf("web3 require: %v", err)
}
var err error
c.jsre.Do(func(vm *goja.Runtime) {
transport := vm.NewObject()
transport.Set("send", jsre.MakeCallback(vm, bridge.Send))
transport.Set("sendAsync", jsre.MakeCallback(vm, bridge.Send))
vm.Set("_consoleWeb3Transport", transport)
_, err = vm.RunString("var web3 = new Web3(_consoleWeb3Transport)")
})
return err
}
// initExtensions loads and registers web3.js extensions.
func (c *Console) initExtensions() error {
// Compute aliases from server-provided modules.
apis, err := c.client.SupportedModules()
if err != nil {
return fmt.Errorf("api modules: %v", err)
}
aliases := map[string]struct{}{"eth": {}, "personal": {}}
for api := range apis {
if api == "web3" {
continue
}
aliases[api] = struct{}{}
if file, ok := web3ext.Modules[api]; ok {
if err = c.jsre.Compile(api+".js", file); err != nil {
return fmt.Errorf("%s.js: %v", api, err)
}
}
}
// Apply aliases.
c.jsre.Do(func(vm *goja.Runtime) {
web3 := getObject(vm, "web3")
for name := range aliases {
if v := web3.Get(name); v != nil {
vm.Set(name, v)
}
}
})
return nil
}
// initAdmin creates additional admin APIs implemented by the bridge.
func (c *Console) initAdmin(vm *goja.Runtime, bridge *bridge) {
if admin := getObject(vm, "admin"); admin != nil {
admin.Set("sleepBlocks", jsre.MakeCallback(vm, bridge.SleepBlocks))
admin.Set("sleep", jsre.MakeCallback(vm, bridge.Sleep))
admin.Set("clearHistory", c.clearHistory)
}
}
// initPersonal redirects account-related API methods through the bridge.
//
// If the console is in interactive mode and the 'personal' API is available, override
// the openWallet, unlockAccount, newAccount and sign methods since these require user
// interaction. The original web3 callbacks are stored in 'jeth'. These will be called
// by the bridge after the prompt and send the original web3 request to the backend.
func (c *Console) initPersonal(vm *goja.Runtime, bridge *bridge) {
personal := getObject(vm, "personal")
if personal == nil || c.prompter == nil {
return
}
jeth := vm.NewObject()
vm.Set("jeth", jeth)
jeth.Set("openWallet", personal.Get("openWallet"))
jeth.Set("unlockAccount", personal.Get("unlockAccount"))
jeth.Set("newAccount", personal.Get("newAccount"))
jeth.Set("sign", personal.Get("sign"))
personal.Set("openWallet", jsre.MakeCallback(vm, bridge.OpenWallet))
personal.Set("unlockAccount", jsre.MakeCallback(vm, bridge.UnlockAccount))
personal.Set("newAccount", jsre.MakeCallback(vm, bridge.NewAccount))
personal.Set("sign", jsre.MakeCallback(vm, bridge.Sign))
}
func (c *Console) clearHistory() {
c.history = nil
c.prompter.ClearHistory()
if err := os.Remove(c.histPath); err != nil {
fmt.Fprintln(c.printer, "can't delete history file:", err)
} else {
fmt.Fprintln(c.printer, "history file deleted")
}
}
// consoleOutput is an override for the console.log and console.error methods to
// stream the output into the configured output stream instead of stdout.
func (c *Console) consoleOutput(call goja.FunctionCall) goja.Value {
var output []string
for _, argument := range call.Arguments {
output = append(output, fmt.Sprintf("%v", argument))
}
fmt.Fprintln(c.printer, strings.Join(output, " "))
return goja.Null()
}
// AutoCompleteInput is a pre-assembled word completer to be used by the user
// input prompter to provide hints to the user about the methods available.
func (c *Console) AutoCompleteInput(line string, pos int) (string, []string, string) {
// No completions can be provided for empty inputs
if len(line) == 0 || pos == 0 {
return "", nil, ""
}
// Chunck data to relevant part for autocompletion
// E.g. in case of nested lines eth.getBalance(eth.coinb<tab><tab>
start := pos - 1
for ; start > 0; start-- {
// Skip all methods and namespaces (i.e. including the dot)
if line[start] == '.' || (line[start] >= 'a' && line[start] <= 'z') || (line[start] >= 'A' && line[start] <= 'Z') {
continue
}
// Handle web3 in a special way (i.e. other numbers aren't auto completed)
if start >= 3 && line[start-3:start] == "web3" {
start -= 3
continue
}
// We've hit an unexpected character, autocomplete form here
start++
break
}
return line[:start], c.jsre.CompleteKeywords(line[start:pos]), line[pos:]
}
// Welcome show summary of current Geth instance and some metadata about the
// console's available modules.
func (c *Console) Welcome() {
message := "Welcome to the Geth JavaScript console!\n\n"
// Print some generic Geth metadata
if res, err := c.jsre.Run(`
var message = "instance: " + web3.version.node + "\n";
try {
message += "coinbase: " + eth.coinbase + "\n";
} catch (err) {}
message += "at block: " + eth.blockNumber + " (" + new Date(1000 * eth.getBlock(eth.blockNumber).timestamp) + ")\n";
try {
message += " datadir: " + admin.datadir + "\n";
} catch (err) {}
message
`); err == nil {
message += res.String()
}
// List all the supported modules for the user to call
if apis, err := c.client.SupportedModules(); err == nil {
modules := make([]string, 0, len(apis))
for api, version := range apis {
modules = append(modules, fmt.Sprintf("%s:%s", api, version))
}
sort.Strings(modules)
message += " modules: " + strings.Join(modules, " ") + "\n"
}
fmt.Fprintln(c.printer, message)
}
// Evaluate executes code and pretty prints the result to the specified output
// stream.
func (c *Console) Evaluate(statement string) {
defer func() {
if r := recover(); r != nil {
fmt.Fprintf(c.printer, "[native] error: %v\n", r)
}
}()
c.jsre.Evaluate(statement, c.printer)
}
// Interactive starts an interactive user session, where input is propted from
// the configured user prompter.
func (c *Console) Interactive() {
var (
prompt = c.prompt // the current prompt line (used for multi-line inputs)
indents = 0 // the current number of input indents (used for multi-line inputs)
input = "" // the current user input
inputLine = make(chan string, 1) // receives user input
inputErr = make(chan error, 1) // receives liner errors
requestLine = make(chan string) // requests a line of input
interrupt = make(chan os.Signal, 1)
)
// Monitor Ctrl-C. While liner does turn on the relevant terminal mode bits to avoid
// the signal, a signal can still be received for unsupported terminals. Unfortunately
// there is no way to cancel the line reader when this happens. The readLines
// goroutine will be leaked in this case.
signal.Notify(interrupt, syscall.SIGINT, syscall.SIGTERM)
defer signal.Stop(interrupt)
// The line reader runs in a separate goroutine.
go c.readLines(inputLine, inputErr, requestLine)
defer close(requestLine)
for {
// Send the next prompt, triggering an input read.
requestLine <- prompt
select {
case <-interrupt:
fmt.Fprintln(c.printer, "caught interrupt, exiting")
return
case err := <-inputErr:
if err == liner.ErrPromptAborted && indents > 0 {
// When prompting for multi-line input, the first Ctrl-C resets
// the multi-line state.
prompt, indents, input = c.prompt, 0, ""
continue
}
return
case line := <-inputLine:
// User input was returned by the prompter, handle special cases.
if indents <= 0 && exit.MatchString(line) {
return
}
if onlyWhitespace.MatchString(line) {
continue
}
// Append the line to the input and check for multi-line interpretation.
input += line + "\n"
indents = countIndents(input)
if indents <= 0 {
prompt = c.prompt
} else {
prompt = strings.Repeat(".", indents*3) + " "
}
// If all the needed lines are present, save the command and run it.
if indents <= 0 {
if len(input) > 0 && input[0] != ' ' && !passwordRegexp.MatchString(input) {
if command := strings.TrimSpace(input); len(c.history) == 0 || command != c.history[len(c.history)-1] {
c.history = append(c.history, command)
if c.prompter != nil {
c.prompter.AppendHistory(command)
}
}
}
c.Evaluate(input)
input = ""
}
}
}
}
// readLines runs in its own goroutine, prompting for input.
func (c *Console) readLines(input chan<- string, errc chan<- error, prompt <-chan string) {
for p := range prompt {
line, err := c.prompter.PromptInput(p)
if err != nil {
errc <- err
} else {
input <- line
}
}
}
// countIndents returns the number of identations for the given input.
// In case of invalid input such as var a = } the result can be negative.
func countIndents(input string) int {
var (
indents = 0
inString = false
strOpenChar = ' ' // keep track of the string open char to allow var str = "I'm ....";
charEscaped = false // keep track if the previous char was the '\' char, allow var str = "abc\"def";
)
for _, c := range input {
switch c {
case '\\':
// indicate next char as escaped when in string and previous char isn't escaping this backslash
if !charEscaped && inString {
charEscaped = true
}
case '\'', '"':
if inString && !charEscaped && strOpenChar == c { // end string
inString = false
} else if !inString && !charEscaped { // begin string
inString = true
strOpenChar = c
}
charEscaped = false
case '{', '(':
if !inString { // ignore brackets when in string, allow var str = "a{"; without indenting
indents++
}
charEscaped = false
case '}', ')':
if !inString {
indents--
}
charEscaped = false
default:
charEscaped = false
}
}
return indents
}
// Execute runs the JavaScript file specified as the argument.
func (c *Console) Execute(path string) error {
return c.jsre.Exec(path)
}
// Stop cleans up the console and terminates the runtime environment.
func (c *Console) Stop(graceful bool) error {
if err := ioutil.WriteFile(c.histPath, []byte(strings.Join(c.history, "\n")), 0600); err != nil {
return err
}
if err := os.Chmod(c.histPath, 0600); err != nil { // Force 0600, even if it was different previously
return err
}
c.jsre.Stop(graceful)
return nil
}

View File

@ -1,346 +0,0 @@
// Copyright 2015 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
package console
import (
"bytes"
"errors"
"fmt"
"io/ioutil"
"os"
"strings"
"testing"
"time"
"github.com/davecgh/go-spew/spew"
"github.com/ledgerwatch/turbo-geth/common"
"github.com/ledgerwatch/turbo-geth/consensus/ethash"
"github.com/ledgerwatch/turbo-geth/console/prompt"
"github.com/ledgerwatch/turbo-geth/core"
"github.com/ledgerwatch/turbo-geth/eth"
"github.com/ledgerwatch/turbo-geth/internal/jsre"
"github.com/ledgerwatch/turbo-geth/miner"
"github.com/ledgerwatch/turbo-geth/node"
)
const (
testInstance = "console-tester"
testAddress = "0x8605cdbbdb6d264aa742e77020dcbc58fcdce182"
)
// hookedPrompter implements UserPrompter to simulate use input via channels.
type hookedPrompter struct {
scheduler chan string
}
func (p *hookedPrompter) PromptInput(prompt string) (string, error) {
// Send the prompt to the tester
select {
case p.scheduler <- prompt:
case <-time.After(time.Second):
return "", errors.New("prompt timeout")
}
// Retrieve the response and feed to the console
select {
case input := <-p.scheduler:
return input, nil
case <-time.After(time.Second):
return "", errors.New("input timeout")
}
}
func (p *hookedPrompter) PromptPassword(prompt string) (string, error) {
return "", errors.New("not implemented")
}
func (p *hookedPrompter) PromptConfirm(prompt string) (bool, error) {
return false, errors.New("not implemented")
}
func (p *hookedPrompter) SetHistory(history []string) {}
func (p *hookedPrompter) AppendHistory(command string) {}
func (p *hookedPrompter) ClearHistory() {}
func (p *hookedPrompter) SetWordCompleter(completer prompt.WordCompleter) {}
// tester is a console test environment for the console tests to operate on.
type tester struct {
workspace string
stack *node.Node
ethereum *eth.Ethereum
console *Console
input *hookedPrompter
output *bytes.Buffer
}
// newTester creates a test environment based on which the console can operate.
// Please ensure you call Close() on the returned tester to avoid leaks.
func newTester(t *testing.T, confOverride func(*eth.Config)) *tester {
t.Helper()
// Create a temporary storage for the node keys and initialize it
workspace, err := ioutil.TempDir("", "console-tester-")
if err != nil {
t.Fatalf("failed to create temporary keystore: %v", err)
}
// Create a networkless protocol stack and start an Ethereum service within
stack, err := node.New(&node.Config{DataDir: workspace, UseLightweightKDF: true, Name: testInstance})
if err != nil {
t.Fatalf("failed to create node: %v", err)
}
ethConf := &eth.Config{
Genesis: core.DeveloperGenesisBlock(15, common.Address{}),
Miner: miner.Config{
Etherbase: common.HexToAddress(testAddress),
},
Ethash: ethash.Config{
PowMode: ethash.ModeTest,
},
Pruning: false,
}
spew.Dump(ethConf)
if confOverride != nil {
confOverride(ethConf)
}
ethBackend, err := eth.New(stack, ethConf)
if err != nil {
t.Fatalf("failed to register Ethereum protocol: %v", err)
}
// Start the node and assemble the JavaScript console around it
if err = stack.Start(); err != nil {
t.Fatalf("failed to start test stack: %v", err)
}
client, err := stack.Attach()
if err != nil {
t.Fatalf("failed to attach to node: %v", err)
}
prompter := &hookedPrompter{scheduler: make(chan string)}
printer := new(bytes.Buffer)
console, err := New(Config{
DataDir: stack.DataDir(),
DocRoot: "testdata",
Client: client,
Prompter: prompter,
Printer: printer,
Preload: []string{"preload.js"},
})
if err != nil {
t.Fatalf("failed to create JavaScript console: %v", err)
}
// Create the final tester and return
return &tester{
workspace: workspace,
stack: stack,
ethereum: ethBackend,
console: console,
input: prompter,
output: printer,
}
}
// Close cleans up any temporary data folders and held resources.
func (env *tester) Close(t *testing.T) {
if err := env.console.Stop(false); err != nil {
t.Errorf("failed to stop embedded console: %v", err)
}
if err := env.stack.Close(); err != nil {
t.Errorf("failed to tear down embedded node: %v", err)
}
os.RemoveAll(env.workspace)
}
// Tests that the node lists the correct welcome message, notably that it contains
// the instance name, coinbase account, block number, data directory and supported
// console modules.
func TestWelcome(t *testing.T) {
tester := newTester(t, nil)
defer tester.Close(t)
tester.console.Welcome()
output := tester.output.String()
if want := "Welcome"; !strings.Contains(output, want) {
t.Fatalf("console output missing welcome message: have\n%s\nwant also %s", output, want)
}
if want := fmt.Sprintf("instance: %s", testInstance); !strings.Contains(output, want) {
t.Fatalf("console output missing instance: have\n%s\nwant also %s", output, want)
}
//if want := fmt.Sprintf("coinbase: %s", testAddress); !strings.Contains(output, want) {
// t.Fatalf("console output missing coinbase: have\n%s\nwant also %s", output, want)
//}
if want := "at block: 0"; !strings.Contains(output, want) {
t.Fatalf("console output missing sync status: have\n%s\nwant also %s", output, want)
}
if want := fmt.Sprintf("datadir: %s", tester.workspace); !strings.Contains(output, want) {
t.Fatalf("console output missing coinbase: have\n%s\nwant also %s", output, want)
}
}
// Tests that JavaScript statement evaluation works as intended.
func TestEvaluate(t *testing.T) {
tester := newTester(t, nil)
defer tester.Close(t)
tester.console.Evaluate("2 + 2")
if output := tester.output.String(); !strings.Contains(output, "4") {
t.Fatalf("statement evaluation failed: have %s, want %s", output, "4")
}
}
// Tests that the console can be used in interactive mode.
func TestInteractive(t *testing.T) {
// Create a tester and run an interactive console in the background
tester := newTester(t, nil)
defer tester.Close(t)
go tester.console.Interactive()
// Wait for a prompt and send a statement back
select {
case <-tester.input.scheduler:
case <-time.After(time.Second):
t.Fatalf("initial prompt timeout")
}
select {
case tester.input.scheduler <- "2+2":
case <-time.After(time.Second):
t.Fatalf("input feedback timeout")
}
// Wait for the second prompt and ensure first statement was evaluated
select {
case <-tester.input.scheduler:
case <-time.After(time.Second):
t.Fatalf("secondary prompt timeout")
}
if output := tester.output.String(); !strings.Contains(output, "4") {
t.Fatalf("statement evaluation failed: have %s, want %s", output, "4")
}
}
// Tests that preloaded JavaScript files have been executed before user is given
// input.
func TestPreload(t *testing.T) {
tester := newTester(t, nil)
defer tester.Close(t)
tester.console.Evaluate("preloaded")
if output := tester.output.String(); !strings.Contains(output, "some-preloaded-string") {
t.Fatalf("preloaded variable missing: have %s, want %s", output, "some-preloaded-string")
}
}
// Tests that JavaScript scripts can be executes from the configured asset path.
func TestExecute(t *testing.T) {
tester := newTester(t, nil)
defer tester.Close(t)
tester.console.Execute("exec.js")
tester.console.Evaluate("execed")
if output := tester.output.String(); !strings.Contains(output, "some-executed-string") {
t.Fatalf("execed variable missing: have %s, want %s", output, "some-executed-string")
}
}
// Tests that the JavaScript objects returned by statement executions are properly
// pretty printed instead of just displaying "[object]".
func TestPrettyPrint(t *testing.T) {
tester := newTester(t, nil)
defer tester.Close(t)
tester.console.Evaluate("obj = {int: 1, string: 'two', list: [3, 3, 3], obj: {null: null, func: function(){}}}")
// Define some specially formatted fields
var (
one = jsre.NumberColor("1")
two = jsre.StringColor("\"two\"")
three = jsre.NumberColor("3")
null = jsre.SpecialColor("null")
fun = jsre.FunctionColor("function()")
)
// Assemble the actual output we're after and verify
want := `{
int: ` + one + `,
list: [` + three + `, ` + three + `, ` + three + `],
obj: {
null: ` + null + `,
func: ` + fun + `
},
string: ` + two + `
}
`
if output := tester.output.String(); output != want {
t.Fatalf("pretty print mismatch: have %s, want %s", output, want)
}
}
// Tests that the JavaScript exceptions are properly formatted and colored.
func TestPrettyError(t *testing.T) {
tester := newTester(t, nil)
defer tester.Close(t)
tester.console.Evaluate("throw 'hello'")
want := jsre.ErrorColor("hello") + "\n\tat <eval>:1:7(1)\n\n"
if output := tester.output.String(); output != want {
t.Fatalf("pretty error mismatch: have %s, want %s", output, want)
}
}
// Tests that tests if the number of indents for JS input is calculated correct.
func TestIndenting(t *testing.T) {
testCases := []struct {
input string
expectedIndentCount int
}{
{`var a = 1;`, 0},
{`"some string"`, 0},
{`"some string with (parenthesis`, 0},
{`"some string with newline
("`, 0},
{`function v(a,b) {}`, 0},
{`function f(a,b) { var str = "asd("; };`, 0},
{`function f(a) {`, 1},
{`function f(a, function(b) {`, 2},
{`function f(a, function(b) {
var str = "a)}";
});`, 0},
{`function f(a,b) {
var str = "a{b(" + a, ", " + b;
}`, 0},
{`var str = "\"{"`, 0},
{`var str = "'("`, 0},
{`var str = "\\{"`, 0},
{`var str = "\\\\{"`, 0},
{`var str = 'a"{`, 0},
{`var obj = {`, 1},
{`var obj = { {a:1`, 2},
{`var obj = { {a:1}`, 1},
{`var obj = { {a:1}, b:2}`, 0},
{`var obj = {}`, 0},
{`var obj = {
a: 1, b: 2
}`, 0},
{`var test = }`, -1},
{`var str = "a\""; var obj = {`, 1},
}
for i, tt := range testCases {
counted := countIndents(tt.input)
if counted != tt.expectedIndentCount {
t.Errorf("test %d: invalid indenting: have %d, want %d", i, counted, tt.expectedIndentCount)
}
}
}

View File

@ -1 +0,0 @@
var execed = "some-executed-string";

View File

@ -1 +0,0 @@
var preloaded = "some-preloaded-string";

View File

@ -207,7 +207,7 @@ func (api *privateAdminAPI) StartRPC(host *string, port *int, cors *string, apis
if err := api.node.http.setListenAddr(*host, *port); err != nil {
return false, err
}
if err := api.node.http.enableRPC(api.node.rpcAPIs, config); err != nil {
if err := api.node.http.enableRPC(api.node.rpcAPIs, config, nil); err != nil {
return false, err
}
if err := api.node.http.start(); err != nil {
@ -263,7 +263,7 @@ func (api *privateAdminAPI) StartWS(host *string, port *int, allowedOrigins *str
if err := server.setListenAddr(*host, *port); err != nil {
return false, err
}
if err := server.enableWS(api.node.rpcAPIs, config); err != nil {
if err := server.enableWS(api.node.rpcAPIs, config, nil); err != nil {
return false, err
}
if err := server.start(); err != nil {

View File

@ -58,6 +58,8 @@ type Node struct {
ipc *ipcServer // Stores information about the ipc http server
inprocHandler *rpc.Server // In-process RPC request handler to process the API requests
rpcAllowList rpc.AllowList // list of RPC methods explicitly allowed for this RPC node
databases []ethdb.Closer
}
@ -342,6 +344,11 @@ func (n *Node) closeDataDir() {
}
}
// SetAllowListForRPC sets granular allow list for exposed RPC methods
func (n *Node) SetAllowListForRPC(allowList rpc.AllowList) {
n.rpcAllowList = allowList
}
// configureRPC is a helper method to configure all the various RPC endpoints during node
// startup. It's not meant to be called at any time afterwards as it makes certain
// assumptions about the state of the node.
@ -367,7 +374,7 @@ func (n *Node) startRPC() error {
if err := n.http.setListenAddr(n.config.HTTPHost, n.config.HTTPPort); err != nil {
return err
}
if err := n.http.enableRPC(n.rpcAPIs, config); err != nil {
if err := n.http.enableRPC(n.rpcAPIs, config, n.rpcAllowList); err != nil {
return err
}
}
@ -382,7 +389,7 @@ func (n *Node) startRPC() error {
if err := server.setListenAddr(n.config.WSHost, n.config.WSPort); err != nil {
return err
}
if err := server.enableWS(n.rpcAPIs, config); err != nil {
if err := server.enableWS(n.rpcAPIs, config, n.rpcAllowList); err != nil {
return err
}
}

View File

@ -484,7 +484,6 @@ func TestWebsocketHTTPOnSeparatePort_WSRequest(t *testing.T) {
if !checkRPC(node.HTTPEndpoint()) {
t.Fatalf("http request failed")
}
}
func createNode(t *testing.T, httpPort, wsPort int) *Node {

View File

@ -227,7 +227,7 @@ func (h *httpServer) doStop() {
}
// enableRPC turns on JSON-RPC over HTTP on the server.
func (h *httpServer) enableRPC(apis []rpc.API, config httpConfig) error {
func (h *httpServer) enableRPC(apis []rpc.API, config httpConfig, allowList rpc.AllowList) error {
h.mu.Lock()
defer h.mu.Unlock()
@ -237,6 +237,7 @@ func (h *httpServer) enableRPC(apis []rpc.API, config httpConfig) error {
// Create RPC server and handler.
srv := rpc.NewServer()
srv.SetAllowList(allowList)
if err := RegisterApisFromWhitelist(apis, config.Modules, srv, false); err != nil {
return err
}
@ -259,7 +260,7 @@ func (h *httpServer) disableRPC() bool {
}
// enableWS turns on JSON-RPC over WebSocket on the server.
func (h *httpServer) enableWS(apis []rpc.API, config wsConfig) error {
func (h *httpServer) enableWS(apis []rpc.API, config wsConfig, allowList rpc.AllowList) error {
h.mu.Lock()
defer h.mu.Unlock()
@ -269,6 +270,7 @@ func (h *httpServer) enableWS(apis []rpc.API, config wsConfig) error {
// Create RPC server and handler.
srv := rpc.NewServer()
srv.SetAllowList(allowList)
if err := RegisterApisFromWhitelist(apis, config.Modules, srv, false); err != nil {
return err
}

View File

@ -18,7 +18,10 @@ package node
import (
"bytes"
"fmt"
"io/ioutil"
"net/http"
"strings"
"testing"
"github.com/ledgerwatch/turbo-geth/internal/testlog"
@ -89,14 +92,16 @@ func TestIsWebsocket(t *testing.T) {
assert.True(t, isWebsocket(r))
}
func createAndStartServer(t *testing.T, conf httpConfig, ws bool, wsConf wsConfig) *httpServer {
func createAndStartServerWithAllowList(t *testing.T, conf httpConfig, ws bool, wsConf wsConfig) *httpServer {
t.Helper()
srv := newHTTPServer(testlog.Logger(t, log.LvlDebug), rpc.DefaultHTTPTimeouts)
assert.NoError(t, srv.enableRPC(nil, conf))
allowList := rpc.AllowList(map[string]struct{}{"net_version": {}}) //don't allow RPC modules
assert.NoError(t, srv.enableRPC(nil, conf, allowList))
if ws {
assert.NoError(t, srv.enableWS(nil, wsConf))
assert.NoError(t, srv.enableWS(nil, wsConf, allowList))
}
assert.NoError(t, srv.setListenAddr("localhost", 0))
assert.NoError(t, srv.start())
@ -104,10 +109,48 @@ func createAndStartServer(t *testing.T, conf httpConfig, ws bool, wsConf wsConfi
return srv
}
func createAndStartServer(t *testing.T, conf httpConfig, ws bool, wsConf wsConfig) *httpServer {
t.Helper()
srv := newHTTPServer(testlog.Logger(t, log.LvlDebug), rpc.DefaultHTTPTimeouts)
assert.NoError(t, srv.enableRPC(nil, conf, nil))
if ws {
assert.NoError(t, srv.enableWS(nil, wsConf, nil))
}
assert.NoError(t, srv.setListenAddr("localhost", 0))
assert.NoError(t, srv.start())
return srv
}
func TestAllowList(t *testing.T) {
srv := createAndStartServerWithAllowList(t, httpConfig{}, false, wsConfig{})
defer srv.stop()
assert.False(t, testCustomRequest(t, srv, "rpc_modules"))
}
func testCustomRequest(t *testing.T, srv *httpServer, method string) bool {
body := bytes.NewReader([]byte(fmt.Sprintf(`{"jsonrpc":"2.0","id":1,"method":"%s"}`, method)))
req, _ := http.NewRequest("POST", "http://"+srv.listenAddr(), body)
req.Header.Set("content-type", "application/json")
client := http.DefaultClient
resp, err := client.Do(req)
if err != nil {
return false
}
respBody, err := ioutil.ReadAll(resp.Body)
assert.NoError(t, err)
return !strings.Contains(string(respBody), "error")
}
func testRequest(t *testing.T, key, value, host string, srv *httpServer) *http.Response {
t.Helper()
body := bytes.NewReader([]byte(`{"jsonrpc":"2.0","id":1,method":"rpc_modules"}`))
body := bytes.NewReader([]byte(`{"jsonrpc":"2.0","id":1,"method":"rpc_modules"}`))
req, _ := http.NewRequest("POST", "http://"+srv.listenAddr(), body)
req.Header.Set("content-type", "application/json")
if key != "" && value != "" {

35
rpc/allow_list.go Normal file
View File

@ -0,0 +1,35 @@
package rpc
import "encoding/json"
type AllowList map[string]struct{}
func (a *AllowList) UnmarshalJSON(data []byte) error {
var keys []string
err := json.Unmarshal(data, &keys)
if err != nil {
return err
}
realA := make(map[string]struct{})
for _, k := range keys {
realA[k] = struct{}{}
}
*a = realA
return nil
}
// MarshalJSON returns *m as the JSON encoding of
func (a *AllowList) MarshalJSON() ([]byte, error) {
var realA map[string]struct{} = *a
keys := make([]string, len(realA))
i := 0
for key := range realA {
keys[i] = key
i++
}
return json.Marshal(keys)
}

23
rpc/allow_list_test.go Normal file
View File

@ -0,0 +1,23 @@
package rpc
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/assert"
)
func TestAllowListMarshaling(t *testing.T) {
}
func TestAllowListUnmarshaling(t *testing.T) {
allowListJSON := `[ "one", "two", "three" ]`
var allowList AllowList
err := json.Unmarshal([]byte(allowListJSON), &allowList)
assert.NoError(t, err, "should unmarshal successfully")
m := map[string]struct{}{"one": {}, "two": {}, "three": {}}
assert.Equal(t, allowList, AllowList(m))
}

View File

@ -74,9 +74,10 @@ type BatchElem struct {
// Client represents a connection to an RPC server.
type Client struct {
idgen func() ID // for subscriptions
isHTTP bool
services *serviceRegistry
idgen func() ID // for subscriptions
isHTTP bool
services *serviceRegistry
methodAllowList AllowList
idCounter uint32
@ -111,7 +112,7 @@ type clientConn struct {
func (c *Client) newClientConn(conn ServerCodec) *clientConn {
ctx := context.WithValue(context.Background(), clientContextKey{}, c)
handler := newHandler(ctx, conn, c.idgen, c.services)
handler := newHandler(ctx, conn, c.idgen, c.services, c.methodAllowList)
return &clientConn{conn, handler}
}

View File

@ -62,6 +62,8 @@ type handler struct {
log log.Logger
allowSubscribe bool
allowList AllowList // a list of explicitly allowed methods, if empty -- everything is allowed
subLock sync.Mutex
serverSubs map[ID]*Subscription
}
@ -71,7 +73,7 @@ type callProc struct {
notifiers []*Notifier
}
func newHandler(connCtx context.Context, conn jsonWriter, idgen func() ID, reg *serviceRegistry) *handler {
func newHandler(connCtx context.Context, conn jsonWriter, idgen func() ID, reg *serviceRegistry, allowList AllowList) *handler {
rootCtx, cancelRoot := context.WithCancel(connCtx)
h := &handler{
reg: reg,
@ -84,6 +86,7 @@ func newHandler(connCtx context.Context, conn jsonWriter, idgen func() ID, reg *
allowSubscribe: true,
serverSubs: make(map[ID]*Subscription),
log: log.Root(),
allowList: allowList,
}
if conn.remoteAddr() != "" {
h.log = h.log.New("conn", conn.remoteAddr())
@ -314,6 +317,14 @@ func (h *handler) handleCallMsg(ctx *callProc, msg *jsonrpcMessage) *jsonrpcMess
}
}
func (h *handler) isMethodAllowedByGranularControl(method string) bool {
if len(h.allowList) == 0 {
return true
}
_, ok := h.allowList[method]
return ok
}
// handleCall processes method calls.
func (h *handler) handleCall(cp *callProc, msg *jsonrpcMessage) *jsonrpcMessage {
if msg.isSubscribe() {
@ -322,7 +333,7 @@ func (h *handler) handleCall(cp *callProc, msg *jsonrpcMessage) *jsonrpcMessage
var callb *callback
if msg.isUnsubscribe() {
callb = h.unsubscribeCb
} else {
} else if h.isMethodAllowedByGranularControl(msg.Method) {
callb = h.reg.callback(msg.Method)
}
if callb == nil {

View File

@ -42,10 +42,11 @@ const (
// Server is an RPC server.
type Server struct {
services serviceRegistry
idgen func() ID
run int32
codecs mapset.Set
services serviceRegistry
methodAllowList AllowList
idgen func() ID
run int32
codecs mapset.Set
}
// NewServer creates a new server instance with no registered handlers.
@ -58,6 +59,11 @@ func NewServer() *Server {
return server
}
// SetAllowList sets the allow list for methods that are handled by this server
func (s *Server) SetAllowList(allowList AllowList) {
s.methodAllowList = allowList
}
// RegisterName creates a service for the given receiver type under the given name. When no
// methods on the given receiver match the criteria to be either a RPC method or a
// subscription an error is returned. Otherwise a new service is created and added to the
@ -97,7 +103,7 @@ func (s *Server) serveSingleRequest(ctx context.Context, codec ServerCodec) {
return
}
h := newHandler(ctx, codec, s.idgen, &s.services)
h := newHandler(ctx, codec, s.idgen, &s.services, s.methodAllowList)
h.allowSubscribe = false
defer h.close(io.EOF, nil)