mirror of
https://gitlab.com/pulsechaincom/erigon-pulse.git
synced 2024-12-21 19:20:39 +00:00
Granular rpc control (Allow list for RPC daemon) (#1341)
This commit is contained in:
parent
546b91f47e
commit
ed9672620b
@ -305,6 +305,33 @@ WARN [11-05|09:03:47.911] Served conn=127.0.0.
|
||||
WARN [11-05|09:03:47.911] Served conn=127.0.0.1:59754 method=eth_newPendingTransactionFilter reqid=6 t="9.053µs" err="the method eth_newPendingTransactionFilter does not exist/is not available"
|
||||
```
|
||||
|
||||
## Allowing only specific methods (Allowlist)
|
||||
|
||||
In some cases you might want to only allow certain methods in the namespaces
|
||||
and hide others. That is possible with `rpc.accessList` flag.
|
||||
|
||||
1. Create a file, say, `rules.json`
|
||||
|
||||
2. Add the following content
|
||||
|
||||
```json
|
||||
{
|
||||
"allow": [
|
||||
"net_version",
|
||||
"web3_eth_getBlockByHash"
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
3. Provide this file to the rpcdaemon using `--rpc.accessList` flag
|
||||
|
||||
```
|
||||
> rpcdaemon --private.api.addr=localhost:9090 --http.api=eth,debug,net,web3 --rpc.accessList=rules.json
|
||||
```
|
||||
|
||||
Now only these two methods are available.
|
||||
|
||||
|
||||
## For Developers
|
||||
|
||||
### Code generation
|
||||
|
@ -18,22 +18,23 @@ import (
|
||||
)
|
||||
|
||||
type Flags struct {
|
||||
PrivateApiAddr string
|
||||
Chaindata string
|
||||
SnapshotDir string
|
||||
SnapshotMode string
|
||||
HttpListenAddress string
|
||||
TLSCertfile string
|
||||
TLSCACert string
|
||||
TLSKeyFile string
|
||||
HttpPort int
|
||||
HttpCORSDomain []string
|
||||
HttpVirtualHost []string
|
||||
API []string
|
||||
Gascap uint64
|
||||
MaxTraces uint64
|
||||
TraceType string
|
||||
WebsocketEnabled bool
|
||||
PrivateApiAddr string
|
||||
Chaindata string
|
||||
SnapshotDir string
|
||||
SnapshotMode string
|
||||
HttpListenAddress string
|
||||
TLSCertfile string
|
||||
TLSCACert string
|
||||
TLSKeyFile string
|
||||
HttpPort int
|
||||
HttpCORSDomain []string
|
||||
HttpVirtualHost []string
|
||||
API []string
|
||||
Gascap uint64
|
||||
MaxTraces uint64
|
||||
TraceType string
|
||||
WebsocketEnabled bool
|
||||
RpcAllowListFilePath string
|
||||
}
|
||||
|
||||
var rootCmd = &cobra.Command{
|
||||
@ -76,6 +77,11 @@ func RootCommand() (*cobra.Command, *Flags) {
|
||||
rootCmd.PersistentFlags().Uint64Var(&cfg.MaxTraces, "trace.maxtraces", 200, "Sets a limit on traces that can be returned in trace_filter")
|
||||
rootCmd.PersistentFlags().StringVar(&cfg.TraceType, "trace.type", "parity", "Specify the type of tracing [geth|parity*] (experimental)")
|
||||
rootCmd.PersistentFlags().BoolVar(&cfg.WebsocketEnabled, "ws", false, "Enable Websockets")
|
||||
rootCmd.PersistentFlags().StringVar(&cfg.RpcAllowListFilePath, "rpc.accessList", "", "Specify granular (method-by-method) API allowlist")
|
||||
|
||||
if err := rootCmd.MarkPersistentFlagFilename("rpc.accessList", "json"); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return rootCmd, cfg
|
||||
}
|
||||
@ -124,12 +130,17 @@ func StartRpcServer(ctx context.Context, cfg Flags, rpcAPI []rpc.API) error {
|
||||
httpEndpoint := fmt.Sprintf("%s:%d", cfg.HttpListenAddress, cfg.HttpPort)
|
||||
|
||||
srv := rpc.NewServer()
|
||||
|
||||
allowListForRPC, err := parseAllowListForRPC(cfg.RpcAllowListFilePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
srv.SetAllowList(allowListForRPC)
|
||||
|
||||
if err := node.RegisterApisFromWhitelist(rpcAPI, cfg.API, srv, false); err != nil {
|
||||
return fmt.Errorf("could not start register RPC apis: %w", err)
|
||||
}
|
||||
|
||||
var err error
|
||||
|
||||
httpHandler := node.NewHTTPHandlerStack(srv, cfg.HttpCORSDomain, cfg.HttpVirtualHost)
|
||||
var wsHandler http.Handler
|
||||
if cfg.WebsocketEnabled {
|
||||
|
43
cmd/rpcdaemon/cli/rpc_allow_list.go
Normal file
43
cmd/rpcdaemon/cli/rpc_allow_list.go
Normal file
@ -0,0 +1,43 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/ledgerwatch/turbo-geth/rpc"
|
||||
)
|
||||
|
||||
type allowListFile struct {
|
||||
Allow rpc.AllowList `json:"allow"`
|
||||
}
|
||||
|
||||
func parseAllowListForRPC(path string) (rpc.AllowList, error) {
|
||||
path = strings.TrimSpace(path)
|
||||
if path == "" { // no file is provided
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
file, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
file.Close() //nolint: errcheck
|
||||
}()
|
||||
|
||||
fileContents, err := ioutil.ReadAll(file)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var allowListFileObj allowListFile
|
||||
|
||||
err = json.Unmarshal(fileContents, &allowListFileObj)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return allowListFileObj.Allow, nil
|
||||
}
|
@ -1,487 +0,0 @@
|
||||
// Copyright 2016 The go-ethereum Authors
|
||||
// This file is part of the go-ethereum library.
|
||||
//
|
||||
// The go-ethereum library is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Lesser General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// The go-ethereum library is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Lesser General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Lesser General Public License
|
||||
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package console
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/dop251/goja"
|
||||
"github.com/ledgerwatch/turbo-geth/accounts/scwallet"
|
||||
"github.com/ledgerwatch/turbo-geth/accounts/usbwallet"
|
||||
"github.com/ledgerwatch/turbo-geth/common/hexutil"
|
||||
"github.com/ledgerwatch/turbo-geth/console/prompt"
|
||||
"github.com/ledgerwatch/turbo-geth/internal/jsre"
|
||||
"github.com/ledgerwatch/turbo-geth/rpc"
|
||||
)
|
||||
|
||||
// bridge is a collection of JavaScript utility methods to bride the .js runtime
|
||||
// environment and the Go RPC connection backing the remote method calls.
|
||||
type bridge struct {
|
||||
client *rpc.Client // RPC client to execute Ethereum requests through
|
||||
prompter prompt.UserPrompter // Input prompter to allow interactive user feedback
|
||||
printer io.Writer // Output writer to serialize any display strings to
|
||||
}
|
||||
|
||||
// newBridge creates a new JavaScript wrapper around an RPC client.
|
||||
func newBridge(client *rpc.Client, prompter prompt.UserPrompter, printer io.Writer) *bridge {
|
||||
return &bridge{
|
||||
client: client,
|
||||
prompter: prompter,
|
||||
printer: printer,
|
||||
}
|
||||
}
|
||||
|
||||
func getJeth(vm *goja.Runtime) *goja.Object {
|
||||
jeth := vm.Get("jeth")
|
||||
if jeth == nil {
|
||||
panic(vm.ToValue("jeth object does not exist"))
|
||||
}
|
||||
return jeth.ToObject(vm)
|
||||
}
|
||||
|
||||
// NewAccount is a wrapper around the personal.newAccount RPC method that uses a
|
||||
// non-echoing password prompt to acquire the passphrase and executes the original
|
||||
// RPC method (saved in jeth.newAccount) with it to actually execute the RPC call.
|
||||
func (b *bridge) NewAccount(call jsre.Call) (goja.Value, error) {
|
||||
var (
|
||||
password string
|
||||
confirm string
|
||||
err error
|
||||
)
|
||||
switch {
|
||||
// No password was specified, prompt the user for it
|
||||
case len(call.Arguments) == 0:
|
||||
if password, err = b.prompter.PromptPassword("Passphrase: "); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if confirm, err = b.prompter.PromptPassword("Repeat passphrase: "); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if password != confirm {
|
||||
return nil, fmt.Errorf("passwords don't match")
|
||||
}
|
||||
// A single string password was specified, use that
|
||||
case len(call.Arguments) == 1 && call.Argument(0).ToString() != nil:
|
||||
password = call.Argument(0).ToString().String()
|
||||
default:
|
||||
return nil, fmt.Errorf("expected 0 or 1 string argument")
|
||||
}
|
||||
// Password acquired, execute the call and return
|
||||
newAccount, callable := goja.AssertFunction(getJeth(call.VM).Get("newAccount"))
|
||||
if !callable {
|
||||
return nil, fmt.Errorf("jeth.newAccount is not callable")
|
||||
}
|
||||
ret, err := newAccount(goja.Null(), call.VM.ToValue(password))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
// OpenWallet is a wrapper around personal.openWallet which can interpret and
|
||||
// react to certain error messages, such as the Trezor PIN matrix request.
|
||||
func (b *bridge) OpenWallet(call jsre.Call) (goja.Value, error) {
|
||||
// Make sure we have a wallet specified to open
|
||||
if call.Argument(0).ToObject(call.VM).ClassName() != "String" {
|
||||
return nil, fmt.Errorf("first argument must be the wallet URL to open")
|
||||
}
|
||||
wallet := call.Argument(0)
|
||||
|
||||
var passwd goja.Value
|
||||
if goja.IsUndefined(call.Argument(1)) || goja.IsNull(call.Argument(1)) {
|
||||
passwd = call.VM.ToValue("")
|
||||
} else {
|
||||
passwd = call.Argument(1)
|
||||
}
|
||||
// Open the wallet and return if successful in itself
|
||||
openWallet, callable := goja.AssertFunction(getJeth(call.VM).Get("openWallet"))
|
||||
if !callable {
|
||||
return nil, fmt.Errorf("jeth.openWallet is not callable")
|
||||
}
|
||||
val, err := openWallet(goja.Null(), wallet, passwd)
|
||||
if err == nil {
|
||||
return val, nil
|
||||
}
|
||||
|
||||
// Wallet open failed, report error unless it's a PIN or PUK entry
|
||||
switch {
|
||||
case strings.HasSuffix(err.Error(), usbwallet.ErrTrezorPINNeeded.Error()):
|
||||
val, err = b.readPinAndReopenWallet(call)
|
||||
if err == nil {
|
||||
return val, nil
|
||||
}
|
||||
val, err = b.readPassphraseAndReopenWallet(call)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
case strings.HasSuffix(err.Error(), scwallet.ErrPairingPasswordNeeded.Error()):
|
||||
// PUK input requested, fetch from the user and call open again
|
||||
var input string
|
||||
input, err = b.prompter.PromptPassword("Please enter the pairing password: ")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
passwd = call.VM.ToValue(input)
|
||||
if val, err = openWallet(goja.Null(), wallet, passwd); err != nil {
|
||||
if !strings.HasSuffix(err.Error(), scwallet.ErrPINNeeded.Error()) {
|
||||
return nil, err
|
||||
} else {
|
||||
// PIN input requested, fetch from the user and call open again
|
||||
input, err = b.prompter.PromptPassword("Please enter current PIN: ")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if val, err = openWallet(goja.Null(), wallet, call.VM.ToValue(input)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
case strings.HasSuffix(err.Error(), scwallet.ErrPINUnblockNeeded.Error()):
|
||||
// PIN unblock requested, fetch PUK and new PIN from the user
|
||||
var pukpin string
|
||||
var input string
|
||||
input, err = b.prompter.PromptPassword("Please enter current PUK: ")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pukpin = input
|
||||
input, err = b.prompter.PromptPassword("Please enter new PIN: ")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pukpin += input
|
||||
|
||||
if val, err = openWallet(goja.Null(), wallet, call.VM.ToValue(pukpin)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
case strings.HasSuffix(err.Error(), scwallet.ErrPINNeeded.Error()):
|
||||
// PIN input requested, fetch from the user and call open again
|
||||
var input string
|
||||
input, err = b.prompter.PromptPassword("Please enter current PIN: ")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if val, err = openWallet(goja.Null(), wallet, call.VM.ToValue(input)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
default:
|
||||
// Unknown error occurred, drop to the user
|
||||
return nil, err
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
func (b *bridge) readPassphraseAndReopenWallet(call jsre.Call) (goja.Value, error) {
|
||||
wallet := call.Argument(0)
|
||||
input, err := b.prompter.PromptPassword("Please enter your passphrase: ")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
openWallet, callable := goja.AssertFunction(getJeth(call.VM).Get("openWallet"))
|
||||
if !callable {
|
||||
return nil, fmt.Errorf("jeth.openWallet is not callable")
|
||||
}
|
||||
return openWallet(goja.Null(), wallet, call.VM.ToValue(input))
|
||||
}
|
||||
|
||||
func (b *bridge) readPinAndReopenWallet(call jsre.Call) (goja.Value, error) {
|
||||
wallet := call.Argument(0)
|
||||
// Trezor PIN matrix input requested, display the matrix to the user and fetch the data
|
||||
fmt.Fprintf(b.printer, "Look at the device for number positions\n\n")
|
||||
fmt.Fprintf(b.printer, "7 | 8 | 9\n")
|
||||
fmt.Fprintf(b.printer, "--+---+--\n")
|
||||
fmt.Fprintf(b.printer, "4 | 5 | 6\n")
|
||||
fmt.Fprintf(b.printer, "--+---+--\n")
|
||||
fmt.Fprintf(b.printer, "1 | 2 | 3\n\n")
|
||||
|
||||
input, err := b.prompter.PromptPassword("Please enter current PIN: ")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
openWallet, callable := goja.AssertFunction(getJeth(call.VM).Get("openWallet"))
|
||||
if !callable {
|
||||
return nil, fmt.Errorf("jeth.openWallet is not callable")
|
||||
}
|
||||
return openWallet(goja.Null(), wallet, call.VM.ToValue(input))
|
||||
}
|
||||
|
||||
// UnlockAccount is a wrapper around the personal.unlockAccount RPC method that
|
||||
// uses a non-echoing password prompt to acquire the passphrase and executes the
|
||||
// original RPC method (saved in jeth.unlockAccount) with it to actually execute
|
||||
// the RPC call.
|
||||
func (b *bridge) UnlockAccount(call jsre.Call) (goja.Value, error) {
|
||||
if len(call.Arguments) < 1 {
|
||||
return nil, fmt.Errorf("usage: unlockAccount(account, [ password, duration ])")
|
||||
}
|
||||
|
||||
account := call.Argument(0)
|
||||
// Make sure we have an account specified to unlock.
|
||||
if goja.IsUndefined(account) || goja.IsNull(account) || account.ExportType().Kind() != reflect.String {
|
||||
return nil, fmt.Errorf("first argument must be the account to unlock")
|
||||
}
|
||||
|
||||
// If password is not given or is the null value, prompt the user for it.
|
||||
var passwd goja.Value
|
||||
if goja.IsUndefined(call.Argument(1)) || goja.IsNull(call.Argument(1)) {
|
||||
fmt.Fprintf(b.printer, "Unlock account %s\n", account)
|
||||
input, err := b.prompter.PromptPassword("Passphrase: ")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
passwd = call.VM.ToValue(input)
|
||||
} else {
|
||||
if call.Argument(1).ExportType().Kind() != reflect.String {
|
||||
return nil, fmt.Errorf("password must be a string")
|
||||
}
|
||||
passwd = call.Argument(1)
|
||||
}
|
||||
|
||||
// Third argument is the duration how long the account should be unlocked.
|
||||
duration := goja.Null()
|
||||
if !goja.IsUndefined(call.Argument(2)) && !goja.IsNull(call.Argument(2)) {
|
||||
if !isNumber(call.Argument(2)) {
|
||||
return nil, fmt.Errorf("unlock duration must be a number")
|
||||
}
|
||||
duration = call.Argument(2)
|
||||
}
|
||||
|
||||
// Send the request to the backend and return.
|
||||
unlockAccount, callable := goja.AssertFunction(getJeth(call.VM).Get("unlockAccount"))
|
||||
if !callable {
|
||||
return nil, fmt.Errorf("jeth.unlockAccount is not callable")
|
||||
}
|
||||
return unlockAccount(goja.Null(), account, passwd, duration)
|
||||
}
|
||||
|
||||
// Sign is a wrapper around the personal.sign RPC method that uses a non-echoing password
|
||||
// prompt to acquire the passphrase and executes the original RPC method (saved in
|
||||
// jeth.sign) with it to actually execute the RPC call.
|
||||
func (b *bridge) Sign(call jsre.Call) (goja.Value, error) {
|
||||
if nArgs := len(call.Arguments); nArgs < 2 {
|
||||
return nil, fmt.Errorf("usage: sign(message, account, [ password ])")
|
||||
}
|
||||
var (
|
||||
message = call.Argument(0)
|
||||
account = call.Argument(1)
|
||||
passwd = call.Argument(2)
|
||||
)
|
||||
|
||||
if goja.IsUndefined(message) || message.ExportType().Kind() != reflect.String {
|
||||
return nil, fmt.Errorf("first argument must be the message to sign")
|
||||
}
|
||||
if goja.IsUndefined(account) || account.ExportType().Kind() != reflect.String {
|
||||
return nil, fmt.Errorf("second argument must be the account to sign with")
|
||||
}
|
||||
|
||||
// if the password is not given or null ask the user and ensure password is a string
|
||||
if goja.IsUndefined(passwd) || goja.IsNull(passwd) {
|
||||
fmt.Fprintf(b.printer, "Give password for account %s\n", account)
|
||||
input, err := b.prompter.PromptPassword("Password: ")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
passwd = call.VM.ToValue(input)
|
||||
} else if passwd.ExportType().Kind() != reflect.String {
|
||||
return nil, fmt.Errorf("third argument must be the password to unlock the account")
|
||||
}
|
||||
|
||||
// Send the request to the backend and return
|
||||
sign, callable := goja.AssertFunction(getJeth(call.VM).Get("sign"))
|
||||
if !callable {
|
||||
return nil, fmt.Errorf("jeth.sign is not callable")
|
||||
}
|
||||
return sign(goja.Null(), message, account, passwd)
|
||||
}
|
||||
|
||||
// Sleep will block the console for the specified number of seconds.
|
||||
func (b *bridge) Sleep(call jsre.Call) (goja.Value, error) {
|
||||
if nArgs := len(call.Arguments); nArgs < 1 {
|
||||
return nil, fmt.Errorf("usage: sleep(<number of seconds>)")
|
||||
}
|
||||
sleepObj := call.Argument(0)
|
||||
if goja.IsUndefined(sleepObj) || goja.IsNull(sleepObj) || !isNumber(sleepObj) {
|
||||
return nil, fmt.Errorf("usage: sleep(<number of seconds>)")
|
||||
}
|
||||
sleep := sleepObj.ToFloat()
|
||||
time.Sleep(time.Duration(sleep * float64(time.Second)))
|
||||
return call.VM.ToValue(true), nil
|
||||
}
|
||||
|
||||
// SleepBlocks will block the console for a specified number of new blocks optionally
|
||||
// until the given timeout is reached.
|
||||
func (b *bridge) SleepBlocks(call jsre.Call) (goja.Value, error) {
|
||||
// Parse the input parameters for the sleep.
|
||||
var (
|
||||
blocks = int64(0)
|
||||
sleep = int64(9999999999999999) // indefinitely
|
||||
)
|
||||
nArgs := len(call.Arguments)
|
||||
if nArgs == 0 {
|
||||
return nil, fmt.Errorf("usage: sleepBlocks(<n blocks>[, max sleep in seconds])")
|
||||
}
|
||||
if nArgs >= 1 {
|
||||
if goja.IsNull(call.Argument(0)) || goja.IsUndefined(call.Argument(0)) || !isNumber(call.Argument(0)) {
|
||||
return nil, fmt.Errorf("expected number as first argument")
|
||||
}
|
||||
blocks = call.Argument(0).ToInteger()
|
||||
}
|
||||
if nArgs >= 2 {
|
||||
if goja.IsNull(call.Argument(1)) || goja.IsUndefined(call.Argument(1)) || !isNumber(call.Argument(1)) {
|
||||
return nil, fmt.Errorf("expected number as second argument")
|
||||
}
|
||||
sleep = call.Argument(1).ToInteger()
|
||||
}
|
||||
|
||||
// Poll the current block number until either it or a timeout is reached.
|
||||
deadline := time.Now().Add(time.Duration(sleep) * time.Second)
|
||||
var lastNumber hexutil.Uint64
|
||||
if err := b.client.Call(&lastNumber, "eth_blockNumber"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for time.Now().Before(deadline) {
|
||||
var number hexutil.Uint64
|
||||
if err := b.client.Call(&number, "eth_blockNumber"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if number != lastNumber {
|
||||
lastNumber = number
|
||||
blocks--
|
||||
}
|
||||
if blocks <= 0 {
|
||||
break
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
return call.VM.ToValue(true), nil
|
||||
}
|
||||
|
||||
type jsonrpcCall struct {
|
||||
ID int64
|
||||
Method string
|
||||
Params []interface{}
|
||||
}
|
||||
|
||||
// Send implements the web3 provider "send" method.
|
||||
func (b *bridge) Send(call jsre.Call) (goja.Value, error) {
|
||||
// Remarshal the request into a Go value.
|
||||
reqVal, err := call.Argument(0).ToObject(call.VM).MarshalJSON()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var (
|
||||
rawReq = string(reqVal)
|
||||
dec = json.NewDecoder(strings.NewReader(rawReq))
|
||||
reqs []jsonrpcCall
|
||||
batch bool
|
||||
)
|
||||
dec.UseNumber() // avoid float64s
|
||||
if rawReq[0] == '[' {
|
||||
batch = true
|
||||
dec.Decode(&reqs)
|
||||
} else {
|
||||
batch = false
|
||||
reqs = make([]jsonrpcCall, 1)
|
||||
dec.Decode(&reqs[0])
|
||||
}
|
||||
|
||||
// Execute the requests.
|
||||
resps := make([]*goja.Object, 0, len(reqs))
|
||||
for _, req := range reqs {
|
||||
resp := call.VM.NewObject()
|
||||
resp.Set("jsonrpc", "2.0") //nolint:errcheck
|
||||
resp.Set("id", req.ID) //nolint:errcheck
|
||||
|
||||
var result json.RawMessage
|
||||
if err = b.client.Call(&result, req.Method, req.Params...); err == nil {
|
||||
if result == nil {
|
||||
// Special case null because it is decoded as an empty
|
||||
// raw message for some reason.
|
||||
resp.Set("result", goja.Null()) // nolint:errcheck
|
||||
} else {
|
||||
JSON := call.VM.Get("JSON").ToObject(call.VM)
|
||||
parse, callable := goja.AssertFunction(JSON.Get("parse"))
|
||||
if !callable {
|
||||
return nil, fmt.Errorf("JSON.parse is not a function")
|
||||
}
|
||||
resultVal, err := parse(goja.Null(), call.VM.ToValue(string(result)))
|
||||
if err != nil {
|
||||
setError(resp, -32603, err.Error(), nil)
|
||||
} else {
|
||||
resp.Set("result", resultVal)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
code := -32603
|
||||
var data interface{}
|
||||
if err, ok := err.(rpc.Error); ok {
|
||||
code = err.ErrorCode()
|
||||
}
|
||||
if err, ok := err.(rpc.DataError); ok {
|
||||
data = err.ErrorData()
|
||||
}
|
||||
setError(resp, code, err.Error(), data)
|
||||
}
|
||||
resps = append(resps, resp)
|
||||
}
|
||||
// Return the responses either to the callback (if supplied)
|
||||
// or directly as the return value.
|
||||
var result goja.Value
|
||||
if batch {
|
||||
result = call.VM.ToValue(resps)
|
||||
} else {
|
||||
result = resps[0]
|
||||
}
|
||||
if fn, isFunc := goja.AssertFunction(call.Argument(1)); isFunc {
|
||||
_, err := fn(goja.Null(), goja.Null(), result)
|
||||
return goja.Undefined(), err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func setError(resp *goja.Object, code int, msg string, data interface{}) {
|
||||
err := make(map[string]interface{})
|
||||
err["code"] = code
|
||||
err["message"] = msg
|
||||
if data != nil {
|
||||
err["data"] = data
|
||||
}
|
||||
resp.Set("error", err) //nolint:errcheck
|
||||
}
|
||||
|
||||
// isNumber returns true if input value is a JS number.
|
||||
func isNumber(v goja.Value) bool {
|
||||
k := v.ExportType().Kind()
|
||||
return k >= reflect.Int && k <= reflect.Float64
|
||||
}
|
||||
|
||||
func getObject(vm *goja.Runtime, name string) *goja.Object {
|
||||
v := vm.Get(name)
|
||||
if v == nil {
|
||||
return nil
|
||||
}
|
||||
return v.ToObject(vm)
|
||||
}
|
@ -1,49 +0,0 @@
|
||||
// Copyright 2020 The go-ethereum Authors
|
||||
// This file is part of the go-ethereum library.
|
||||
//
|
||||
// The go-ethereum library is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Lesser General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// The go-ethereum library is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Lesser General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Lesser General Public License
|
||||
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
//nolint:errcheck
|
||||
package console
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/dop251/goja"
|
||||
"github.com/ledgerwatch/turbo-geth/internal/jsre"
|
||||
)
|
||||
|
||||
// TestUndefinedAsParam ensures that personal functions can receive
|
||||
// `undefined` as a parameter.
|
||||
func TestUndefinedAsParam(t *testing.T) {
|
||||
b := bridge{}
|
||||
call := jsre.Call{}
|
||||
call.Arguments = []goja.Value{goja.Undefined()}
|
||||
|
||||
b.UnlockAccount(call)
|
||||
b.Sign(call)
|
||||
b.Sleep(call)
|
||||
}
|
||||
|
||||
// TestNullAsParam ensures that personal functions can receive
|
||||
// `null` as a parameter.
|
||||
func TestNullAsParam(t *testing.T) {
|
||||
b := bridge{}
|
||||
call := jsre.Call{}
|
||||
call.Arguments = []goja.Value{goja.Null()}
|
||||
|
||||
b.UnlockAccount(call)
|
||||
b.Sign(call)
|
||||
b.Sleep(call)
|
||||
}
|
@ -1,487 +0,0 @@
|
||||
// Copyright 2016 The go-ethereum Authors
|
||||
// This file is part of the go-ethereum library.
|
||||
//
|
||||
// The go-ethereum library is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Lesser General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// The go-ethereum library is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Lesser General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Lesser General Public License
|
||||
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
// nolint:errcheck
|
||||
package console
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"github.com/dop251/goja"
|
||||
"github.com/ledgerwatch/turbo-geth/console/prompt"
|
||||
"github.com/ledgerwatch/turbo-geth/internal/jsre"
|
||||
"github.com/ledgerwatch/turbo-geth/internal/jsre/deps"
|
||||
"github.com/ledgerwatch/turbo-geth/internal/web3ext"
|
||||
"github.com/ledgerwatch/turbo-geth/rpc"
|
||||
"github.com/mattn/go-colorable"
|
||||
"github.com/peterh/liner"
|
||||
)
|
||||
|
||||
var (
|
||||
// u: unlock, s: signXX, sendXX, n: newAccount, i: importXX
|
||||
passwordRegexp = regexp.MustCompile(`personal.[nusi]`)
|
||||
onlyWhitespace = regexp.MustCompile(`^\s*$`)
|
||||
exit = regexp.MustCompile(`^\s*exit\s*;*\s*$`)
|
||||
)
|
||||
|
||||
// HistoryFile is the file within the data directory to store input scrollback.
|
||||
const HistoryFile = "history"
|
||||
|
||||
// DefaultPrompt is the default prompt line prefix to use for user input querying.
|
||||
const DefaultPrompt = "> "
|
||||
|
||||
// Config is the collection of configurations to fine tune the behavior of the
|
||||
// JavaScript console.
|
||||
type Config struct {
|
||||
DataDir string // Data directory to store the console history at
|
||||
DocRoot string // Filesystem path from where to load JavaScript files from
|
||||
Client *rpc.Client // RPC client to execute Ethereum requests through
|
||||
Prompt string // Input prompt prefix string (defaults to DefaultPrompt)
|
||||
Prompter prompt.UserPrompter // Input prompter to allow interactive user feedback (defaults to TerminalPrompter)
|
||||
Printer io.Writer // Output writer to serialize any display strings to (defaults to os.Stdout)
|
||||
Preload []string // Absolute paths to JavaScript files to preload
|
||||
}
|
||||
|
||||
// Console is a JavaScript interpreted runtime environment. It is a fully fledged
|
||||
// JavaScript console attached to a running node via an external or in-process RPC
|
||||
// client.
|
||||
type Console struct {
|
||||
client *rpc.Client // RPC client to execute Ethereum requests through
|
||||
jsre *jsre.JSRE // JavaScript runtime environment running the interpreter
|
||||
prompt string // Input prompt prefix string
|
||||
prompter prompt.UserPrompter // Input prompter to allow interactive user feedback
|
||||
histPath string // Absolute path to the console scrollback history
|
||||
history []string // Scroll history maintained by the console
|
||||
printer io.Writer // Output writer to serialize any display strings to
|
||||
}
|
||||
|
||||
// New initializes a JavaScript interpreted runtime environment and sets defaults
|
||||
// with the config struct.
|
||||
func New(config Config) (*Console, error) {
|
||||
// Handle unset config values gracefully
|
||||
if config.Prompter == nil {
|
||||
config.Prompter = prompt.Stdin
|
||||
}
|
||||
if config.Prompt == "" {
|
||||
config.Prompt = DefaultPrompt
|
||||
}
|
||||
if config.Printer == nil {
|
||||
config.Printer = colorable.NewColorableStdout()
|
||||
}
|
||||
|
||||
// Initialize the console and return
|
||||
console := &Console{
|
||||
client: config.Client,
|
||||
jsre: jsre.New(config.DocRoot, config.Printer),
|
||||
prompt: config.Prompt,
|
||||
prompter: config.Prompter,
|
||||
printer: config.Printer,
|
||||
histPath: filepath.Join(config.DataDir, HistoryFile),
|
||||
}
|
||||
if err := os.MkdirAll(config.DataDir, 0700); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := console.init(config.Preload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return console, nil
|
||||
}
|
||||
|
||||
// init retrieves the available APIs from the remote RPC provider and initializes
|
||||
// the console's JavaScript namespaces based on the exposed modules.
|
||||
func (c *Console) init(preload []string) error {
|
||||
c.initConsoleObject()
|
||||
|
||||
// Initialize the JavaScript <-> Go RPC bridge.
|
||||
bridge := newBridge(c.client, c.prompter, c.printer)
|
||||
if err := c.initWeb3(bridge); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.initExtensions(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Add bridge overrides for web3.js functionality.
|
||||
c.jsre.Do(func(vm *goja.Runtime) {
|
||||
c.initAdmin(vm, bridge)
|
||||
c.initPersonal(vm, bridge)
|
||||
})
|
||||
|
||||
// Preload JavaScript files.
|
||||
for _, path := range preload {
|
||||
if err := c.jsre.Exec(path); err != nil {
|
||||
failure := err.Error()
|
||||
if gojaErr, ok := err.(*goja.Exception); ok {
|
||||
failure = gojaErr.String()
|
||||
}
|
||||
return fmt.Errorf("%s: %v", path, failure)
|
||||
}
|
||||
}
|
||||
|
||||
// Configure the input prompter for history and tab completion.
|
||||
if c.prompter != nil {
|
||||
if content, err := ioutil.ReadFile(c.histPath); err != nil {
|
||||
c.prompter.SetHistory(nil)
|
||||
} else {
|
||||
c.history = strings.Split(string(content), "\n")
|
||||
c.prompter.SetHistory(c.history)
|
||||
}
|
||||
c.prompter.SetWordCompleter(c.AutoCompleteInput)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Console) initConsoleObject() {
|
||||
c.jsre.Do(func(vm *goja.Runtime) {
|
||||
console := vm.NewObject()
|
||||
console.Set("log", c.consoleOutput)
|
||||
console.Set("error", c.consoleOutput)
|
||||
vm.Set("console", console)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Console) initWeb3(bridge *bridge) error {
|
||||
bnJS := string(deps.MustAsset("bignumber.js"))
|
||||
web3JS := string(deps.MustAsset("web3.js"))
|
||||
if err := c.jsre.Compile("bignumber.js", bnJS); err != nil {
|
||||
return fmt.Errorf("bignumber.js: %v", err)
|
||||
}
|
||||
if err := c.jsre.Compile("web3.js", web3JS); err != nil {
|
||||
return fmt.Errorf("web3.js: %v", err)
|
||||
}
|
||||
if _, err := c.jsre.Run("var Web3 = require('web3');"); err != nil {
|
||||
return fmt.Errorf("web3 require: %v", err)
|
||||
}
|
||||
var err error
|
||||
c.jsre.Do(func(vm *goja.Runtime) {
|
||||
transport := vm.NewObject()
|
||||
transport.Set("send", jsre.MakeCallback(vm, bridge.Send))
|
||||
transport.Set("sendAsync", jsre.MakeCallback(vm, bridge.Send))
|
||||
vm.Set("_consoleWeb3Transport", transport)
|
||||
_, err = vm.RunString("var web3 = new Web3(_consoleWeb3Transport)")
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// initExtensions loads and registers web3.js extensions.
|
||||
func (c *Console) initExtensions() error {
|
||||
// Compute aliases from server-provided modules.
|
||||
apis, err := c.client.SupportedModules()
|
||||
if err != nil {
|
||||
return fmt.Errorf("api modules: %v", err)
|
||||
}
|
||||
aliases := map[string]struct{}{"eth": {}, "personal": {}}
|
||||
for api := range apis {
|
||||
if api == "web3" {
|
||||
continue
|
||||
}
|
||||
aliases[api] = struct{}{}
|
||||
if file, ok := web3ext.Modules[api]; ok {
|
||||
if err = c.jsre.Compile(api+".js", file); err != nil {
|
||||
return fmt.Errorf("%s.js: %v", api, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Apply aliases.
|
||||
c.jsre.Do(func(vm *goja.Runtime) {
|
||||
web3 := getObject(vm, "web3")
|
||||
for name := range aliases {
|
||||
if v := web3.Get(name); v != nil {
|
||||
vm.Set(name, v)
|
||||
}
|
||||
}
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// initAdmin creates additional admin APIs implemented by the bridge.
|
||||
func (c *Console) initAdmin(vm *goja.Runtime, bridge *bridge) {
|
||||
if admin := getObject(vm, "admin"); admin != nil {
|
||||
admin.Set("sleepBlocks", jsre.MakeCallback(vm, bridge.SleepBlocks))
|
||||
admin.Set("sleep", jsre.MakeCallback(vm, bridge.Sleep))
|
||||
admin.Set("clearHistory", c.clearHistory)
|
||||
}
|
||||
}
|
||||
|
||||
// initPersonal redirects account-related API methods through the bridge.
|
||||
//
|
||||
// If the console is in interactive mode and the 'personal' API is available, override
|
||||
// the openWallet, unlockAccount, newAccount and sign methods since these require user
|
||||
// interaction. The original web3 callbacks are stored in 'jeth'. These will be called
|
||||
// by the bridge after the prompt and send the original web3 request to the backend.
|
||||
func (c *Console) initPersonal(vm *goja.Runtime, bridge *bridge) {
|
||||
personal := getObject(vm, "personal")
|
||||
if personal == nil || c.prompter == nil {
|
||||
return
|
||||
}
|
||||
jeth := vm.NewObject()
|
||||
vm.Set("jeth", jeth)
|
||||
jeth.Set("openWallet", personal.Get("openWallet"))
|
||||
jeth.Set("unlockAccount", personal.Get("unlockAccount"))
|
||||
jeth.Set("newAccount", personal.Get("newAccount"))
|
||||
jeth.Set("sign", personal.Get("sign"))
|
||||
personal.Set("openWallet", jsre.MakeCallback(vm, bridge.OpenWallet))
|
||||
personal.Set("unlockAccount", jsre.MakeCallback(vm, bridge.UnlockAccount))
|
||||
personal.Set("newAccount", jsre.MakeCallback(vm, bridge.NewAccount))
|
||||
personal.Set("sign", jsre.MakeCallback(vm, bridge.Sign))
|
||||
}
|
||||
|
||||
func (c *Console) clearHistory() {
|
||||
c.history = nil
|
||||
c.prompter.ClearHistory()
|
||||
if err := os.Remove(c.histPath); err != nil {
|
||||
fmt.Fprintln(c.printer, "can't delete history file:", err)
|
||||
} else {
|
||||
fmt.Fprintln(c.printer, "history file deleted")
|
||||
}
|
||||
}
|
||||
|
||||
// consoleOutput is an override for the console.log and console.error methods to
|
||||
// stream the output into the configured output stream instead of stdout.
|
||||
func (c *Console) consoleOutput(call goja.FunctionCall) goja.Value {
|
||||
var output []string
|
||||
for _, argument := range call.Arguments {
|
||||
output = append(output, fmt.Sprintf("%v", argument))
|
||||
}
|
||||
fmt.Fprintln(c.printer, strings.Join(output, " "))
|
||||
return goja.Null()
|
||||
}
|
||||
|
||||
// AutoCompleteInput is a pre-assembled word completer to be used by the user
|
||||
// input prompter to provide hints to the user about the methods available.
|
||||
func (c *Console) AutoCompleteInput(line string, pos int) (string, []string, string) {
|
||||
// No completions can be provided for empty inputs
|
||||
if len(line) == 0 || pos == 0 {
|
||||
return "", nil, ""
|
||||
}
|
||||
// Chunck data to relevant part for autocompletion
|
||||
// E.g. in case of nested lines eth.getBalance(eth.coinb<tab><tab>
|
||||
start := pos - 1
|
||||
for ; start > 0; start-- {
|
||||
// Skip all methods and namespaces (i.e. including the dot)
|
||||
if line[start] == '.' || (line[start] >= 'a' && line[start] <= 'z') || (line[start] >= 'A' && line[start] <= 'Z') {
|
||||
continue
|
||||
}
|
||||
// Handle web3 in a special way (i.e. other numbers aren't auto completed)
|
||||
if start >= 3 && line[start-3:start] == "web3" {
|
||||
start -= 3
|
||||
continue
|
||||
}
|
||||
// We've hit an unexpected character, autocomplete form here
|
||||
start++
|
||||
break
|
||||
}
|
||||
return line[:start], c.jsre.CompleteKeywords(line[start:pos]), line[pos:]
|
||||
}
|
||||
|
||||
// Welcome show summary of current Geth instance and some metadata about the
|
||||
// console's available modules.
|
||||
func (c *Console) Welcome() {
|
||||
message := "Welcome to the Geth JavaScript console!\n\n"
|
||||
|
||||
// Print some generic Geth metadata
|
||||
if res, err := c.jsre.Run(`
|
||||
var message = "instance: " + web3.version.node + "\n";
|
||||
try {
|
||||
message += "coinbase: " + eth.coinbase + "\n";
|
||||
} catch (err) {}
|
||||
message += "at block: " + eth.blockNumber + " (" + new Date(1000 * eth.getBlock(eth.blockNumber).timestamp) + ")\n";
|
||||
try {
|
||||
message += " datadir: " + admin.datadir + "\n";
|
||||
} catch (err) {}
|
||||
message
|
||||
`); err == nil {
|
||||
message += res.String()
|
||||
}
|
||||
// List all the supported modules for the user to call
|
||||
if apis, err := c.client.SupportedModules(); err == nil {
|
||||
modules := make([]string, 0, len(apis))
|
||||
for api, version := range apis {
|
||||
modules = append(modules, fmt.Sprintf("%s:%s", api, version))
|
||||
}
|
||||
sort.Strings(modules)
|
||||
message += " modules: " + strings.Join(modules, " ") + "\n"
|
||||
}
|
||||
fmt.Fprintln(c.printer, message)
|
||||
}
|
||||
|
||||
// Evaluate executes code and pretty prints the result to the specified output
|
||||
// stream.
|
||||
func (c *Console) Evaluate(statement string) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
fmt.Fprintf(c.printer, "[native] error: %v\n", r)
|
||||
}
|
||||
}()
|
||||
c.jsre.Evaluate(statement, c.printer)
|
||||
}
|
||||
|
||||
// Interactive starts an interactive user session, where input is propted from
|
||||
// the configured user prompter.
|
||||
func (c *Console) Interactive() {
|
||||
var (
|
||||
prompt = c.prompt // the current prompt line (used for multi-line inputs)
|
||||
indents = 0 // the current number of input indents (used for multi-line inputs)
|
||||
input = "" // the current user input
|
||||
inputLine = make(chan string, 1) // receives user input
|
||||
inputErr = make(chan error, 1) // receives liner errors
|
||||
requestLine = make(chan string) // requests a line of input
|
||||
interrupt = make(chan os.Signal, 1)
|
||||
)
|
||||
|
||||
// Monitor Ctrl-C. While liner does turn on the relevant terminal mode bits to avoid
|
||||
// the signal, a signal can still be received for unsupported terminals. Unfortunately
|
||||
// there is no way to cancel the line reader when this happens. The readLines
|
||||
// goroutine will be leaked in this case.
|
||||
signal.Notify(interrupt, syscall.SIGINT, syscall.SIGTERM)
|
||||
defer signal.Stop(interrupt)
|
||||
|
||||
// The line reader runs in a separate goroutine.
|
||||
go c.readLines(inputLine, inputErr, requestLine)
|
||||
defer close(requestLine)
|
||||
|
||||
for {
|
||||
// Send the next prompt, triggering an input read.
|
||||
requestLine <- prompt
|
||||
|
||||
select {
|
||||
case <-interrupt:
|
||||
fmt.Fprintln(c.printer, "caught interrupt, exiting")
|
||||
return
|
||||
|
||||
case err := <-inputErr:
|
||||
if err == liner.ErrPromptAborted && indents > 0 {
|
||||
// When prompting for multi-line input, the first Ctrl-C resets
|
||||
// the multi-line state.
|
||||
prompt, indents, input = c.prompt, 0, ""
|
||||
continue
|
||||
}
|
||||
return
|
||||
|
||||
case line := <-inputLine:
|
||||
// User input was returned by the prompter, handle special cases.
|
||||
if indents <= 0 && exit.MatchString(line) {
|
||||
return
|
||||
}
|
||||
if onlyWhitespace.MatchString(line) {
|
||||
continue
|
||||
}
|
||||
// Append the line to the input and check for multi-line interpretation.
|
||||
input += line + "\n"
|
||||
indents = countIndents(input)
|
||||
if indents <= 0 {
|
||||
prompt = c.prompt
|
||||
} else {
|
||||
prompt = strings.Repeat(".", indents*3) + " "
|
||||
}
|
||||
// If all the needed lines are present, save the command and run it.
|
||||
if indents <= 0 {
|
||||
if len(input) > 0 && input[0] != ' ' && !passwordRegexp.MatchString(input) {
|
||||
if command := strings.TrimSpace(input); len(c.history) == 0 || command != c.history[len(c.history)-1] {
|
||||
c.history = append(c.history, command)
|
||||
if c.prompter != nil {
|
||||
c.prompter.AppendHistory(command)
|
||||
}
|
||||
}
|
||||
}
|
||||
c.Evaluate(input)
|
||||
input = ""
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// readLines runs in its own goroutine, prompting for input.
|
||||
func (c *Console) readLines(input chan<- string, errc chan<- error, prompt <-chan string) {
|
||||
for p := range prompt {
|
||||
line, err := c.prompter.PromptInput(p)
|
||||
if err != nil {
|
||||
errc <- err
|
||||
} else {
|
||||
input <- line
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// countIndents returns the number of identations for the given input.
|
||||
// In case of invalid input such as var a = } the result can be negative.
|
||||
func countIndents(input string) int {
|
||||
var (
|
||||
indents = 0
|
||||
inString = false
|
||||
strOpenChar = ' ' // keep track of the string open char to allow var str = "I'm ....";
|
||||
charEscaped = false // keep track if the previous char was the '\' char, allow var str = "abc\"def";
|
||||
)
|
||||
|
||||
for _, c := range input {
|
||||
switch c {
|
||||
case '\\':
|
||||
// indicate next char as escaped when in string and previous char isn't escaping this backslash
|
||||
if !charEscaped && inString {
|
||||
charEscaped = true
|
||||
}
|
||||
case '\'', '"':
|
||||
if inString && !charEscaped && strOpenChar == c { // end string
|
||||
inString = false
|
||||
} else if !inString && !charEscaped { // begin string
|
||||
inString = true
|
||||
strOpenChar = c
|
||||
}
|
||||
charEscaped = false
|
||||
case '{', '(':
|
||||
if !inString { // ignore brackets when in string, allow var str = "a{"; without indenting
|
||||
indents++
|
||||
}
|
||||
charEscaped = false
|
||||
case '}', ')':
|
||||
if !inString {
|
||||
indents--
|
||||
}
|
||||
charEscaped = false
|
||||
default:
|
||||
charEscaped = false
|
||||
}
|
||||
}
|
||||
|
||||
return indents
|
||||
}
|
||||
|
||||
// Execute runs the JavaScript file specified as the argument.
|
||||
func (c *Console) Execute(path string) error {
|
||||
return c.jsre.Exec(path)
|
||||
}
|
||||
|
||||
// Stop cleans up the console and terminates the runtime environment.
|
||||
func (c *Console) Stop(graceful bool) error {
|
||||
if err := ioutil.WriteFile(c.histPath, []byte(strings.Join(c.history, "\n")), 0600); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.Chmod(c.histPath, 0600); err != nil { // Force 0600, even if it was different previously
|
||||
return err
|
||||
}
|
||||
c.jsre.Stop(graceful)
|
||||
return nil
|
||||
}
|
@ -1,346 +0,0 @@
|
||||
// Copyright 2015 The go-ethereum Authors
|
||||
// This file is part of the go-ethereum library.
|
||||
//
|
||||
// The go-ethereum library is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Lesser General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// The go-ethereum library is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Lesser General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Lesser General Public License
|
||||
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package console
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
|
||||
"github.com/ledgerwatch/turbo-geth/common"
|
||||
"github.com/ledgerwatch/turbo-geth/consensus/ethash"
|
||||
"github.com/ledgerwatch/turbo-geth/console/prompt"
|
||||
"github.com/ledgerwatch/turbo-geth/core"
|
||||
"github.com/ledgerwatch/turbo-geth/eth"
|
||||
"github.com/ledgerwatch/turbo-geth/internal/jsre"
|
||||
"github.com/ledgerwatch/turbo-geth/miner"
|
||||
"github.com/ledgerwatch/turbo-geth/node"
|
||||
)
|
||||
|
||||
const (
|
||||
testInstance = "console-tester"
|
||||
testAddress = "0x8605cdbbdb6d264aa742e77020dcbc58fcdce182"
|
||||
)
|
||||
|
||||
// hookedPrompter implements UserPrompter to simulate use input via channels.
|
||||
type hookedPrompter struct {
|
||||
scheduler chan string
|
||||
}
|
||||
|
||||
func (p *hookedPrompter) PromptInput(prompt string) (string, error) {
|
||||
// Send the prompt to the tester
|
||||
select {
|
||||
case p.scheduler <- prompt:
|
||||
case <-time.After(time.Second):
|
||||
return "", errors.New("prompt timeout")
|
||||
}
|
||||
// Retrieve the response and feed to the console
|
||||
select {
|
||||
case input := <-p.scheduler:
|
||||
return input, nil
|
||||
case <-time.After(time.Second):
|
||||
return "", errors.New("input timeout")
|
||||
}
|
||||
}
|
||||
|
||||
func (p *hookedPrompter) PromptPassword(prompt string) (string, error) {
|
||||
return "", errors.New("not implemented")
|
||||
}
|
||||
func (p *hookedPrompter) PromptConfirm(prompt string) (bool, error) {
|
||||
return false, errors.New("not implemented")
|
||||
}
|
||||
func (p *hookedPrompter) SetHistory(history []string) {}
|
||||
func (p *hookedPrompter) AppendHistory(command string) {}
|
||||
func (p *hookedPrompter) ClearHistory() {}
|
||||
func (p *hookedPrompter) SetWordCompleter(completer prompt.WordCompleter) {}
|
||||
|
||||
// tester is a console test environment for the console tests to operate on.
|
||||
type tester struct {
|
||||
workspace string
|
||||
stack *node.Node
|
||||
ethereum *eth.Ethereum
|
||||
console *Console
|
||||
input *hookedPrompter
|
||||
output *bytes.Buffer
|
||||
}
|
||||
|
||||
// newTester creates a test environment based on which the console can operate.
|
||||
// Please ensure you call Close() on the returned tester to avoid leaks.
|
||||
func newTester(t *testing.T, confOverride func(*eth.Config)) *tester {
|
||||
t.Helper()
|
||||
// Create a temporary storage for the node keys and initialize it
|
||||
workspace, err := ioutil.TempDir("", "console-tester-")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create temporary keystore: %v", err)
|
||||
}
|
||||
|
||||
// Create a networkless protocol stack and start an Ethereum service within
|
||||
stack, err := node.New(&node.Config{DataDir: workspace, UseLightweightKDF: true, Name: testInstance})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create node: %v", err)
|
||||
}
|
||||
ethConf := ð.Config{
|
||||
Genesis: core.DeveloperGenesisBlock(15, common.Address{}),
|
||||
Miner: miner.Config{
|
||||
Etherbase: common.HexToAddress(testAddress),
|
||||
},
|
||||
Ethash: ethash.Config{
|
||||
PowMode: ethash.ModeTest,
|
||||
},
|
||||
Pruning: false,
|
||||
}
|
||||
spew.Dump(ethConf)
|
||||
if confOverride != nil {
|
||||
confOverride(ethConf)
|
||||
}
|
||||
ethBackend, err := eth.New(stack, ethConf)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to register Ethereum protocol: %v", err)
|
||||
}
|
||||
// Start the node and assemble the JavaScript console around it
|
||||
if err = stack.Start(); err != nil {
|
||||
t.Fatalf("failed to start test stack: %v", err)
|
||||
}
|
||||
client, err := stack.Attach()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to attach to node: %v", err)
|
||||
}
|
||||
prompter := &hookedPrompter{scheduler: make(chan string)}
|
||||
printer := new(bytes.Buffer)
|
||||
|
||||
console, err := New(Config{
|
||||
DataDir: stack.DataDir(),
|
||||
DocRoot: "testdata",
|
||||
Client: client,
|
||||
Prompter: prompter,
|
||||
Printer: printer,
|
||||
Preload: []string{"preload.js"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create JavaScript console: %v", err)
|
||||
}
|
||||
// Create the final tester and return
|
||||
return &tester{
|
||||
workspace: workspace,
|
||||
stack: stack,
|
||||
ethereum: ethBackend,
|
||||
console: console,
|
||||
input: prompter,
|
||||
output: printer,
|
||||
}
|
||||
}
|
||||
|
||||
// Close cleans up any temporary data folders and held resources.
|
||||
func (env *tester) Close(t *testing.T) {
|
||||
if err := env.console.Stop(false); err != nil {
|
||||
t.Errorf("failed to stop embedded console: %v", err)
|
||||
}
|
||||
if err := env.stack.Close(); err != nil {
|
||||
t.Errorf("failed to tear down embedded node: %v", err)
|
||||
}
|
||||
os.RemoveAll(env.workspace)
|
||||
}
|
||||
|
||||
// Tests that the node lists the correct welcome message, notably that it contains
|
||||
// the instance name, coinbase account, block number, data directory and supported
|
||||
// console modules.
|
||||
func TestWelcome(t *testing.T) {
|
||||
tester := newTester(t, nil)
|
||||
defer tester.Close(t)
|
||||
|
||||
tester.console.Welcome()
|
||||
|
||||
output := tester.output.String()
|
||||
if want := "Welcome"; !strings.Contains(output, want) {
|
||||
t.Fatalf("console output missing welcome message: have\n%s\nwant also %s", output, want)
|
||||
}
|
||||
if want := fmt.Sprintf("instance: %s", testInstance); !strings.Contains(output, want) {
|
||||
t.Fatalf("console output missing instance: have\n%s\nwant also %s", output, want)
|
||||
}
|
||||
//if want := fmt.Sprintf("coinbase: %s", testAddress); !strings.Contains(output, want) {
|
||||
// t.Fatalf("console output missing coinbase: have\n%s\nwant also %s", output, want)
|
||||
//}
|
||||
if want := "at block: 0"; !strings.Contains(output, want) {
|
||||
t.Fatalf("console output missing sync status: have\n%s\nwant also %s", output, want)
|
||||
}
|
||||
if want := fmt.Sprintf("datadir: %s", tester.workspace); !strings.Contains(output, want) {
|
||||
t.Fatalf("console output missing coinbase: have\n%s\nwant also %s", output, want)
|
||||
}
|
||||
}
|
||||
|
||||
// Tests that JavaScript statement evaluation works as intended.
|
||||
func TestEvaluate(t *testing.T) {
|
||||
tester := newTester(t, nil)
|
||||
defer tester.Close(t)
|
||||
|
||||
tester.console.Evaluate("2 + 2")
|
||||
if output := tester.output.String(); !strings.Contains(output, "4") {
|
||||
t.Fatalf("statement evaluation failed: have %s, want %s", output, "4")
|
||||
}
|
||||
}
|
||||
|
||||
// Tests that the console can be used in interactive mode.
|
||||
func TestInteractive(t *testing.T) {
|
||||
// Create a tester and run an interactive console in the background
|
||||
tester := newTester(t, nil)
|
||||
defer tester.Close(t)
|
||||
|
||||
go tester.console.Interactive()
|
||||
|
||||
// Wait for a prompt and send a statement back
|
||||
select {
|
||||
case <-tester.input.scheduler:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatalf("initial prompt timeout")
|
||||
}
|
||||
select {
|
||||
case tester.input.scheduler <- "2+2":
|
||||
case <-time.After(time.Second):
|
||||
t.Fatalf("input feedback timeout")
|
||||
}
|
||||
// Wait for the second prompt and ensure first statement was evaluated
|
||||
select {
|
||||
case <-tester.input.scheduler:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatalf("secondary prompt timeout")
|
||||
}
|
||||
if output := tester.output.String(); !strings.Contains(output, "4") {
|
||||
t.Fatalf("statement evaluation failed: have %s, want %s", output, "4")
|
||||
}
|
||||
}
|
||||
|
||||
// Tests that preloaded JavaScript files have been executed before user is given
|
||||
// input.
|
||||
func TestPreload(t *testing.T) {
|
||||
tester := newTester(t, nil)
|
||||
defer tester.Close(t)
|
||||
|
||||
tester.console.Evaluate("preloaded")
|
||||
if output := tester.output.String(); !strings.Contains(output, "some-preloaded-string") {
|
||||
t.Fatalf("preloaded variable missing: have %s, want %s", output, "some-preloaded-string")
|
||||
}
|
||||
}
|
||||
|
||||
// Tests that JavaScript scripts can be executes from the configured asset path.
|
||||
func TestExecute(t *testing.T) {
|
||||
tester := newTester(t, nil)
|
||||
defer tester.Close(t)
|
||||
|
||||
tester.console.Execute("exec.js")
|
||||
|
||||
tester.console.Evaluate("execed")
|
||||
if output := tester.output.String(); !strings.Contains(output, "some-executed-string") {
|
||||
t.Fatalf("execed variable missing: have %s, want %s", output, "some-executed-string")
|
||||
}
|
||||
}
|
||||
|
||||
// Tests that the JavaScript objects returned by statement executions are properly
|
||||
// pretty printed instead of just displaying "[object]".
|
||||
func TestPrettyPrint(t *testing.T) {
|
||||
tester := newTester(t, nil)
|
||||
defer tester.Close(t)
|
||||
|
||||
tester.console.Evaluate("obj = {int: 1, string: 'two', list: [3, 3, 3], obj: {null: null, func: function(){}}}")
|
||||
|
||||
// Define some specially formatted fields
|
||||
var (
|
||||
one = jsre.NumberColor("1")
|
||||
two = jsre.StringColor("\"two\"")
|
||||
three = jsre.NumberColor("3")
|
||||
null = jsre.SpecialColor("null")
|
||||
fun = jsre.FunctionColor("function()")
|
||||
)
|
||||
// Assemble the actual output we're after and verify
|
||||
want := `{
|
||||
int: ` + one + `,
|
||||
list: [` + three + `, ` + three + `, ` + three + `],
|
||||
obj: {
|
||||
null: ` + null + `,
|
||||
func: ` + fun + `
|
||||
},
|
||||
string: ` + two + `
|
||||
}
|
||||
`
|
||||
if output := tester.output.String(); output != want {
|
||||
t.Fatalf("pretty print mismatch: have %s, want %s", output, want)
|
||||
}
|
||||
}
|
||||
|
||||
// Tests that the JavaScript exceptions are properly formatted and colored.
|
||||
func TestPrettyError(t *testing.T) {
|
||||
tester := newTester(t, nil)
|
||||
defer tester.Close(t)
|
||||
tester.console.Evaluate("throw 'hello'")
|
||||
|
||||
want := jsre.ErrorColor("hello") + "\n\tat <eval>:1:7(1)\n\n"
|
||||
if output := tester.output.String(); output != want {
|
||||
t.Fatalf("pretty error mismatch: have %s, want %s", output, want)
|
||||
}
|
||||
}
|
||||
|
||||
// Tests that tests if the number of indents for JS input is calculated correct.
|
||||
func TestIndenting(t *testing.T) {
|
||||
testCases := []struct {
|
||||
input string
|
||||
expectedIndentCount int
|
||||
}{
|
||||
{`var a = 1;`, 0},
|
||||
{`"some string"`, 0},
|
||||
{`"some string with (parenthesis`, 0},
|
||||
{`"some string with newline
|
||||
("`, 0},
|
||||
{`function v(a,b) {}`, 0},
|
||||
{`function f(a,b) { var str = "asd("; };`, 0},
|
||||
{`function f(a) {`, 1},
|
||||
{`function f(a, function(b) {`, 2},
|
||||
{`function f(a, function(b) {
|
||||
var str = "a)}";
|
||||
});`, 0},
|
||||
{`function f(a,b) {
|
||||
var str = "a{b(" + a, ", " + b;
|
||||
}`, 0},
|
||||
{`var str = "\"{"`, 0},
|
||||
{`var str = "'("`, 0},
|
||||
{`var str = "\\{"`, 0},
|
||||
{`var str = "\\\\{"`, 0},
|
||||
{`var str = 'a"{`, 0},
|
||||
{`var obj = {`, 1},
|
||||
{`var obj = { {a:1`, 2},
|
||||
{`var obj = { {a:1}`, 1},
|
||||
{`var obj = { {a:1}, b:2}`, 0},
|
||||
{`var obj = {}`, 0},
|
||||
{`var obj = {
|
||||
a: 1, b: 2
|
||||
}`, 0},
|
||||
{`var test = }`, -1},
|
||||
{`var str = "a\""; var obj = {`, 1},
|
||||
}
|
||||
|
||||
for i, tt := range testCases {
|
||||
counted := countIndents(tt.input)
|
||||
if counted != tt.expectedIndentCount {
|
||||
t.Errorf("test %d: invalid indenting: have %d, want %d", i, counted, tt.expectedIndentCount)
|
||||
}
|
||||
}
|
||||
}
|
1
console/testdata/exec.js
vendored
1
console/testdata/exec.js
vendored
@ -1 +0,0 @@
|
||||
var execed = "some-executed-string";
|
1
console/testdata/preload.js
vendored
1
console/testdata/preload.js
vendored
@ -1 +0,0 @@
|
||||
var preloaded = "some-preloaded-string";
|
@ -207,7 +207,7 @@ func (api *privateAdminAPI) StartRPC(host *string, port *int, cors *string, apis
|
||||
if err := api.node.http.setListenAddr(*host, *port); err != nil {
|
||||
return false, err
|
||||
}
|
||||
if err := api.node.http.enableRPC(api.node.rpcAPIs, config); err != nil {
|
||||
if err := api.node.http.enableRPC(api.node.rpcAPIs, config, nil); err != nil {
|
||||
return false, err
|
||||
}
|
||||
if err := api.node.http.start(); err != nil {
|
||||
@ -263,7 +263,7 @@ func (api *privateAdminAPI) StartWS(host *string, port *int, allowedOrigins *str
|
||||
if err := server.setListenAddr(*host, *port); err != nil {
|
||||
return false, err
|
||||
}
|
||||
if err := server.enableWS(api.node.rpcAPIs, config); err != nil {
|
||||
if err := server.enableWS(api.node.rpcAPIs, config, nil); err != nil {
|
||||
return false, err
|
||||
}
|
||||
if err := server.start(); err != nil {
|
||||
|
11
node/node.go
11
node/node.go
@ -58,6 +58,8 @@ type Node struct {
|
||||
ipc *ipcServer // Stores information about the ipc http server
|
||||
inprocHandler *rpc.Server // In-process RPC request handler to process the API requests
|
||||
|
||||
rpcAllowList rpc.AllowList // list of RPC methods explicitly allowed for this RPC node
|
||||
|
||||
databases []ethdb.Closer
|
||||
}
|
||||
|
||||
@ -342,6 +344,11 @@ func (n *Node) closeDataDir() {
|
||||
}
|
||||
}
|
||||
|
||||
// SetAllowListForRPC sets granular allow list for exposed RPC methods
|
||||
func (n *Node) SetAllowListForRPC(allowList rpc.AllowList) {
|
||||
n.rpcAllowList = allowList
|
||||
}
|
||||
|
||||
// configureRPC is a helper method to configure all the various RPC endpoints during node
|
||||
// startup. It's not meant to be called at any time afterwards as it makes certain
|
||||
// assumptions about the state of the node.
|
||||
@ -367,7 +374,7 @@ func (n *Node) startRPC() error {
|
||||
if err := n.http.setListenAddr(n.config.HTTPHost, n.config.HTTPPort); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := n.http.enableRPC(n.rpcAPIs, config); err != nil {
|
||||
if err := n.http.enableRPC(n.rpcAPIs, config, n.rpcAllowList); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@ -382,7 +389,7 @@ func (n *Node) startRPC() error {
|
||||
if err := server.setListenAddr(n.config.WSHost, n.config.WSPort); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := server.enableWS(n.rpcAPIs, config); err != nil {
|
||||
if err := server.enableWS(n.rpcAPIs, config, n.rpcAllowList); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
@ -484,7 +484,6 @@ func TestWebsocketHTTPOnSeparatePort_WSRequest(t *testing.T) {
|
||||
if !checkRPC(node.HTTPEndpoint()) {
|
||||
t.Fatalf("http request failed")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func createNode(t *testing.T, httpPort, wsPort int) *Node {
|
||||
|
@ -227,7 +227,7 @@ func (h *httpServer) doStop() {
|
||||
}
|
||||
|
||||
// enableRPC turns on JSON-RPC over HTTP on the server.
|
||||
func (h *httpServer) enableRPC(apis []rpc.API, config httpConfig) error {
|
||||
func (h *httpServer) enableRPC(apis []rpc.API, config httpConfig, allowList rpc.AllowList) error {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
@ -237,6 +237,7 @@ func (h *httpServer) enableRPC(apis []rpc.API, config httpConfig) error {
|
||||
|
||||
// Create RPC server and handler.
|
||||
srv := rpc.NewServer()
|
||||
srv.SetAllowList(allowList)
|
||||
if err := RegisterApisFromWhitelist(apis, config.Modules, srv, false); err != nil {
|
||||
return err
|
||||
}
|
||||
@ -259,7 +260,7 @@ func (h *httpServer) disableRPC() bool {
|
||||
}
|
||||
|
||||
// enableWS turns on JSON-RPC over WebSocket on the server.
|
||||
func (h *httpServer) enableWS(apis []rpc.API, config wsConfig) error {
|
||||
func (h *httpServer) enableWS(apis []rpc.API, config wsConfig, allowList rpc.AllowList) error {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
@ -269,6 +270,7 @@ func (h *httpServer) enableWS(apis []rpc.API, config wsConfig) error {
|
||||
|
||||
// Create RPC server and handler.
|
||||
srv := rpc.NewServer()
|
||||
srv.SetAllowList(allowList)
|
||||
if err := RegisterApisFromWhitelist(apis, config.Modules, srv, false); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -18,7 +18,10 @@ package node
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ledgerwatch/turbo-geth/internal/testlog"
|
||||
@ -89,14 +92,16 @@ func TestIsWebsocket(t *testing.T) {
|
||||
assert.True(t, isWebsocket(r))
|
||||
}
|
||||
|
||||
func createAndStartServer(t *testing.T, conf httpConfig, ws bool, wsConf wsConfig) *httpServer {
|
||||
func createAndStartServerWithAllowList(t *testing.T, conf httpConfig, ws bool, wsConf wsConfig) *httpServer {
|
||||
t.Helper()
|
||||
|
||||
srv := newHTTPServer(testlog.Logger(t, log.LvlDebug), rpc.DefaultHTTPTimeouts)
|
||||
|
||||
assert.NoError(t, srv.enableRPC(nil, conf))
|
||||
allowList := rpc.AllowList(map[string]struct{}{"net_version": {}}) //don't allow RPC modules
|
||||
|
||||
assert.NoError(t, srv.enableRPC(nil, conf, allowList))
|
||||
if ws {
|
||||
assert.NoError(t, srv.enableWS(nil, wsConf))
|
||||
assert.NoError(t, srv.enableWS(nil, wsConf, allowList))
|
||||
}
|
||||
assert.NoError(t, srv.setListenAddr("localhost", 0))
|
||||
assert.NoError(t, srv.start())
|
||||
@ -104,10 +109,48 @@ func createAndStartServer(t *testing.T, conf httpConfig, ws bool, wsConf wsConfi
|
||||
return srv
|
||||
}
|
||||
|
||||
func createAndStartServer(t *testing.T, conf httpConfig, ws bool, wsConf wsConfig) *httpServer {
|
||||
t.Helper()
|
||||
|
||||
srv := newHTTPServer(testlog.Logger(t, log.LvlDebug), rpc.DefaultHTTPTimeouts)
|
||||
|
||||
assert.NoError(t, srv.enableRPC(nil, conf, nil))
|
||||
if ws {
|
||||
assert.NoError(t, srv.enableWS(nil, wsConf, nil))
|
||||
}
|
||||
assert.NoError(t, srv.setListenAddr("localhost", 0))
|
||||
assert.NoError(t, srv.start())
|
||||
|
||||
return srv
|
||||
}
|
||||
|
||||
func TestAllowList(t *testing.T) {
|
||||
srv := createAndStartServerWithAllowList(t, httpConfig{}, false, wsConfig{})
|
||||
defer srv.stop()
|
||||
|
||||
assert.False(t, testCustomRequest(t, srv, "rpc_modules"))
|
||||
}
|
||||
|
||||
func testCustomRequest(t *testing.T, srv *httpServer, method string) bool {
|
||||
body := bytes.NewReader([]byte(fmt.Sprintf(`{"jsonrpc":"2.0","id":1,"method":"%s"}`, method)))
|
||||
req, _ := http.NewRequest("POST", "http://"+srv.listenAddr(), body)
|
||||
req.Header.Set("content-type", "application/json")
|
||||
|
||||
client := http.DefaultClient
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
respBody, err := ioutil.ReadAll(resp.Body)
|
||||
assert.NoError(t, err)
|
||||
|
||||
return !strings.Contains(string(respBody), "error")
|
||||
}
|
||||
|
||||
func testRequest(t *testing.T, key, value, host string, srv *httpServer) *http.Response {
|
||||
t.Helper()
|
||||
|
||||
body := bytes.NewReader([]byte(`{"jsonrpc":"2.0","id":1,method":"rpc_modules"}`))
|
||||
body := bytes.NewReader([]byte(`{"jsonrpc":"2.0","id":1,"method":"rpc_modules"}`))
|
||||
req, _ := http.NewRequest("POST", "http://"+srv.listenAddr(), body)
|
||||
req.Header.Set("content-type", "application/json")
|
||||
if key != "" && value != "" {
|
||||
|
35
rpc/allow_list.go
Normal file
35
rpc/allow_list.go
Normal file
@ -0,0 +1,35 @@
|
||||
package rpc
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
type AllowList map[string]struct{}
|
||||
|
||||
func (a *AllowList) UnmarshalJSON(data []byte) error {
|
||||
var keys []string
|
||||
err := json.Unmarshal(data, &keys)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
realA := make(map[string]struct{})
|
||||
|
||||
for _, k := range keys {
|
||||
realA[k] = struct{}{}
|
||||
}
|
||||
|
||||
*a = realA
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarshalJSON returns *m as the JSON encoding of
|
||||
func (a *AllowList) MarshalJSON() ([]byte, error) {
|
||||
var realA map[string]struct{} = *a
|
||||
keys := make([]string, len(realA))
|
||||
i := 0
|
||||
for key := range realA {
|
||||
keys[i] = key
|
||||
i++
|
||||
}
|
||||
return json.Marshal(keys)
|
||||
}
|
23
rpc/allow_list_test.go
Normal file
23
rpc/allow_list_test.go
Normal file
@ -0,0 +1,23 @@
|
||||
package rpc
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestAllowListMarshaling(t *testing.T) {
|
||||
|
||||
}
|
||||
|
||||
func TestAllowListUnmarshaling(t *testing.T) {
|
||||
allowListJSON := `[ "one", "two", "three" ]`
|
||||
|
||||
var allowList AllowList
|
||||
err := json.Unmarshal([]byte(allowListJSON), &allowList)
|
||||
assert.NoError(t, err, "should unmarshal successfully")
|
||||
|
||||
m := map[string]struct{}{"one": {}, "two": {}, "three": {}}
|
||||
assert.Equal(t, allowList, AllowList(m))
|
||||
}
|
@ -74,9 +74,10 @@ type BatchElem struct {
|
||||
|
||||
// Client represents a connection to an RPC server.
|
||||
type Client struct {
|
||||
idgen func() ID // for subscriptions
|
||||
isHTTP bool
|
||||
services *serviceRegistry
|
||||
idgen func() ID // for subscriptions
|
||||
isHTTP bool
|
||||
services *serviceRegistry
|
||||
methodAllowList AllowList
|
||||
|
||||
idCounter uint32
|
||||
|
||||
@ -111,7 +112,7 @@ type clientConn struct {
|
||||
|
||||
func (c *Client) newClientConn(conn ServerCodec) *clientConn {
|
||||
ctx := context.WithValue(context.Background(), clientContextKey{}, c)
|
||||
handler := newHandler(ctx, conn, c.idgen, c.services)
|
||||
handler := newHandler(ctx, conn, c.idgen, c.services, c.methodAllowList)
|
||||
return &clientConn{conn, handler}
|
||||
}
|
||||
|
||||
|
@ -62,6 +62,8 @@ type handler struct {
|
||||
log log.Logger
|
||||
allowSubscribe bool
|
||||
|
||||
allowList AllowList // a list of explicitly allowed methods, if empty -- everything is allowed
|
||||
|
||||
subLock sync.Mutex
|
||||
serverSubs map[ID]*Subscription
|
||||
}
|
||||
@ -71,7 +73,7 @@ type callProc struct {
|
||||
notifiers []*Notifier
|
||||
}
|
||||
|
||||
func newHandler(connCtx context.Context, conn jsonWriter, idgen func() ID, reg *serviceRegistry) *handler {
|
||||
func newHandler(connCtx context.Context, conn jsonWriter, idgen func() ID, reg *serviceRegistry, allowList AllowList) *handler {
|
||||
rootCtx, cancelRoot := context.WithCancel(connCtx)
|
||||
h := &handler{
|
||||
reg: reg,
|
||||
@ -84,6 +86,7 @@ func newHandler(connCtx context.Context, conn jsonWriter, idgen func() ID, reg *
|
||||
allowSubscribe: true,
|
||||
serverSubs: make(map[ID]*Subscription),
|
||||
log: log.Root(),
|
||||
allowList: allowList,
|
||||
}
|
||||
if conn.remoteAddr() != "" {
|
||||
h.log = h.log.New("conn", conn.remoteAddr())
|
||||
@ -314,6 +317,14 @@ func (h *handler) handleCallMsg(ctx *callProc, msg *jsonrpcMessage) *jsonrpcMess
|
||||
}
|
||||
}
|
||||
|
||||
func (h *handler) isMethodAllowedByGranularControl(method string) bool {
|
||||
if len(h.allowList) == 0 {
|
||||
return true
|
||||
}
|
||||
_, ok := h.allowList[method]
|
||||
return ok
|
||||
}
|
||||
|
||||
// handleCall processes method calls.
|
||||
func (h *handler) handleCall(cp *callProc, msg *jsonrpcMessage) *jsonrpcMessage {
|
||||
if msg.isSubscribe() {
|
||||
@ -322,7 +333,7 @@ func (h *handler) handleCall(cp *callProc, msg *jsonrpcMessage) *jsonrpcMessage
|
||||
var callb *callback
|
||||
if msg.isUnsubscribe() {
|
||||
callb = h.unsubscribeCb
|
||||
} else {
|
||||
} else if h.isMethodAllowedByGranularControl(msg.Method) {
|
||||
callb = h.reg.callback(msg.Method)
|
||||
}
|
||||
if callb == nil {
|
||||
|
@ -42,10 +42,11 @@ const (
|
||||
|
||||
// Server is an RPC server.
|
||||
type Server struct {
|
||||
services serviceRegistry
|
||||
idgen func() ID
|
||||
run int32
|
||||
codecs mapset.Set
|
||||
services serviceRegistry
|
||||
methodAllowList AllowList
|
||||
idgen func() ID
|
||||
run int32
|
||||
codecs mapset.Set
|
||||
}
|
||||
|
||||
// NewServer creates a new server instance with no registered handlers.
|
||||
@ -58,6 +59,11 @@ func NewServer() *Server {
|
||||
return server
|
||||
}
|
||||
|
||||
// SetAllowList sets the allow list for methods that are handled by this server
|
||||
func (s *Server) SetAllowList(allowList AllowList) {
|
||||
s.methodAllowList = allowList
|
||||
}
|
||||
|
||||
// RegisterName creates a service for the given receiver type under the given name. When no
|
||||
// methods on the given receiver match the criteria to be either a RPC method or a
|
||||
// subscription an error is returned. Otherwise a new service is created and added to the
|
||||
@ -97,7 +103,7 @@ func (s *Server) serveSingleRequest(ctx context.Context, codec ServerCodec) {
|
||||
return
|
||||
}
|
||||
|
||||
h := newHandler(ctx, codec, s.idgen, &s.services)
|
||||
h := newHandler(ctx, codec, s.idgen, &s.services, s.methodAllowList)
|
||||
h.allowSubscribe = false
|
||||
defer h.close(io.EOF, nil)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user