diff --git a/cmd/geth/chaincmd.go b/cmd/geth/chaincmd.go index 8bd3c57bc..d764dcc6a 100644 --- a/cmd/geth/chaincmd.go +++ b/cmd/geth/chaincmd.go @@ -231,8 +231,9 @@ func initGenesis(ctx *cli.Context) error { if err := json.NewDecoder(file).Decode(genesis); err != nil { utils.Fatalf("invalid genesis file: %v", err) } + // Open an initialise both full and light databases - stack, _ := makeFullNode(ctx) + stack, _ := makeConfigNode(ctx) defer stack.Close() for _, name := range []string{"chaindata", "lightchaindata"} { diff --git a/cmd/geth/config.go b/cmd/geth/config.go index 717fe012f..31ce605eb 100644 --- a/cmd/geth/config.go +++ b/cmd/geth/config.go @@ -28,6 +28,7 @@ import ( "github.com/ledgerwatch/turbo-geth/cmd/utils" "github.com/ledgerwatch/turbo-geth/eth" + "github.com/ethereum/go-ethereum/internal/ethapi" "github.com/ledgerwatch/turbo-geth/node" "github.com/ledgerwatch/turbo-geth/params" @@ -131,19 +132,20 @@ func makeConfigNode(ctx *cli.Context) (*node.Node, gethConfig) { return stack, cfg } -func makeFullNode(ctx *cli.Context) (*node.Node, *eth.Ethereum) { +func makeFullNode(ctx *cli.Context) (*node.Node, ethapi.Backend) { stack, cfg := makeConfigNode(ctx) - service := utils.RegisterEthService(stack, &cfg.Eth) + + backend := utils.RegisterEthService(stack, &cfg.Eth) // Configure GraphQL if required if ctx.GlobalIsSet(utils.GraphQLEnabledFlag.Name) { - utils.RegisterGraphQLService(stack, cfg.Node.GraphQLEndpoint(), cfg.Node.GraphQLCors, cfg.Node.GraphQLVirtualHosts, cfg.Node.HTTPTimeouts) + utils.RegisterGraphQLService(stack, backend, cfg.Node) } // Add the Ethereum Stats daemon if requested. if cfg.Ethstats.URL != "" { - utils.RegisterEthStatsService(stack, cfg.Ethstats.URL) + utils.RegisterEthStatsService(stack, backend, cfg.Ethstats.URL) } - return stack, service + return stack, backend } // dumpConfig is the dumpconfig command. diff --git a/cmd/geth/consolecmd.go b/cmd/geth/consolecmd.go index f70e0baab..666305abe 100644 --- a/cmd/geth/consolecmd.go +++ b/cmd/geth/consolecmd.go @@ -80,26 +80,13 @@ JavaScript API. See https://github.com/ledgerwatch/turbo-geth/wiki/JavaScript-Co func localConsole(ctx *cli.Context) error { // Create and start the node based on the CLI flags prepare(ctx) - stack, ethService := makeFullNode(ctx) + stack, backend := makeFullNode(ctx) - err := stack.Register(func(ctx *node.ServiceContext) (node.Service, error) { - diskdb, err := ctx.OpenDatabaseWithFreezer("chaindata", "") - if err != nil { - return nil, err - } - return service.New(diskdb, ethService.TxPool()), nil - }) - - if err != nil { - panic(err) - } - - node := stack - startNode(ctx, node) - defer node.Close() + startNode(ctx, stack, backend) + defer stack.Close() // Attach to the newly started node and start the JavaScript console - client, err := node.Attach() + client, err := stack.Attach() if err != nil { utils.Fatalf("Failed to attach to the inproc geth: %v", err) } @@ -206,12 +193,12 @@ func dialRPC(endpoint string) (*rpc.Client, error) { // everything down. func ephemeralConsole(ctx *cli.Context) error { // Create and start the node based on the CLI flags - node, _ := makeFullNode(ctx) - startNode(ctx, node) - defer node.Close() + stack, backend := makeFullNode(ctx) + startNode(ctx, stack, backend) + defer stack.Close() // Attach to the newly started node and start the JavaScript console - client, err := node.Attach() + client, err := stack.Attach() if err != nil { utils.Fatalf("Failed to attach to the inproc geth: %v", err) } diff --git a/cmd/geth/dao_test.go b/cmd/geth/dao_test.go index 766d440d4..cdb9e277b 100644 --- a/cmd/geth/dao_test.go +++ b/cmd/geth/dao_test.go @@ -120,8 +120,7 @@ func testDAOForkBlockNewChain(t *testing.T, test int, genesis string, expectBloc } else { // Force chain initialization args := []string{"--port", "0", "--maxpeers", "0", "--nodiscover", "--nat", "none", "--ipcdisable", "--datadir", datadir} - geth := runGeth(t, append(args, []string{"--exec", "2+2", "console"}...)...) - geth.WaitExit() + runGeth(t, append(args, []string{"--exec", "2+2", "console"}...)...).WaitExit() } // Retrieve the DAO config flag from the database path := filepath.Join(datadir, "tg", "chaindata") diff --git a/cmd/geth/main.go b/cmd/geth/main.go index 3c87c2d9d..de0ad72b0 100644 --- a/cmd/geth/main.go +++ b/cmd/geth/main.go @@ -174,8 +174,6 @@ var ( utils.LegacyRPCCORSDomainFlag, utils.LegacyRPCVirtualHostsFlag, utils.GraphQLEnabledFlag, - utils.GraphQLListenAddrFlag, - utils.GraphQLPortFlag, utils.GraphQLCORSDomainFlag, utils.GraphQLVirtualHostsFlag, utils.HTTPApiFlag, @@ -347,18 +345,20 @@ func geth(ctx *cli.Context) error { if args := ctx.Args(); len(args) > 0 { return fmt.Errorf("invalid command: %q", args[0]) } + prepare(ctx) - node, _ := makeFullNode(ctx) - defer node.Close() - startNode(ctx, node) - node.Wait() + stack, backend := makeFullNode(ctx) + defer stack.Close() + + startNode(ctx, stack, backend) + stack.Wait() return nil } // startNode boots up the system node and all registered protocols, after which // it unlocks any requested accounts, and starts the RPC/IPC interfaces and the // miner. -func startNode(ctx *cli.Context, stack *node.Node) { +func startNode(ctx *cli.Context, stack *node.Node, backend ethapi.Backend) { debug.Memsize.Add("node", stack) // Start up the node itself @@ -378,16 +378,6 @@ func startNode(ctx *cli.Context, stack *node.Node) { } ethClient := ethclient.NewClient(rpcClient) - // Set contract backend for ethereum service if local node - // is serving LES requests. - if ctx.GlobalInt(utils.LegacyLightServFlag.Name) > 0 || ctx.GlobalInt(utils.LightServeFlag.Name) > 0 { - var ethService *eth.Ethereum - if err := stack.Service(ðService); err != nil { - utils.Fatalf("Failed to retrieve ethereum service: %v", err) - } - ethService.SetContractBackend(ethClient) - } - go func() { // Open any wallets already attached for _, wallet := range stack.AccountManager().Wallets() { @@ -439,7 +429,7 @@ func startNode(ctx *cli.Context, stack *node.Node) { if timestamp := time.Unix(int64(done.Latest.Time), 0); time.Since(timestamp) < 10*time.Minute { log.Info("Synchronisation completed", "latestnum", done.Latest.Number, "latesthash", done.Latest.Hash(), "age", common.PrettyAge(timestamp)) - stack.Stop() + stack.Close() } } }() @@ -447,24 +437,28 @@ func startNode(ctx *cli.Context, stack *node.Node) { // Start auxiliary services if enabled if ctx.GlobalBool(utils.MiningEnabledFlag.Name) || ctx.GlobalBool(utils.DeveloperFlag.Name) { - var ethereum *eth.Ethereum - if err := stack.Service(ðereum); err != nil { + // Mining only makes sense if a full Ethereum node is running + if ctx.GlobalString(utils.SyncModeFlag.Name) == "light" { + utils.Fatalf("Light clients do not support mining") + } + ethBackend, ok := backend.(*eth.EthAPIBackend) + if !ok { utils.Fatalf("Ethereum service not running: %v", err) } + // Set the gas price to the limits from the CLI and start mining gasprice := utils.GlobalBig(ctx, utils.MinerGasPriceFlag.Name) if ctx.GlobalIsSet(utils.LegacyMinerGasPriceFlag.Name) && !ctx.GlobalIsSet(utils.MinerGasPriceFlag.Name) { gasprice = utils.GlobalBig(ctx, utils.LegacyMinerGasPriceFlag.Name) } - ethereum.TxPool().SetGasPrice(gasprice) - + ethBackend.TxPool().SetGasPrice(gasprice) + // start mining threads := ctx.GlobalInt(utils.MinerThreadsFlag.Name) if ctx.GlobalIsSet(utils.LegacyMinerThreadsFlag.Name) && !ctx.GlobalIsSet(utils.MinerThreadsFlag.Name) { threads = ctx.GlobalInt(utils.LegacyMinerThreadsFlag.Name) log.Warn("The flag --minerthreads is deprecated and will be removed in the future, please use --miner.threads") } - - if err := ethereum.StartMining(threads); err != nil { + if err := ethBackend.StartMining(threads); err != nil { utils.Fatalf("Failed to start mining: %v", err) } } diff --git a/cmd/geth/usage.go b/cmd/geth/usage.go index 3017e4e62..37fd0048c 100644 --- a/cmd/geth/usage.go +++ b/cmd/geth/usage.go @@ -148,8 +148,6 @@ var AppHelpFlagGroups = []flags.FlagGroup{ utils.WSApiFlag, utils.WSAllowedOriginsFlag, utils.GraphQLEnabledFlag, - utils.GraphQLListenAddrFlag, - utils.GraphQLPortFlag, utils.GraphQLCORSDomainFlag, utils.GraphQLVirtualHostsFlag, utils.RPCGlobalGasCap, @@ -235,6 +233,8 @@ var AppHelpFlagGroups = []flags.FlagGroup{ utils.LegacyWSApiFlag, utils.LegacyGpoBlocksFlag, utils.LegacyGpoPercentileFlag, + utils.LegacyGraphQLListenAddrFlag, + utils.LegacyGraphQLPortFlag, }, debug.DeprecatedFlags...), }, { diff --git a/cmd/p2psim/main.go b/cmd/p2psim/main.go index bf2776c11..5a22920fc 100644 --- a/cmd/p2psim/main.go +++ b/cmd/p2psim/main.go @@ -289,7 +289,7 @@ func createNode(ctx *cli.Context) error { config.PrivateKey = privKey } if services := ctx.String("services"); services != "" { - config.Services = strings.Split(services, ",") + config.Lifecycles = strings.Split(services, ",") } node, err := client.CreateNode(config) if err != nil { diff --git a/cmd/utils/cmd.go b/cmd/utils/cmd.go index 13d12e354..da825f6db 100644 --- a/cmd/utils/cmd.go +++ b/cmd/utils/cmd.go @@ -75,7 +75,7 @@ func StartNode(stack *node.Node) { defer signal.Stop(sigc) <-sigc log.Info("Got interrupt, shutting down...") - go stack.Stop() + go stack.Close() for i := 10; i > 0; i-- { <-sigc if i > 1 { diff --git a/cmd/utils/flags.go b/cmd/utils/flags.go index 693d8e4e9..f867e97fb 100644 --- a/cmd/utils/flags.go +++ b/cmd/utils/flags.go @@ -19,7 +19,6 @@ package utils import ( "crypto/ecdsa" - "errors" "fmt" "io" "io/ioutil" @@ -561,6 +560,20 @@ var ( Usage: "API's offered over the HTTP-RPC interface", Value: "", } + GraphQLEnabledFlag = cli.BoolFlag{ + Name: "graphql", + Usage: "Enable GraphQL on the HTTP-RPC server. Note that GraphQL can only be started if an HTTP server is started as well.", + } + GraphQLCORSDomainFlag = cli.StringFlag{ + Name: "graphql.corsdomain", + Usage: "Comma separated list of domains from which to accept cross origin requests (browser enforced)", + Value: "", + } + GraphQLVirtualHostsFlag = cli.StringFlag{ + Name: "graphql.vhosts", + Usage: "Comma separated list of virtual hostnames from which to accept requests (server enforced). Accepts '*' wildcard.", + Value: strings.Join(node.DefaultConfig.GraphQLVirtualHosts, ","), + } WSEnabledFlag = cli.BoolFlag{ Name: "ws", Usage: "Enable the WS-RPC server", @@ -585,30 +598,6 @@ var ( Usage: "Origins from which to accept websockets requests", Value: "", } - GraphQLEnabledFlag = cli.BoolFlag{ - Name: "graphql", - Usage: "Enable the GraphQL server", - } - GraphQLListenAddrFlag = cli.StringFlag{ - Name: "graphql.addr", - Usage: "GraphQL server listening interface", - Value: node.DefaultGraphQLHost, - } - GraphQLPortFlag = cli.IntFlag{ - Name: "graphql.port", - Usage: "GraphQL server listening port", - Value: node.DefaultGraphQLPort, - } - GraphQLCORSDomainFlag = cli.StringFlag{ - Name: "graphql.corsdomain", - Usage: "Comma separated list of domains from which to accept cross origin requests (browser enforced)", - Value: "", - } - GraphQLVirtualHostsFlag = cli.StringFlag{ - Name: "graphql.vhosts", - Usage: "Comma separated list of virtual hostnames from which to accept requests (server enforced). Accepts '*' wildcard.", - Value: strings.Join(node.DefaultConfig.GraphQLVirtualHosts, ","), - } ExecFlag = cli.StringFlag{ Name: "exec", Usage: "Execute JavaScript statement", @@ -984,13 +973,6 @@ func setHTTP(ctx *cli.Context, cfg *node.Config) { // setGraphQL creates the GraphQL listener interface string from the set // command line flags, returning empty if the GraphQL endpoint is disabled. func setGraphQL(ctx *cli.Context, cfg *node.Config) { - if ctx.GlobalBool(GraphQLEnabledFlag.Name) && cfg.GraphQLHost == "" { - cfg.GraphQLHost = localhost - if ctx.GlobalIsSet(GraphQLListenAddrFlag.Name) { - cfg.GraphQLHost = ctx.GlobalString(GraphQLListenAddrFlag.Name) - } - } - cfg.GraphQLPort = ctx.GlobalInt(GraphQLPortFlag.Name) if ctx.GlobalIsSet(GraphQLCORSDomainFlag.Name) { cfg.GraphQLCors = splitAndTrim(ctx.GlobalString(GraphQLCORSDomainFlag.Name)) } @@ -1742,12 +1724,24 @@ func setDNSDiscoveryDefaults(cfg *eth.Config, genesis common.Hash) { } // RegisterEthService adds an Ethereum client to the stack. -func RegisterEthService(stack *node.Node, cfg *eth.Config) *eth.Ethereum { - fullNode := new(eth.Ethereum) - if err := stack.Register(func(ctx *node.ServiceContext) (node.Service, error) { - fullNodeInst, err := eth.New(ctx, cfg) - *fullNode = *fullNodeInst //nolint:govet - return fullNode, err +func RegisterEthService(stack *node.Node, cfg *eth.Config) ethapi.Backend { + backend, err := les.New(stack, cfg) + if err != nil { + Fatalf("Failed to register the Ethereum service: %v", err) + } + return backend.ApiBackend + } else { + backend, err := eth.New(stack, cfg) + if err != nil { + Fatalf("Failed to register the Ethereum service: %v", err) + } + if cfg.LightServ > 0 { + _, err := les.NewLesServer(stack, backend, cfg) + if err != nil { + Fatalf("Failed to create the LES server: %v", err) + } + } + return backend.APIBackend }); err != nil { Fatalf("Failed to register the Ethereum service: %v", err) } @@ -1756,31 +1750,15 @@ func RegisterEthService(stack *node.Node, cfg *eth.Config) *eth.Ethereum { // RegisterEthStatsService configures the Ethereum Stats daemon and adds it to // the given node. -func RegisterEthStatsService(stack *node.Node, url string) { - if err := stack.Register(func(ctx *node.ServiceContext) (node.Service, error) { - // Retrieve both eth and les services - var ethServ *eth.Ethereum - if err := ctx.Service(ðServ); err != nil { - return nil, err - } - - return ethstats.New(url, ethServ) - }); err != nil { +func RegisterEthStatsService(stack *node.Node, backend ethapi.Backend, url string) { + if err := ethstats.New(stack, backend, backend.Engine(), url); err != nil { Fatalf("Failed to register the Ethereum Stats service: %v", err) } } // RegisterGraphQLService is a utility function to construct a new service and register it against a node. -func RegisterGraphQLService(stack *node.Node, endpoint string, cors, vhosts []string, timeouts rpc.HTTPTimeouts) { - if err := stack.Register(func(ctx *node.ServiceContext) (node.Service, error) { - // Try to construct the GraphQL service backed by a full node - var ethServ *eth.Ethereum - if err := ctx.Service(ðServ); err == nil { - return graphql.New(ethServ.APIBackend, endpoint, cors, vhosts, timeouts) - } - // Well, this should not have happened, bail out - return nil, errors.New("no Ethereum service") - }); err != nil { +func RegisterGraphQLService(stack *node.Node, backend ethapi.Backend, cfg node.Config) { + if err := graphql.New(stack, backend, cfg.GraphQLCors, cfg.GraphQLVirtualHosts); err != nil { Fatalf("Failed to register the GraphQL service: %v", err) } } diff --git a/cmd/utils/flags_legacy.go b/cmd/utils/flags_legacy.go index fe0b0a073..216ac8414 100644 --- a/cmd/utils/flags_legacy.go +++ b/cmd/utils/flags_legacy.go @@ -90,6 +90,8 @@ var ( Name: "testnet", Usage: "Pre-configured test network (Deprecated: Please choose one of --goerli, --rinkeby, or --ropsten.)", } + + // (Deprecated May 2020, shown in aliased flags section) LegacyRPCEnabledFlag = cli.BoolFlag{ Name: "rpc", Usage: "Enable the HTTP-RPC server (deprecated, use --http)", @@ -159,6 +161,17 @@ var ( Usage: "Comma separated enode URLs for P2P v5 discovery bootstrap (light server, light nodes) (deprecated, use --bootnodes)", Value: "", } + + // (Deprecated July 2020, shown in aliased flags section) + LegacyGraphQLListenAddrFlag = cli.StringFlag{ + Name: "graphql.addr", + Usage: "GraphQL server listening interface (deprecated, graphql can only be enabled on the HTTP-RPC server endpoint, use --graphql)", + } + LegacyGraphQLPortFlag = cli.IntFlag{ + Name: "graphql.port", + Usage: "GraphQL server listening port (deprecated, graphql can only be enabled on the HTTP-RPC server endpoint, use --graphql)", + Value: node.DefaultHTTPPort, + } ) // showDeprecated displays deprecated flags that will be soon removed from the codebase. diff --git a/console/console_test.go b/console/console_test.go index 006410f9b..f41432ed6 100644 --- a/console/console_test.go +++ b/console/console_test.go @@ -114,7 +114,8 @@ func newTester(t *testing.T, confOverride func(*eth.Config)) *tester { if confOverride != nil { confOverride(ethConf) } - if err = stack.Register(func(ctx *node.ServiceContext) (node.Service, error) { return eth.New(ctx, ethConf) }); err != nil { + ethBackend, err := eth.New(stack, ethConf) + if err != nil { t.Fatalf("failed to register Ethereum protocol: %v", err) } // Start the node and assemble the JavaScript console around it @@ -140,13 +141,10 @@ func newTester(t *testing.T, confOverride func(*eth.Config)) *tester { t.Fatalf("failed to create JavaScript console: %v", err) } // Create the final tester and return - var ethereum *eth.Ethereum - stack.Service(ðereum) - return &tester{ workspace: workspace, stack: stack, - ethereum: ethereum, + ethereum: ethBackend, console: console, input: prompter, output: printer, diff --git a/eth/api_backend.go b/eth/api_backend.go index edf5d683b..7fa0f6158 100644 --- a/eth/api_backend.go +++ b/eth/api_backend.go @@ -24,6 +24,7 @@ import ( "github.com/ledgerwatch/turbo-geth/accounts" "github.com/ledgerwatch/turbo-geth/common" + "github.com/ethereum/go-ethereum/consensus" "github.com/ledgerwatch/turbo-geth/core" "github.com/ledgerwatch/turbo-geth/core/bloombits" "github.com/ledgerwatch/turbo-geth/core/rawdb" @@ -318,6 +319,10 @@ func (b *EthAPIBackend) Stats() (pending int, queued int) { func (b *EthAPIBackend) TxPoolContent() (map[common.Address]types.Transactions, map[common.Address]types.Transactions) { return b.eth.TxPool().Content() +func (b *EthAPIBackend) TxPool() *core.TxPool { +} + + return b.eth.TxPool() } func (b *EthAPIBackend) SubscribeNewTxsEvent(ch chan<- core.NewTxsEvent) event.Subscription { @@ -370,3 +375,19 @@ func (b *EthAPIBackend) ServiceFilter(ctx context.Context, session *bloombits.Ma go session.Multiplex(bloomRetrievalBatch, bloomRetrievalWait, b.eth.bloomRequests) } } + +func (b *EthAPIBackend) Engine() consensus.Engine { + return b.eth.engine +} + +func (b *EthAPIBackend) CurrentHeader() *types.Header { + return b.eth.blockchain.CurrentHeader() +} + +func (b *EthAPIBackend) Miner() *miner.Miner { + return b.eth.Miner() +} + +func (b *EthAPIBackend) StartMining(threads int) error { + return b.eth.StartMining(threads) +} diff --git a/eth/backend.go b/eth/backend.go index acd40587d..fb64a03d7 100644 --- a/eth/backend.go +++ b/eth/backend.go @@ -58,15 +58,6 @@ import ( "github.com/ledgerwatch/turbo-geth/rpc" ) -type LesServer interface { - Start(srvr *p2p.Server) - Stop() - APIs() []rpc.API - Protocols() []p2p.Protocol - SetBloomBitsIndexer(bbIndexer *core.ChainIndexer) - SetContractBackend(bind.ContractBackend) -} - // Ethereum implements the Ethereum full node service. type Ethereum struct { config *Config @@ -75,7 +66,6 @@ type Ethereum struct { txPool *core.TxPool blockchain *core.BlockChain protocolManager *ProtocolManager - lesServer LesServer dialCandidates enode.Iterator // DB interfaces @@ -99,26 +89,16 @@ type Ethereum struct { networkID uint64 netRPCService *ethapi.PublicNetAPI - lock sync.RWMutex // Protects the variadic fields (e.g. gas price and etherbase) + p2pServer *p2p.Server txPoolStarted bool } -func (s *Ethereum) AddLesServer(ls LesServer) { - s.lesServer = ls - ls.SetBloomBitsIndexer(s.bloomIndexer) -} - -// SetClient sets a rpc client which connecting to our local node. -func (s *Ethereum) SetContractBackend(backend bind.ContractBackend) { - // Pass the rpc client to les server if it is enabled. - if s.lesServer != nil { - s.lesServer.SetContractBackend(backend) - } + lock sync.RWMutex // Protects the variadic fields (e.g. gas price and etherbase) } // New creates a new Ethereum object (including the // initialisation of the common Ethereum object) -func New(ctx *node.ServiceContext, config *Config) (*Ethereum, error) { +func New(stack *node.Node, config *Config) (*Ethereum, error) { // Ensure configuration values are compatible and sane if config.SyncMode == downloader.LightSync { return nil, errors.New("can't run eth.Ethereum in light sync mode, use les.LightEthereum") @@ -150,7 +130,7 @@ func New(ctx *node.ServiceContext, config *Config) (*Ethereum, error) { } chainDb = ethdb.MustOpen("simulator") } else { - if chainDb, err = ctx.OpenDatabaseWithFreezer("chaindata", config.DatabaseFreezer); err != nil { + chainDb, err := stack.OpenDatabaseWithFreezer("chaindata", config.DatabaseCache, config.DatabaseHandles, config.DatabaseFreezer, "eth/db/chaindata/") return nil, err } } @@ -171,15 +151,16 @@ func New(ctx *node.ServiceContext, config *Config) (*Ethereum, error) { config: config, chainDb: chainDb, chainKV: chainDb.KV(), - eventMux: ctx.EventMux, - accountManager: ctx.AccountManager, - engine: CreateConsensusEngine(ctx, chainConfig, &config.Ethash, config.Miner.Notify, config.Miner.Noverify, chainDb), + eventMux: stack.EventMux(), + accountManager: stack.AccountManager(), + engine: CreateConsensusEngine(stack, chainConfig, &config.Ethash, config.Miner.Notify, config.Miner.Noverify, chainDb), closeBloomHandler: make(chan struct{}), networkID: config.NetworkID, gasPrice: config.Miner.GasPrice, etherbase: config.Miner.Etherbase, bloomRequests: make(chan chan *bloombits.Retrieval), bloomIndexer: NewBloomIndexer(chainDb, params.BloomBitsBlocks, params.BloomConfirms), + p2pServer: stack.Server(), } log.Info("Initialising Ethereum protocol", "versions", ProtocolVersions, "network", config.NetworkID) @@ -239,7 +220,7 @@ func New(ctx *node.ServiceContext, config *Config) (*Ethereum, error) { } if config.TxPool.Journal != "" { - config.TxPool.Journal = ctx.ResolvePath(config.TxPool.Journal) + config.TxPool.Journal = stack.ResolvePath(config.TxPool.Journal) } eth.txPool = core.NewTxPool(config.TxPool, chainConfig, chainDb, txCacher) @@ -255,13 +236,19 @@ func New(ctx *node.ServiceContext, config *Config) (*Ethereum, error) { if eth.protocolManager, err = NewProtocolManager(chainConfig, checkpoint, config.SyncMode, config.NetworkID, eth.eventMux, eth.txPool, eth.engine, eth.blockchain, chainDb, config.Whitelist); err != nil { return nil, err } + eth.miner = miner.New(eth, &config.Miner, chainConfig, eth.EventMux(), eth.engine, eth.isLocalBlock) eth.protocolManager.SetDataDir(ctx.Config.DataDir) if config.SyncMode != downloader.StagedSync { if err = eth.StartTxPool(); err != nil { return nil, err } + eth.APIBackend = &EthAPIBackend{stack.Config().ExtRPCEnabled(), eth, nil} + gpoParams := config.GPO + if gpoParams.Default == nil { + gpoParams.Default = config.Miner.GasPrice } + eth.APIBackend.gpo = gasprice.NewOracle(eth.APIBackend, gpoParams) if config.SyncMode != downloader.StagedSync { eth.miner = miner.New(eth, &config.Miner, chainConfig, eth.EventMux(), eth.engine, eth.isLocalBlock) @@ -277,11 +264,18 @@ func New(ctx *node.ServiceContext, config *Config) (*Ethereum, error) { eth.APIBackend.gpo = gasprice.NewOracle(eth.APIBackend, gpoParams) } - eth.dialCandidates, err = eth.setupDiscovery(&ctx.Config.P2P) + eth.dialCandidates, err = eth.setupDiscovery(&stack.Config().P2P) if err != nil { return nil, err } + // Start the RPC service + eth.netRPCService = ethapi.NewPublicNetAPI(eth.p2pServer, eth.NetVersion()) + + // Register the backend on the node + stack.RegisterAPIs(eth.APIs()) + stack.RegisterProtocols(eth.Protocols()) + stack.RegisterLifecycle(eth) return eth, nil } @@ -327,7 +321,7 @@ func makeExtraData(extra []byte) []byte { } // CreateConsensusEngine creates the required type of consensus engine instance for an Ethereum service -func CreateConsensusEngine(ctx *node.ServiceContext, chainConfig *params.ChainConfig, config *ethash.Config, notify []string, noverify bool, db ethdb.Database) consensus.Engine { +func CreateConsensusEngine(stack *node.Node, chainConfig *params.ChainConfig, config *ethash.Config, notify []string, noverify bool, db ethdb.Database) consensus.Engine { // If proof-of-authority is requested, set it up if chainConfig.Clique != nil { return clique.New(chainConfig.Clique, db) @@ -345,7 +339,7 @@ func CreateConsensusEngine(ctx *node.ServiceContext, chainConfig *params.ChainCo return ethash.NewShared() default: engine := ethash.New(ethash.Config{ - CacheDir: ctx.ResolvePath(config.CacheDir), + CacheDir: stack.ResolvePath(config.CacheDir), CachesInMem: config.CachesInMem, CachesOnDisk: config.CachesOnDisk, CachesLockMmap: config.CachesLockMmap, @@ -367,18 +361,9 @@ func (s *Ethereum) APIs() []rpc.API { } apis := ethapi.GetAPIs(s.APIBackend) - // Append any APIs exposed explicitly by the les server - if s.lesServer != nil { - apis = append(apis, s.lesServer.APIs()...) - } // Append any APIs exposed explicitly by the consensus engine apis = append(apis, s.engine.APIs(s.BlockChain())...) - // Append any APIs exposed explicitly by the les server - if s.lesServer != nil { - apis = append(apis, s.lesServer.APIs()...) - } - // Append all the local APIs and return return append(apis, []rpc.API{ { @@ -594,8 +579,9 @@ func (s *Ethereum) NetVersion() uint64 { return s.networkID } func (s *Ethereum) Downloader() *downloader.Downloader { return s.protocolManager.downloader } func (s *Ethereum) Synced() bool { return atomic.LoadUint32(&s.protocolManager.acceptTxs) == 1 } func (s *Ethereum) ArchiveMode() bool { return !s.config.Pruning } +func (s *Ethereum) BloomIndexer() *core.ChainIndexer { return s.bloomIndexer } -// Protocols implements node.Service, returning all the currently configured +// Protocols returns all the currently configured // network protocols to start. func (s *Ethereum) Protocols() []p2p.Protocol { protos := make([]p2p.Protocol, len(ProtocolVersions)) @@ -616,10 +602,10 @@ func (s *Ethereum) Protocols() []p2p.Protocol { return protos } -// Start implements node.Service, starting all internal goroutines needed by the +// Start implements node.Lifecycle, starting all internal goroutines needed by the // Ethereum protocol implementation. -func (s *Ethereum) Start(srvr *p2p.Server) error { - s.startEthEntryUpdate(srvr.LocalNode()) +func (s *Ethereum) Start() error { + s.startEthEntryUpdate(s.p2pServer.LocalNode()) // Start the bloom bits servicing goroutines if s.config.SyncMode != downloader.StagedSync { @@ -632,10 +618,10 @@ func (s *Ethereum) Start(srvr *p2p.Server) error { } // Figure out a max peers count based on the server limits - maxPeers := srvr.MaxPeers + maxPeers := s.p2pServer.MaxPeers if s.config.LightServ > 0 { - if s.config.LightPeers >= srvr.MaxPeers { - return fmt.Errorf("invalid peer config: light peer count (%d) >= total peer count (%d)", s.config.LightPeers, srvr.MaxPeers) + if s.config.LightPeers >= s.p2pServer.MaxPeers { + return fmt.Errorf("invalid peer config: light peer count (%d) >= total peer count (%d)", s.config.LightPeers, s.p2pServer.MaxPeers) } maxPeers -= s.config.LightPeers } @@ -643,11 +629,7 @@ func (s *Ethereum) Start(srvr *p2p.Server) error { withTxPool := s.config.SyncMode != downloader.StagedSync // Start the networking layer and the light server if requested if err := s.protocolManager.Start(maxPeers, withTxPool); err != nil { - return err - } - if s.lesServer != nil { - s.lesServer.Start(srvr) - } +// Stop implements node.Lifecycle, terminating all internal goroutines used by the return nil } @@ -686,9 +668,6 @@ func (s *Ethereum) StopTxPool() error { func (s *Ethereum) Stop() error { // Stop all the peer-related stuff first. s.protocolManager.Stop() - if s.lesServer != nil { - s.lesServer.Stop() - } // Then stop everything else. s.bloomIndexer.Close() diff --git a/ethclient/ethclient_test.go b/ethclient/ethclient_test.go index 8f3e59b6d..879e0ae9b 100644 --- a/ethclient/ethclient_test.go +++ b/ethclient/ethclient_test.go @@ -187,18 +187,19 @@ var ( func newTestBackend(t *testing.T) (*node.Node, []*types.Block) { // Generate test chain. genesis, blocks := generateTestChain() - - // Start Ethereum service. - var ethservice *eth.Ethereum + // Create node n, err := node.New(&node.Config{}) - n.Register(func(ctx *node.ServiceContext) (node.Service, error) { - config := ð.Config{Genesis: genesis} - config.Ethash.PowMode = ethash.ModeFake + if err != nil { + t.Fatalf("can't create new node: %v", err) + } + // Create Ethereum Service + config := ð.Config{Genesis: genesis} + config.Ethash.PowMode = ethash.ModeFake config.Pruning = false - ethservice, err = eth.New(ctx, config) - return ethservice, err - }) - + ethservice, err := eth.New(n, config) + if err != nil { + t.Fatalf("can't create new ethereum service: %v", err) + } // Import the test chain. if err := n.Start(); err != nil { t.Fatalf("can't start test node: %v", err) @@ -236,7 +237,7 @@ func generateTestChain() (*core.Genesis, []*types.Block) { func TestHeader(t *testing.T) { backend, chain := newTestBackend(t) client, _ := backend.Attach() - defer backend.Stop() + defer backend.Close() defer client.Close() tests := map[string]struct { @@ -280,7 +281,7 @@ func TestHeader(t *testing.T) { func TestBalanceAt(t *testing.T) { backend, _ := newTestBackend(t) client, _ := backend.Attach() - defer backend.Stop() + defer backend.Close() defer client.Close() tests := map[string]struct { @@ -326,7 +327,7 @@ func TestBalanceAt(t *testing.T) { func TestTransactionInBlockInterrupted(t *testing.T) { backend, _ := newTestBackend(t) client, _ := backend.Attach() - defer backend.Stop() + defer backend.Close() defer client.Close() ec := NewClient(client) @@ -344,7 +345,7 @@ func TestTransactionInBlockInterrupted(t *testing.T) { func TestChainID(t *testing.T) { backend, _ := newTestBackend(t) client, _ := backend.Attach() - defer backend.Stop() + defer backend.Close() defer client.Close() ec := NewClient(client) diff --git a/ethstats/ethstats.go b/ethstats/ethstats.go index bfa47cb0a..aab72a662 100644 --- a/ethstats/ethstats.go +++ b/ethstats/ethstats.go @@ -55,22 +55,33 @@ const ( chainHeadChanSize = 10 ) -type txPool interface { - // SubscribeNewTxsEvent should return an event subscription of - // NewTxsEvent and send events to the given channel. - SubscribeNewTxsEvent(chan<- core.NewTxsEvent) event.Subscription +// backend encompasses the bare-minimum functionality needed for ethstats reporting +type backend interface { + SubscribeChainHeadEvent(ch chan<- core.ChainHeadEvent) event.Subscription + SubscribeNewTxsEvent(ch chan<- core.NewTxsEvent) event.Subscription + CurrentHeader() *types.Header + HeaderByNumber(ctx context.Context, number rpc.BlockNumber) (*types.Header, error) + GetTd(ctx context.Context, hash common.Hash) *big.Int + Stats() (pending int, queued int) + Downloader() *downloader.Downloader } -type blockChain interface { - SubscribeChainHeadEvent(ch chan<- core.ChainHeadEvent) event.Subscription +// fullNodeBackend encompasses the functionality necessary for a full node +// reporting to ethstats +type fullNodeBackend interface { + backend + Miner() *miner.Miner + BlockByNumber(ctx context.Context, number rpc.BlockNumber) (*types.Block, error) + CurrentBlock() *types.Block + SuggestPrice(ctx context.Context) (*big.Int, error) } // Service implements an Ethereum netstats reporting daemon that pushes local // chain statistics up to a monitoring server. type Service struct { - server *p2p.Server // Peer-to-peer server to retrieve networking infos - eth *eth.Ethereum // Full Ethereum service if monitoring a full node - engine consensus.Engine // Consensus engine to retrieve variadic block fields + server *p2p.Server // Peer-to-peer server to retrieve networking infos + backend backend + engine consensus.Engine // Consensus engine to retrieve variadic block fields node string // Name of the node to display on the monitoring page pass string // Password to authorize access to the monitoring page @@ -81,47 +92,37 @@ type Service struct { } // New returns a monitoring service ready for stats reporting. -func New(url string, ethServ *eth.Ethereum) (*Service, error) { +func New(node *node.Node, backend backend, engine consensus.Engine, url string) error { // Parse the netstats connection url re := regexp.MustCompile("([^:@]*)(:([^@]*))?@(.+)") parts := re.FindStringSubmatch(url) if len(parts) != 5 { - return nil, fmt.Errorf("invalid netstats url: \"%s\", should be nodename:secret@host:port", url) + return fmt.Errorf("invalid netstats url: \"%s\", should be nodename:secret@host:port", url) } - // Assemble and return the stats service - var engine consensus.Engine - if ethServ != nil { - engine = ethServ.Engine() + ethstats := &Service{ + backend: backend, + engine: engine, + server: node.Server(), + node: parts[1], + pass: parts[3], + host: parts[4], + pongCh: make(chan struct{}), + histCh: make(chan []uint64, 1), } - return &Service{ - eth: ethServ, - engine: engine, - node: parts[1], - pass: parts[3], - host: parts[4], - pongCh: make(chan struct{}), - histCh: make(chan []uint64, 1), - }, nil + + node.RegisterLifecycle(ethstats) + return nil } -// Protocols implements node.Service, returning the P2P network protocols used -// by the stats service (nil as it doesn't use the devp2p overlay network). -func (s *Service) Protocols() []p2p.Protocol { return nil } - -// APIs implements node.Service, returning the RPC API endpoints provided by the -// stats service (nil as it doesn't provide any user callable APIs). -func (s *Service) APIs() []rpc.API { return nil } - -// Start implements node.Service, starting up the monitoring and reporting daemon. -func (s *Service) Start(server *p2p.Server) error { - s.server = server +// Start implements node.Lifecycle, starting up the monitoring and reporting daemon. +func (s *Service) Start() error { go s.loop() log.Info("Stats daemon started") return nil } -// Stop implements node.Service, terminating the monitoring and reporting daemon. +// Stop implements node.Lifecycle, terminating the monitoring and reporting daemon. func (s *Service) Stop() error { log.Info("Stats daemon stopped") return nil @@ -131,19 +132,12 @@ func (s *Service) Stop() error { // until termination. func (s *Service) loop() { // Subscribe to chain events to execute updates on - var blockchain blockChain - var txpool txPool - if s.eth != nil { - blockchain = s.eth.BlockChain() - txpool = s.eth.TxPool() - } - chainHeadCh := make(chan core.ChainHeadEvent, chainHeadChanSize) - headSub := blockchain.SubscribeChainHeadEvent(chainHeadCh) + headSub := s.backend.SubscribeChainHeadEvent(chainHeadCh) defer headSub.Unsubscribe() txEventCh := make(chan core.NewTxsEvent, txChanSize) - txSub := txpool.SubscribeNewTxsEvent(txEventCh) + txSub := s.backend.SubscribeNewTxsEvent(txEventCh) defer txSub.Unsubscribe() // Start a goroutine that exhausts the subscriptions to avoid events piling up @@ -549,13 +543,15 @@ func (s *Service) assembleBlockStats(block *types.Block) *blockStats { txs []txStats uncles []*types.Header ) - if s.eth != nil { - // Full nodes have all needed information available + + // check if backend is a full node + fullBackend, ok := s.backend.(fullNodeBackend) + if ok { if block == nil { - block = s.eth.BlockChain().CurrentBlock() + block = fullBackend.CurrentBlock() } header = block.Header() - td = s.eth.BlockChain().GetTd(header.Hash(), header.Number.Uint64()) + td = fullBackend.GetTd(context.Background(), header.Hash()) txs = make([]txStats, len(block.Transactions())) for i, tx := range block.Transactions() { @@ -563,6 +559,7 @@ func (s *Service) assembleBlockStats(block *types.Block) *blockStats { } uncles = block.Uncles() } + // Assemble and return the block stats author, _ := s.engine.Author(header) @@ -593,10 +590,7 @@ func (s *Service) reportHistory(conn *websocket.Conn, list []uint64) error { indexes = append(indexes, list...) } else { // No indexes requested, send back the top ones - var head int64 - if s.eth != nil { - head = s.eth.BlockChain().CurrentHeader().Number.Int64() - } + head := s.backend.CurrentHeader().Number.Int64() start := head - historyUpdateRange + 1 if start < 0 { start = 0 @@ -608,10 +602,14 @@ func (s *Service) reportHistory(conn *websocket.Conn, list []uint64) error { // Gather the batch of blocks to report history := make([]*blockStats, len(indexes)) for i, number := range indexes { + fullBackend, ok := s.backend.(fullNodeBackend) // Retrieve the next block if it's known to us var block *types.Block - if s.eth != nil { - block = s.eth.BlockChain().GetBlockByNumber(number) + if ok { + block, _ = fullBackend.BlockByNumber(context.Background(), rpc.BlockNumber(number)) // TODO ignore error here ? + } else { + if header, _ := s.backend.HeaderByNumber(context.Background(), rpc.BlockNumber(number)); header != nil { + block = types.NewBlockWithHeader(header) } // If we do have the block, add to the history and continue if block != nil { @@ -647,10 +645,7 @@ type pendStats struct { // it to the stats server. func (s *Service) reportPending(conn *websocket.Conn) error { // Retrieve the pending count from the local blockchain - var pending int - if s.eth != nil { - pending, _ = s.eth.TxPool().Stats() - } + pending, _ := s.backend.Stats() // Assemble the transaction stats and send it to the server log.Trace("Sending pending transactions to ethstats", "count", pending) @@ -677,7 +672,7 @@ type nodeStats struct { Uptime int `json:"uptime"` } -// reportPending retrieves various stats about the node at the networking and +// reportStats retrieves various stats about the node at the networking and // mining layer and reports it to the stats server. func (s *Service) reportStats(conn *websocket.Conn) error { // Gather the syncing and mining infos from the local miner instance @@ -687,14 +682,16 @@ func (s *Service) reportStats(conn *websocket.Conn) error { syncing bool gasprice int ) - if s.eth != nil { - mining = s.eth.Miner().Mining() - hashrate = int(s.eth.Miner().HashRate()) + // check if backend is a full node + fullBackend, ok := s.backend.(fullNodeBackend) + if ok { + mining = fullBackend.Miner().Mining() + hashrate = int(fullBackend.Miner().HashRate()) - sync := s.eth.Downloader().Progress() - syncing = s.eth.BlockChain().CurrentHeader().Number.Uint64() >= sync.HighestBlock + sync := fullBackend.Downloader().Progress() + syncing = fullBackend.CurrentHeader().Number.Uint64() >= sync.HighestBlock - price, _ := s.eth.APIBackend.SuggestPrice(context.Background()) + price, _ := fullBackend.SuggestPrice(context.Background()) gasprice = int(price.Uint64()) } // Assemble the node stats and send it to the server diff --git a/graphql/graphql_test.go b/graphql/graphql_test.go index 40b13187f..5ba9c9553 100644 --- a/graphql/graphql_test.go +++ b/graphql/graphql_test.go @@ -17,12 +17,118 @@ package graphql import ( + "fmt" + "io/ioutil" + "net/http" + "strings" "testing" + + "github.com/ethereum/go-ethereum/eth" + "github.com/ethereum/go-ethereum/node" + "github.com/stretchr/testify/assert" ) func TestBuildSchema(t *testing.T) { + stack, err := node.New(&node.DefaultConfig) + if err != nil { + t.Fatalf("could not create new node: %v", err) + } // Make sure the schema can be parsed and matched up to the object model. - if _, err := newHandler(nil); err != nil { + if err := newHandler(stack, nil, []string{}, []string{}); err != nil { t.Errorf("Could not construct GraphQL handler: %v", err) } } + +// Tests that a graphQL request is successfully handled when graphql is enabled on the specified endpoint +func TestGraphQLHTTPOnSamePort_GQLRequest_Successful(t *testing.T) { + stack := createNode(t, true) + defer stack.Close() + // start node + if err := stack.Start(); err != nil { + t.Fatalf("could not start node: %v", err) + } + // create http request + body := strings.NewReader("{\"query\": \"{block{number}}\",\"variables\": null}") + gqlReq, err := http.NewRequest(http.MethodGet, fmt.Sprintf("http://%s/graphql", "127.0.0.1:9393"), body) + if err != nil { + t.Error("could not issue new http request ", err) + } + gqlReq.Header.Set("Content-Type", "application/json") + // read from response + resp := doHTTPRequest(t, gqlReq) + bodyBytes, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatalf("could not read from response body: %v", err) + } + expected := "{\"data\":{\"block\":{\"number\":\"0x0\"}}}" + assert.Equal(t, expected, string(bodyBytes)) +} + +// Tests that a graphQL request is not handled successfully when graphql is not enabled on the specified endpoint +func TestGraphQLHTTPOnSamePort_GQLRequest_Unsuccessful(t *testing.T) { + stack := createNode(t, false) + defer stack.Close() + if err := stack.Start(); err != nil { + t.Fatalf("could not start node: %v", err) + } + + // create http request + body := strings.NewReader("{\"query\": \"{block{number}}\",\"variables\": null}") + gqlReq, err := http.NewRequest(http.MethodPost, fmt.Sprintf("http://%s/graphql", "127.0.0.1:9393"), body) + if err != nil { + t.Error("could not issue new http request ", err) + } + gqlReq.Header.Set("Content-Type", "application/json") + // read from response + resp := doHTTPRequest(t, gqlReq) + bodyBytes, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatalf("could not read from response body: %v", err) + } + // make sure the request is not handled successfully + assert.Equal(t, 404, resp.StatusCode) + assert.Equal(t, "404 page not found\n", string(bodyBytes)) +} + +func createNode(t *testing.T, gqlEnabled bool) *node.Node { + stack, err := node.New(&node.Config{ + HTTPHost: "127.0.0.1", + HTTPPort: 9393, + WSHost: "127.0.0.1", + WSPort: 9393, + }) + if err != nil { + t.Fatalf("could not create node: %v", err) + } + if !gqlEnabled { + return stack + } + + createGQLService(t, stack, "127.0.0.1:9393") + + return stack +} + +func createGQLService(t *testing.T, stack *node.Node, endpoint string) { + // create backend + ethBackend, err := eth.New(stack, ð.DefaultConfig) + if err != nil { + t.Fatalf("could not create eth backend: %v", err) + } + + // create gql service + err = New(stack, ethBackend.APIBackend, []string{}, []string{}) + if err != nil { + t.Fatalf("could not create graphql service: %v", err) + } +} + +func doHTTPRequest(t *testing.T, req *http.Request) *http.Response { + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + t.Fatal("could not issue a GET request to the given endpoint", err) + + } + return resp +} diff --git a/graphql/service.go b/graphql/service.go index d44f17593..a602e9421 100644 --- a/graphql/service.go +++ b/graphql/service.go @@ -19,8 +19,6 @@ package graphql import ( "fmt" "net" - "net/http" - "github.com/graph-gophers/graphql-go" "github.com/graph-gophers/graphql-go/relay" "github.com/ledgerwatch/turbo-geth/internal/ethapi" @@ -30,86 +28,30 @@ import ( "github.com/ledgerwatch/turbo-geth/rpc" ) -// Service encapsulates a GraphQL service. -type Service struct { - endpoint string // The host:port endpoint for this service. - cors []string // Allowed CORS domains - vhosts []string // Recognised vhosts - timeouts rpc.HTTPTimeouts // Timeout settings for HTTP requests. - backend ethapi.Backend // The backend that queries will operate on. - handler http.Handler // The `http.Handler` used to answer queries. - listener net.Listener // The listening socket. -} - // New constructs a new GraphQL service instance. -func New(backend ethapi.Backend, endpoint string, cors, vhosts []string, timeouts rpc.HTTPTimeouts) (*Service, error) { - return &Service{ - endpoint: endpoint, - cors: cors, - vhosts: vhosts, - timeouts: timeouts, - backend: backend, - }, nil -} - -// Protocols returns the list of protocols exported by this service. -func (s *Service) Protocols() []p2p.Protocol { return nil } - -// APIs returns the list of APIs exported by this service. -func (s *Service) APIs() []rpc.API { return nil } - -// Start is called after all services have been constructed and the networking -// layer was also initialized to spawn any goroutines required by the service. -func (s *Service) Start(server *p2p.Server) error { - var err error - s.handler, err = newHandler(s.backend) - if err != nil { - return err +func New(stack *node.Node, backend ethapi.Backend, cors, vhosts []string) error { + if backend == nil { + panic("missing backend") } - if s.listener, err = net.Listen("tcp", s.endpoint); err != nil { - return err - } - // create handler stack and wrap the graphql handler - handler := node.NewHTTPHandlerStack(s.handler, s.cors, s.vhosts) - // make sure timeout values are meaningful - node.CheckTimeouts(&s.timeouts) - // create http server - httpSrv := &http.Server{ - Handler: handler, - ReadTimeout: s.timeouts.ReadTimeout, - WriteTimeout: s.timeouts.WriteTimeout, - IdleTimeout: s.timeouts.IdleTimeout, - } - go httpSrv.Serve(s.listener) - log.Info("GraphQL endpoint opened", "url", fmt.Sprintf("http://%s", s.endpoint)) - return nil + // check if http server with given endpoint exists and enable graphQL on it + return newHandler(stack, backend, cors, vhosts) } // newHandler returns a new `http.Handler` that will answer GraphQL queries. // It additionally exports an interactive query browser on the / endpoint. -func newHandler(backend ethapi.Backend) (http.Handler, error) { +func newHandler(stack *node.Node, backend ethapi.Backend, cors, vhosts []string) error { q := Resolver{backend} s, err := graphql.ParseSchema(schema, &q) if err != nil { - return nil, err + return err } h := &relay.Handler{Schema: s} + handler := node.NewHTTPHandlerStack(h, cors, vhosts) - mux := http.NewServeMux() - mux.Handle("/", GraphiQL{}) - mux.Handle("/graphql", h) - mux.Handle("/graphql/", h) - return mux, nil -} + stack.RegisterHandler("GraphQL UI", "/graphql/ui", GraphiQL{}) + stack.RegisterHandler("GraphQL", "/graphql", handler) + stack.RegisterHandler("GraphQL", "/graphql/", handler) -// Stop terminates all goroutines belonging to the service, blocking until they -// are all terminated. -func (s *Service) Stop() error { - if s.listener != nil { - s.listener.Close() - s.listener = nil - log.Info("GraphQL endpoint closed", "url", fmt.Sprintf("http://%s", s.endpoint)) - } return nil } diff --git a/internal/ethapi/backend.go b/internal/ethapi/backend.go index 62de2bb6f..54cf8b138 100644 --- a/internal/ethapi/backend.go +++ b/internal/ethapi/backend.go @@ -45,14 +45,16 @@ type Backend interface { ChainDb() ethdb.Database AccountManager() *accounts.Manager ExtRPCEnabled() bool - RPCTxFeeCap() float64 // global tx fee cap for all transaction related APIs RPCGasCap() uint64 // global gas cap for eth_call over rpc: DoS protection + RPCTxFeeCap() float64 // global tx fee cap for all transaction related APIs // Blockchain API SetHead(number uint64) HeaderByNumber(ctx context.Context, number rpc.BlockNumber) (*types.Header, error) HeaderByHash(ctx context.Context, hash common.Hash) (*types.Header, error) HeaderByNumberOrHash(ctx context.Context, blockNrOrHash rpc.BlockNumberOrHash) (*types.Header, error) + CurrentHeader() *types.Header + CurrentBlock() *types.Block BlockByNumber(ctx context.Context, number rpc.BlockNumber) (*types.Block, error) BlockByHash(ctx context.Context, hash common.Hash) (*types.Block, error) BlockByNumberOrHash(ctx context.Context, blockNrOrHash rpc.BlockNumberOrHash) (*types.Block, error) @@ -84,7 +86,7 @@ type Backend interface { SubscribeRemovedLogsEvent(ch chan<- core.RemovedLogsEvent) event.Subscription ChainConfig() *params.ChainConfig - CurrentBlock() *types.Block + Engine() consensus.Engine } func GetAPIs(apiBackend Backend) []rpc.API { diff --git a/miner/stresstest/stress_clique.go b/miner/stresstest/stress_clique.go index 439f5b5e5..fb75d3c3e 100644 --- a/miner/stresstest/stress_clique.go +++ b/miner/stresstest/stress_clique.go @@ -60,30 +60,31 @@ func main() { genesis := makeGenesis(faucets, sealers) var ( - nodes []*node.Node + nodes []*eth.Ethereum enodes []*enode.Node ) + for _, sealer := range sealers { // Start the node and wait until it's up - node, err := makeSealer(genesis) + stack, ethBackend, err := makeSealer(genesis) if err != nil { panic(err) } - defer node.Close() + defer stack.Close() - for node.Server().NodeInfo().Ports.Listener == 0 { + for stack.Server().NodeInfo().Ports.Listener == 0 { time.Sleep(250 * time.Millisecond) } - // Connect the node to al the previous ones + // Connect the node to all the previous ones for _, n := range enodes { - node.Server().AddPeer(n) + stack.Server().AddPeer(n) } - // Start tracking the node and it's enode - nodes = append(nodes, node) - enodes = append(enodes, node.Server().Self()) + // Start tracking the node and its enode + nodes = append(nodes, ethBackend) + enodes = append(enodes, stack.Server().Self()) // Inject the signer key and start sealing with it - store := node.AccountManager().Backends(keystore.KeyStoreType)[0].(*keystore.KeyStore) + store := stack.AccountManager().Backends(keystore.KeyStoreType)[0].(*keystore.KeyStore) signer, err := store.ImportECDSA(sealer, "") if err != nil { panic(err) @@ -92,15 +93,11 @@ func main() { panic(err) } } - // Iterate over all the nodes and start signing with them - time.Sleep(3 * time.Second) + // Iterate over all the nodes and start signing on them + time.Sleep(3 * time.Second) for _, node := range nodes { - var ethereum *eth.Ethereum - if err := node.Service(ðereum); err != nil { - panic(err) - } - if err := ethereum.StartMining(1); err != nil { + if err := node.StartMining(1); err != nil { panic(err) } } @@ -109,25 +106,22 @@ func main() { // Start injecting transactions from the faucet like crazy nonces := make([]uint64, len(faucets)) for { + // Pick a random signer node index := rand.Intn(len(faucets)) + backend := nodes[index%len(nodes)] - // Fetch the accessor for the relevant signer - var ethereum *eth.Ethereum - if err := nodes[index%len(nodes)].Service(ðereum); err != nil { - panic(err) - } // Create a self transaction and inject into the pool tx, err := types.SignTx(types.NewTransaction(nonces[index], crypto.PubkeyToAddress(faucets[index].PublicKey), new(big.Int), 21000, big.NewInt(100000000000), nil), types.HomesteadSigner{}, faucets[index]) if err != nil { panic(err) } - if err := ethereum.TxPool().AddLocal(tx); err != nil { + if err := backend.TxPool().AddLocal(tx); err != nil { panic(err) } nonces[index]++ // Wait if we're too saturated - if pend, _ := ethereum.TxPool().Stats(); pend > 2048 { + if pend, _ := backend.TxPool().Stats(); pend > 2048 { time.Sleep(100 * time.Millisecond) } } @@ -170,7 +164,7 @@ func makeGenesis(faucets []*ecdsa.PrivateKey, sealers []*ecdsa.PrivateKey) *core return genesis } -func makeSealer(genesis *core.Genesis) (*node.Node, error) { +func makeSealer(genesis *core.Genesis) (*node.Node, *eth.Ethereum, error) { // Define the basic configurations for the Ethereum node datadir, _ := ioutil.TempDir("", "") @@ -188,27 +182,28 @@ func makeSealer(genesis *core.Genesis) (*node.Node, error) { // Start the node and configure a full Ethereum node on it stack, err := node.New(config) if err != nil { - return nil, err + return nil, nil, err } - if err := stack.Register(func(ctx *node.ServiceContext) (node.Service, error) { - return eth.New(ctx, ð.Config{ - Genesis: genesis, - NetworkID: genesis.Config.ChainID.Uint64(), - SyncMode: downloader.FullSync, - DatabaseCache: 256, - DatabaseHandles: 256, - TxPool: core.DefaultTxPoolConfig, - GPO: eth.DefaultConfig.GPO, - Miner: miner.Config{ - GasFloor: genesis.GasLimit * 9 / 10, - GasCeil: genesis.GasLimit * 11 / 10, - GasPrice: big.NewInt(1), - Recommit: time.Second, - }, - }) - }); err != nil { - return nil, err + // Create and register the backend + ethBackend, err := eth.New(stack, ð.Config{ + Genesis: genesis, + NetworkId: genesis.Config.ChainID.Uint64(), + SyncMode: downloader.FullSync, + DatabaseCache: 256, + DatabaseHandles: 256, + TxPool: core.DefaultTxPoolConfig, + GPO: eth.DefaultConfig.GPO, + Miner: miner.Config{ + GasFloor: genesis.GasLimit * 9 / 10, + GasCeil: genesis.GasLimit * 11 / 10, + GasPrice: big.NewInt(1), + Recommit: time.Second, + }, + }) + if err != nil { + return nil, nil, err } - // Start the node and return if successful - return stack, stack.Start() + + err = stack.Start() + return stack, ethBackend, err } diff --git a/miner/stresstest/stress_ethash.go b/miner/stresstest/stress_ethash.go index 88731d714..d93a31e1e 100644 --- a/miner/stresstest/stress_ethash.go +++ b/miner/stresstest/stress_ethash.go @@ -65,43 +65,39 @@ func main() { genesis := makeGenesis(faucets) var ( - nodes []*node.Node + nodes []*eth.Ethereum enodes []*enode.Node ) for i := 0; i < n; i++ { // Start the node and wait until it's up - node, err := makeMiner(genesis) + stack, ethBackend, err := makeMiner(genesis) if err != nil { panic(err) } - defer node.Close() + defer stack.Close() - for node.Server().NodeInfo().Ports.Listener == 0 { + for stack.Server().NodeInfo().Ports.Listener == 0 { time.Sleep(250 * time.Millisecond) } - // Connect the node to al the previous ones + // Connect the node to all the previous ones for _, n := range enodes { - node.Server().AddPeer(n) + stack.Server().AddPeer(n) } - // Start tracking the node and it's enode - nodes = append(nodes, node) - enodes = append(enodes, node.Server().Self()) + // Start tracking the node and its enode + nodes = append(nodes, ethBackend) + enodes = append(enodes, stack.Server().Self()) // Inject the signer key and start sealing with it - store := node.AccountManager().Backends(keystore.KeyStoreType)[0].(*keystore.KeyStore) + store := stack.AccountManager().Backends(keystore.KeyStoreType)[0].(*keystore.KeyStore) if _, err := store.NewAccount(""); err != nil { panic(err) } } - // Iterate over all the nodes and start signing with them - time.Sleep(3 * time.Second) + // Iterate over all the nodes and start mining + time.Sleep(3 * time.Second) for _, node := range nodes { - var ethereum *eth.Ethereum - if err := node.Service(ðereum); err != nil { - panic(err) - } - if err := ethereum.StartMining(1); err != nil { + if err := node.StartMining(1); err != nil { panic(err) } } @@ -110,19 +106,16 @@ func main() { // Start injecting transactions from the faucets like crazy nonces := make([]uint64, len(faucets)) for { + // Pick a random mining node index := rand.Intn(len(faucets)) + backend := nodes[index%len(nodes)] - // Fetch the accessor for the relevant signer - var ethereum *eth.Ethereum - if err := nodes[index%len(nodes)].Service(ðereum); err != nil { - panic(err) - } // Create a self transaction and inject into the pool tx, err := types.SignTx(types.NewTransaction(nonces[index], crypto.PubkeyToAddress(faucets[index].PublicKey), new(big.Int), 21000, big.NewInt(100000000000+rand.Int63n(65536)), nil), types.HomesteadSigner{}, faucets[index]) if err != nil { panic(err) } - if err := ethereum.TxPool().AddLocal(tx); err != nil { + if err := backend.TxPool().AddLocal(tx); err != nil { panic(err) } nonces[index]++ @@ -153,7 +146,7 @@ func makeGenesis(faucets []*ecdsa.PrivateKey) *core.Genesis { return genesis } -func makeMiner(genesis *core.Genesis) (*node.Node, error) { +func makeMiner(genesis *core.Genesis) (*node.Node, *eth.Ethereum, error) { // Define the basic configurations for the Ethereum node datadir, _ := ioutil.TempDir("", "") @@ -169,13 +162,12 @@ func makeMiner(genesis *core.Genesis) (*node.Node, error) { NoUSB: true, UseLightweightKDF: true, } - // Start the node and configure a full Ethereum node on it + // Create the node and configure a full Ethereum node on it stack, err := node.New(config) if err != nil { - return nil, err + return nil, nil, err } - - ethConfig := ð.Config{ + ethBackend, err := eth.New(stack, ð.Config{ Genesis: genesis, NetworkID: genesis.Config.ChainID.Uint64(), SyncMode: downloader.FullSync, @@ -193,10 +185,13 @@ func makeMiner(genesis *core.Genesis) (*node.Node, error) { BlocksBeforePruning: 100, BlocksToPrune: 10, PruningTimeout: time.Second, + }) + if err != nil { + return nil, nil, err } - if err := stack.Register(func(ctx *node.ServiceContext) (node.Service, error) { - return eth.New(ctx, ethConfig) + err = stack.Start() + return stack, ethBackend, err }); err != nil { return nil, errors.Wrap(err, fmt.Sprintf("cannot register stress test miner. config %v", ethConfig)) } diff --git a/node/api.go b/node/api.go index 47e067ce7..deebcac95 100644 --- a/node/api.go +++ b/node/api.go @@ -28,21 +28,40 @@ import ( "github.com/ledgerwatch/turbo-geth/rpc" ) -// PrivateAdminAPI is the collection of administrative API methods exposed only -// over a secure RPC channel. -type PrivateAdminAPI struct { - node *Node // Node interfaced by this API +// apis returns the collection of built-in RPC APIs. +func (n *Node) apis() []rpc.API { + return []rpc.API{ + { + Namespace: "admin", + Version: "1.0", + Service: &privateAdminAPI{n}, + }, { + Namespace: "admin", + Version: "1.0", + Service: &publicAdminAPI{n}, + Public: true, + }, { + Namespace: "debug", + Version: "1.0", + Service: debug.Handler, + }, { + Namespace: "web3", + Version: "1.0", + Service: &publicWeb3API{n}, + Public: true, + }, + } } -// NewPrivateAdminAPI creates a new API definition for the private admin methods -// of the node itself. -func NewPrivateAdminAPI(node *Node) *PrivateAdminAPI { - return &PrivateAdminAPI{node: node} +// privateAdminAPI is the collection of administrative API methods exposed only +// over a secure RPC channel. +type privateAdminAPI struct { + node *Node // Node interfaced by this API } // AddPeer requests connecting to a remote node, and also maintaining the new // connection at all times, even reconnecting if it is lost. -func (api *PrivateAdminAPI) AddPeer(url string) (bool, error) { +func (api *privateAdminAPI) AddPeer(url string) (bool, error) { // Make sure the server is running, fail otherwise server := api.node.Server() if server == nil { @@ -58,7 +77,7 @@ func (api *PrivateAdminAPI) AddPeer(url string) (bool, error) { } // RemovePeer disconnects from a remote node if the connection exists -func (api *PrivateAdminAPI) RemovePeer(url string) (bool, error) { +func (api *privateAdminAPI) RemovePeer(url string) (bool, error) { // Make sure the server is running, fail otherwise server := api.node.Server() if server == nil { @@ -74,7 +93,7 @@ func (api *PrivateAdminAPI) RemovePeer(url string) (bool, error) { } // AddTrustedPeer allows a remote node to always connect, even if slots are full -func (api *PrivateAdminAPI) AddTrustedPeer(url string) (bool, error) { +func (api *privateAdminAPI) AddTrustedPeer(url string) (bool, error) { // Make sure the server is running, fail otherwise server := api.node.Server() if server == nil { @@ -90,7 +109,7 @@ func (api *PrivateAdminAPI) AddTrustedPeer(url string) (bool, error) { // RemoveTrustedPeer removes a remote node from the trusted peer set, but it // does not disconnect it automatically. -func (api *PrivateAdminAPI) RemoveTrustedPeer(url string) (bool, error) { +func (api *privateAdminAPI) RemoveTrustedPeer(url string) (bool, error) { // Make sure the server is running, fail otherwise server := api.node.Server() if server == nil { @@ -106,7 +125,7 @@ func (api *PrivateAdminAPI) RemoveTrustedPeer(url string) (bool, error) { // PeerEvents creates an RPC subscription which receives peer events from the // node's p2p.Server -func (api *PrivateAdminAPI) PeerEvents(ctx context.Context) (*rpc.Subscription, error) { +func (api *privateAdminAPI) PeerEvents(ctx context.Context) (*rpc.Subscription, error) { // Make sure the server is running, fail otherwise server := api.node.Server() if server == nil { @@ -143,14 +162,11 @@ func (api *PrivateAdminAPI) PeerEvents(ctx context.Context) (*rpc.Subscription, } // StartRPC starts the HTTP RPC API server. -func (api *PrivateAdminAPI) StartRPC(host *string, port *int, cors *string, apis *string, vhosts *string) (bool, error) { +func (api *privateAdminAPI) StartRPC(host *string, port *int, cors *string, apis *string, vhosts *string) (bool, error) { api.node.lock.Lock() defer api.node.lock.Unlock() - if api.node.httpHandler != nil { - return false, fmt.Errorf("HTTP RPC already running on %s", api.node.httpEndpoint) - } - + // Determine host and port. if host == nil { h := DefaultHTTPHost if api.node.config.HTTPHost != "" { @@ -162,57 +178,55 @@ func (api *PrivateAdminAPI) StartRPC(host *string, port *int, cors *string, apis port = &api.node.config.HTTPPort } - allowedOrigins := api.node.config.HTTPCors + // Determine config. + config := httpConfig{ + CorsAllowedOrigins: api.node.config.HTTPCors, + Vhosts: api.node.config.HTTPVirtualHosts, + Modules: api.node.config.HTTPModules, + } if cors != nil { - allowedOrigins = nil + config.CorsAllowedOrigins = nil for _, origin := range strings.Split(*cors, ",") { - allowedOrigins = append(allowedOrigins, strings.TrimSpace(origin)) + config.CorsAllowedOrigins = append(config.CorsAllowedOrigins, strings.TrimSpace(origin)) } } - - allowedVHosts := api.node.config.HTTPVirtualHosts if vhosts != nil { - allowedVHosts = nil + config.Vhosts = nil for _, vhost := range strings.Split(*host, ",") { - allowedVHosts = append(allowedVHosts, strings.TrimSpace(vhost)) + config.Vhosts = append(config.Vhosts, strings.TrimSpace(vhost)) } } - - modules := api.node.httpWhitelist if apis != nil { - modules = nil + config.Modules = nil for _, m := range strings.Split(*apis, ",") { - modules = append(modules, strings.TrimSpace(m)) + config.Modules = append(config.Modules, strings.TrimSpace(m)) } } - if err := api.node.startHTTP(fmt.Sprintf("%s:%d", *host, *port), api.node.rpcAPIs, modules, allowedOrigins, allowedVHosts, api.node.config.HTTPTimeouts, api.node.config.WSOrigins); err != nil { + if err := api.node.http.setListenAddr(*host, *port); err != nil { + return false, err + } + if err := api.node.http.enableRPC(api.node.rpcAPIs, config); err != nil { + return false, err + } + if err := api.node.http.start(); err != nil { return false, err } return true, nil } -// StopRPC terminates an already running HTTP RPC API endpoint. -func (api *PrivateAdminAPI) StopRPC() (bool, error) { - api.node.lock.Lock() - defer api.node.lock.Unlock() - - if api.node.httpHandler == nil { - return false, fmt.Errorf("HTTP RPC not running") - } - api.node.stopHTTP() +// StopRPC shuts down the HTTP server. +func (api *privateAdminAPI) StopRPC() (bool, error) { + api.node.http.stop() return true, nil } // StartWS starts the websocket RPC API server. -func (api *PrivateAdminAPI) StartWS(host *string, port *int, allowedOrigins *string, apis *string) (bool, error) { +func (api *privateAdminAPI) StartWS(host *string, port *int, allowedOrigins *string, apis *string) (bool, error) { api.node.lock.Lock() defer api.node.lock.Unlock() - if api.node.wsHandler != nil { - return false, fmt.Errorf("WebSocket RPC already running on %s", api.node.wsEndpoint) - } - + // Determine host and port. if host == nil { h := DefaultWSHost if api.node.config.WSHost != "" { @@ -224,43 +238,50 @@ func (api *PrivateAdminAPI) StartWS(host *string, port *int, allowedOrigins *str port = &api.node.config.WSPort } - origins := api.node.config.WSOrigins - if allowedOrigins != nil { - origins = nil - for _, origin := range strings.Split(*allowedOrigins, ",") { - origins = append(origins, strings.TrimSpace(origin)) - } + // Determine config. + config := wsConfig{ + Modules: api.node.config.WSModules, + Origins: api.node.config.WSOrigins, + // ExposeAll: api.node.config.WSExposeAll, } - - modules := api.node.config.WSModules if apis != nil { - modules = nil + config.Modules = nil for _, m := range strings.Split(*apis, ",") { - modules = append(modules, strings.TrimSpace(m)) + config.Modules = append(config.Modules, strings.TrimSpace(m)) + } + } + if allowedOrigins != nil { + config.Origins = nil + for _, origin := range strings.Split(*allowedOrigins, ",") { + config.Origins = append(config.Origins, strings.TrimSpace(origin)) } } - if err := api.node.startWS(fmt.Sprintf("%s:%d", *host, *port), api.node.rpcAPIs, modules, origins, api.node.config.WSExposeAll); err != nil { + // Enable WebSocket on the server. + server := api.node.wsServerForPort(*port) + if err := server.setListenAddr(*host, *port); err != nil { return false, err } - return true, nil -} - -// StopWS terminates an already running websocket RPC API endpoint. -func (api *PrivateAdminAPI) StopWS() (bool, error) { - api.node.lock.Lock() - defer api.node.lock.Unlock() - - if api.node.wsHandler == nil { - return false, fmt.Errorf("WebSocket RPC not running") + if err := server.enableWS(api.node.rpcAPIs, config); err != nil { + return false, err } - api.node.stopWS() + if err := server.start(); err != nil { + return false, err + } + api.node.http.log.Info("WebSocket endpoint opened", "url", api.node.WSEndpoint()) return true, nil } -// PublicAdminAPI is the collection of administrative API methods exposed over +// StopWS terminates all WebSocket servers. +func (api *privateAdminAPI) StopWS() (bool, error) { + api.node.http.stopWS() + api.node.ws.stop() + return true, nil +} + +// publicAdminAPI is the collection of administrative API methods exposed over // both secure and unsecure RPC channels. -type PublicAdminAPI struct { +type publicAdminAPI struct { node *Node // Node interfaced by this API } @@ -272,7 +293,7 @@ func NewPublicAdminAPI(node *Node) *PublicAdminAPI { // Peers retrieves all the information we know about each individual peer at the // protocol granularity. -func (api *PublicAdminAPI) Peers() ([]*p2p.PeerInfo, error) { +func (api *publicAdminAPI) Peers() ([]*p2p.PeerInfo, error) { server := api.node.Server() if server == nil { return nil, ErrNodeStopped @@ -282,7 +303,7 @@ func (api *PublicAdminAPI) Peers() ([]*p2p.PeerInfo, error) { // NodeInfo retrieves all the information we know about the host node at the // protocol granularity. -func (api *PublicAdminAPI) NodeInfo() (*p2p.NodeInfo, error) { +func (api *publicAdminAPI) NodeInfo() (*p2p.NodeInfo, error) { server := api.node.Server() if server == nil { return nil, ErrNodeStopped @@ -291,27 +312,22 @@ func (api *PublicAdminAPI) NodeInfo() (*p2p.NodeInfo, error) { } // Datadir retrieves the current data directory the node is using. -func (api *PublicAdminAPI) Datadir() string { +func (api *publicAdminAPI) Datadir() string { return api.node.DataDir() } -// PublicWeb3API offers helper utils -type PublicWeb3API struct { +// publicWeb3API offers helper utils +type publicWeb3API struct { stack *Node } -// NewPublicWeb3API creates a new Web3Service instance -func NewPublicWeb3API(stack *Node) *PublicWeb3API { - return &PublicWeb3API{stack} -} - // ClientVersion returns the node name -func (s *PublicWeb3API) ClientVersion() string { +func (s *publicWeb3API) ClientVersion() string { return s.stack.Server().Name } // Sha3 applies the ethereum sha3 implementation on the input. // It assumes the input is hex encoded. -func (s *PublicWeb3API) Sha3(input hexutil.Bytes) hexutil.Bytes { +func (s *publicWeb3API) Sha3(input hexutil.Bytes) hexutil.Bytes { return crypto.Keccak256(input) } diff --git a/node/api_test.go b/node/api_test.go new file mode 100644 index 000000000..e4c08962c --- /dev/null +++ b/node/api_test.go @@ -0,0 +1,350 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package node + +import ( + "bytes" + "io" + "net" + "net/http" + "net/url" + "strings" + "testing" + + "github.com/ethereum/go-ethereum/rpc" + "github.com/stretchr/testify/assert" +) + +// This test uses the admin_startRPC and admin_startWS APIs, +// checking whether the HTTP server is started correctly. +func TestStartRPC(t *testing.T) { + type test struct { + name string + cfg Config + fn func(*testing.T, *Node, *privateAdminAPI) + + // Checks. These run after the node is configured and all API calls have been made. + wantReachable bool // whether the HTTP server should be reachable at all + wantHandlers bool // whether RegisterHandler handlers should be accessible + wantRPC bool // whether JSON-RPC/HTTP should be accessible + wantWS bool // whether JSON-RPC/WS should be accessible + } + + tests := []test{ + { + name: "all off", + cfg: Config{}, + fn: func(t *testing.T, n *Node, api *privateAdminAPI) { + }, + wantReachable: false, + wantHandlers: false, + wantRPC: false, + wantWS: false, + }, + { + name: "rpc enabled through config", + cfg: Config{HTTPHost: "127.0.0.1"}, + fn: func(t *testing.T, n *Node, api *privateAdminAPI) { + }, + wantReachable: true, + wantHandlers: true, + wantRPC: true, + wantWS: false, + }, + { + name: "rpc enabled through API", + cfg: Config{}, + fn: func(t *testing.T, n *Node, api *privateAdminAPI) { + _, err := api.StartRPC(sp("127.0.0.1"), ip(0), nil, nil, nil) + assert.NoError(t, err) + }, + wantReachable: true, + wantHandlers: true, + wantRPC: true, + wantWS: false, + }, + { + name: "rpc start again after failure", + cfg: Config{}, + fn: func(t *testing.T, n *Node, api *privateAdminAPI) { + // Listen on a random port. + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal("can't listen:", err) + } + defer listener.Close() + port := listener.Addr().(*net.TCPAddr).Port + + // Now try to start RPC on that port. This should fail. + _, err = api.StartRPC(sp("127.0.0.1"), ip(port), nil, nil, nil) + if err == nil { + t.Fatal("StartRPC should have failed on port", port) + } + + // Try again after unblocking the port. It should work this time. + listener.Close() + _, err = api.StartRPC(sp("127.0.0.1"), ip(port), nil, nil, nil) + assert.NoError(t, err) + }, + wantReachable: true, + wantHandlers: true, + wantRPC: true, + wantWS: false, + }, + { + name: "rpc stopped through API", + cfg: Config{HTTPHost: "127.0.0.1"}, + fn: func(t *testing.T, n *Node, api *privateAdminAPI) { + _, err := api.StopRPC() + assert.NoError(t, err) + }, + wantReachable: false, + wantHandlers: false, + wantRPC: false, + wantWS: false, + }, + { + name: "rpc stopped twice", + cfg: Config{HTTPHost: "127.0.0.1"}, + fn: func(t *testing.T, n *Node, api *privateAdminAPI) { + _, err := api.StopRPC() + assert.NoError(t, err) + + _, err = api.StopRPC() + assert.NoError(t, err) + }, + wantReachable: false, + wantHandlers: false, + wantRPC: false, + wantWS: false, + }, + { + name: "ws enabled through config", + cfg: Config{WSHost: "127.0.0.1"}, + wantReachable: true, + wantHandlers: false, + wantRPC: false, + wantWS: true, + }, + { + name: "ws enabled through API", + cfg: Config{}, + fn: func(t *testing.T, n *Node, api *privateAdminAPI) { + _, err := api.StartWS(sp("127.0.0.1"), ip(0), nil, nil) + assert.NoError(t, err) + }, + wantReachable: true, + wantHandlers: false, + wantRPC: false, + wantWS: true, + }, + { + name: "ws stopped through API", + cfg: Config{WSHost: "127.0.0.1"}, + fn: func(t *testing.T, n *Node, api *privateAdminAPI) { + _, err := api.StopWS() + assert.NoError(t, err) + }, + wantReachable: false, + wantHandlers: false, + wantRPC: false, + wantWS: false, + }, + { + name: "ws stopped twice", + cfg: Config{WSHost: "127.0.0.1"}, + fn: func(t *testing.T, n *Node, api *privateAdminAPI) { + _, err := api.StopWS() + assert.NoError(t, err) + + _, err = api.StopWS() + assert.NoError(t, err) + }, + wantReachable: false, + wantHandlers: false, + wantRPC: false, + wantWS: false, + }, + { + name: "ws enabled after RPC", + cfg: Config{HTTPHost: "127.0.0.1"}, + fn: func(t *testing.T, n *Node, api *privateAdminAPI) { + wsport := n.http.port + _, err := api.StartWS(sp("127.0.0.1"), ip(wsport), nil, nil) + assert.NoError(t, err) + }, + wantReachable: true, + wantHandlers: true, + wantRPC: true, + wantWS: true, + }, + { + name: "ws enabled after RPC then stopped", + cfg: Config{HTTPHost: "127.0.0.1"}, + fn: func(t *testing.T, n *Node, api *privateAdminAPI) { + wsport := n.http.port + _, err := api.StartWS(sp("127.0.0.1"), ip(wsport), nil, nil) + assert.NoError(t, err) + + _, err = api.StopWS() + assert.NoError(t, err) + }, + wantReachable: true, + wantHandlers: true, + wantRPC: true, + wantWS: false, + }, + { + name: "rpc stopped with ws enabled", + fn: func(t *testing.T, n *Node, api *privateAdminAPI) { + _, err := api.StartRPC(sp("127.0.0.1"), ip(0), nil, nil, nil) + assert.NoError(t, err) + + wsport := n.http.port + _, err = api.StartWS(sp("127.0.0.1"), ip(wsport), nil, nil) + assert.NoError(t, err) + + _, err = api.StopRPC() + assert.NoError(t, err) + }, + wantReachable: false, + wantHandlers: false, + wantRPC: false, + wantWS: false, + }, + { + name: "rpc enabled after ws", + fn: func(t *testing.T, n *Node, api *privateAdminAPI) { + _, err := api.StartWS(sp("127.0.0.1"), ip(0), nil, nil) + assert.NoError(t, err) + + wsport := n.http.port + _, err = api.StartRPC(sp("127.0.0.1"), ip(wsport), nil, nil, nil) + assert.NoError(t, err) + }, + wantReachable: true, + wantHandlers: true, + wantRPC: true, + wantWS: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + // Apply some sane defaults. + config := test.cfg + // config.Logger = testlog.Logger(t, log.LvlDebug) + config.NoUSB = true + config.P2P.NoDiscovery = true + + // Create Node. + stack, err := New(&config) + if err != nil { + t.Fatal("can't create node:", err) + } + defer stack.Close() + + // Register the test handler. + stack.RegisterHandler("test", "/test", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("OK")) + })) + + if err := stack.Start(); err != nil { + t.Fatal("can't start node:", err) + } + + // Run the API call hook. + if test.fn != nil { + test.fn(t, stack, &privateAdminAPI{stack}) + } + + // Check if the HTTP endpoints are available. + baseURL := stack.HTTPEndpoint() + reachable := checkReachable(baseURL) + handlersAvailable := checkBodyOK(baseURL + "/test") + rpcAvailable := checkRPC(baseURL) + wsAvailable := checkRPC(strings.Replace(baseURL, "http://", "ws://", 1)) + if reachable != test.wantReachable { + t.Errorf("HTTP server is %sreachable, want it %sreachable", not(reachable), not(test.wantReachable)) + } + if handlersAvailable != test.wantHandlers { + t.Errorf("RegisterHandler handlers %savailable, want them %savailable", not(handlersAvailable), not(test.wantHandlers)) + } + if rpcAvailable != test.wantRPC { + t.Errorf("HTTP RPC %savailable, want it %savailable", not(rpcAvailable), not(test.wantRPC)) + } + if wsAvailable != test.wantWS { + t.Errorf("WS RPC %savailable, want it %savailable", not(wsAvailable), not(test.wantWS)) + } + }) + } +} + +// checkReachable checks if the TCP endpoint in rawurl is open. +func checkReachable(rawurl string) bool { + u, err := url.Parse(rawurl) + if err != nil { + panic(err) + } + conn, err := net.Dial("tcp", u.Host) + if err != nil { + return false + } + conn.Close() + return true +} + +// checkBodyOK checks whether the given HTTP URL responds with 200 OK and body "OK". +func checkBodyOK(url string) bool { + resp, err := http.Get(url) + if err != nil { + return false + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + return false + } + buf := make([]byte, 2) + if _, err = io.ReadFull(resp.Body, buf); err != nil { + return false + } + return bytes.Equal(buf, []byte("OK")) +} + +// checkRPC checks whether JSON-RPC works against the given URL. +func checkRPC(url string) bool { + c, err := rpc.Dial(url) + if err != nil { + return false + } + defer c.Close() + + _, err = c.SupportedModules() + return err == nil +} + +// string/int pointer helpers. +func sp(s string) *string { return &s } +func ip(i int) *int { return &i } + +func not(ok bool) string { + if ok { + return "" + } + return "not " +} diff --git a/node/config.go b/node/config.go index 742713d56..e2da73c1c 100644 --- a/node/config.go +++ b/node/config.go @@ -162,15 +162,6 @@ type Config struct { // private APIs to untrusted users is a major security risk. WSExposeAll bool `toml:",omitempty"` - // GraphQLHost is the host interface on which to start the GraphQL server. If this - // field is empty, no GraphQL API endpoint will be started. - GraphQLHost string - - // GraphQLPort is the TCP port number on which to start the GraphQL server. The - // default zero value is/ valid and will pick a port number randomly (useful - // for ephemeral nodes). - GraphQLPort int `toml:",omitempty"` - // GraphQLCors is the Cross-Origin Resource Sharing header to send to requesting // clients. Please be aware that CORS is a browser enforced security, it's fully // useless for custom HTTP clients. @@ -255,15 +246,6 @@ func (c *Config) HTTPEndpoint() string { return fmt.Sprintf("%s:%d", c.HTTPHost, c.HTTPPort) } -// GraphQLEndpoint resolves a GraphQL endpoint based on the configured host interface -// and port parameters. -func (c *Config) GraphQLEndpoint() string { - if c.GraphQLHost == "" { - return "" - } - return fmt.Sprintf("%s:%d", c.GraphQLHost, c.GraphQLPort) -} - // DefaultHTTPEndpoint returns the HTTP endpoint used by default. func DefaultHTTPEndpoint() string { config := &Config{HTTPHost: DefaultHTTPHost, HTTPPort: DefaultHTTPPort} @@ -288,7 +270,7 @@ func DefaultWSEndpoint() string { // ExtRPCEnabled returns the indicator whether node enables the external // RPC(http, ws or graphql). func (c *Config) ExtRPCEnabled() bool { - return c.HTTPHost != "" || c.WSHost != "" || c.GraphQLHost != "" + return c.HTTPHost != "" || c.WSHost != "" } // NodeName returns the devp2p node identifier. diff --git a/node/defaults.go b/node/defaults.go index 3eefadf3b..cf2f99226 100644 --- a/node/defaults.go +++ b/node/defaults.go @@ -46,7 +46,6 @@ var DefaultConfig = Config{ HTTPTimeouts: rpc.DefaultHTTPTimeouts, WSPort: DefaultWSPort, WSModules: []string{"net", "web3"}, - GraphQLPort: DefaultGraphQLPort, GraphQLVirtualHosts: []string{"localhost"}, P2P: p2p.Config{ ListenAddr: ":30303", diff --git a/node/doc.go b/node/doc.go index 0f1cb811b..034bb5e6a 100644 --- a/node/doc.go +++ b/node/doc.go @@ -22,6 +22,43 @@ resources to provide RPC APIs. Services can also offer devp2p protocols, which a up to the devp2p network when the node instance is started. +Node Lifecycle + +The Node object has a lifecycle consisting of three basic states, INITIALIZING, RUNNING +and CLOSED. + + + ●───────┐ + New() + │ + ▼ + INITIALIZING ────Start()─┐ + │ │ + │ ▼ + Close() RUNNING + │ │ + ▼ │ + CLOSED ◀──────Close()─┘ + + +Creating a Node allocates basic resources such as the data directory and returns the node +in its INITIALIZING state. Lifecycle objects, RPC APIs and peer-to-peer networking +protocols can be registered in this state. Basic operations such as opening a key-value +database are permitted while initializing. + +Once everything is registered, the node can be started, which moves it into the RUNNING +state. Starting the node starts all registered Lifecycle objects and enables RPC and +peer-to-peer networking. Note that no additional Lifecycles, APIs or p2p protocols can be +registered while the node is running. + +Closing the node releases all held resources. The actions performed by Close depend on the +state it was in. When closing a node in INITIALIZING state, resources related to the data +directory are released. If the node was RUNNING, closing it also stops all Lifecycle +objects and shuts down RPC and peer-to-peer networking. + +You must always call Close on Node, even if the node was not started. + + Resources Managed By Node All file-system resources used by a node instance are located in a directory called the diff --git a/node/endpoints.go b/node/endpoints.go index e65b757a1..8ffbce94e 100644 --- a/node/endpoints.go +++ b/node/endpoints.go @@ -48,21 +48,6 @@ func StartHTTPEndpoint(endpoint string, timeouts rpc.HTTPTimeouts, handler http. return httpSrv, listener.Addr(), err } -// startWSEndpoint starts a websocket endpoint. -func startWSEndpoint(endpoint string, handler http.Handler) (*http.Server, net.Addr, error) { - // start the HTTP listener - var ( - listener net.Listener - err error - ) - if listener, err = net.Listen("tcp", endpoint); err != nil { - return nil, nil, err - } - wsSrv := &http.Server{Handler: handler} - go wsSrv.Serve(listener) - return wsSrv, listener.Addr(), err -} - // checkModuleAvailability checks that all names given in modules are actually // available API services. It assumes that the MetadataApi module ("rpc") is always available; // the registration of this "rpc" module happens in NewServer() and is thus common to all endpoints. diff --git a/node/errors.go b/node/errors.go index 2e0dadc4d..67547bf69 100644 --- a/node/errors.go +++ b/node/errors.go @@ -39,17 +39,6 @@ func convertFileLockError(err error) error { return err } -// DuplicateServiceError is returned during Node startup if a registered service -// constructor returns a service of the same type that was already started. -type DuplicateServiceError struct { - Kind reflect.Type -} - -// Error generates a textual representation of the duplicate service error. -func (e *DuplicateServiceError) Error() string { - return fmt.Sprintf("duplicate service: %v", e.Kind) -} - // StopError is returned if a Node fails to stop either any of its registered // services or itself. type StopError struct { diff --git a/node/lifecycle.go b/node/lifecycle.go new file mode 100644 index 000000000..0d5f9a068 --- /dev/null +++ b/node/lifecycle.go @@ -0,0 +1,31 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package node + +// Lifecycle encompasses the behavior of services that can be started and stopped +// on the node. Lifecycle management is delegated to the node, but it is the +// responsibility of the service-specific package to configure and register the +// service on the node using the `RegisterLifecycle` method. +type Lifecycle interface { + // Start is called after all services have been constructed and the networking + // layer was also initialized to spawn any goroutines required by the service. + Start() error + + // Stop terminates all goroutines belonging to the service, blocking until they + // are all terminated. + Stop() error +} diff --git a/node/node.go b/node/node.go index 648a6bc49..eef62c4dd 100644 --- a/node/node.go +++ b/node/node.go @@ -17,10 +17,8 @@ package node import ( - "context" "errors" "fmt" - "net" "net/http" "os" "path/filepath" @@ -40,36 +38,33 @@ import ( // Node is a container on which services can be registered. type Node struct { - eventmux *event.TypeMux // Event multiplexer used between the services of a stack - config *Config - accman *accounts.Manager - - ephemeralKeystore string // if non-empty, the key directory that will be removed by Stop - instanceDirLock fileutil.Releaser // prevents concurrent use of instance directory - - serverConfig p2p.Config - server *p2p.Server // Currently running P2P networking layer - - serviceFuncs []ServiceConstructor // Service constructors (in dependency order) - services map[reflect.Type]Service // Currently running services + eventmux *event.TypeMux + config *Config + accman *accounts.Manager + log log.Logger + ephemKeystore string // if non-empty, the key directory that will be removed by Stop + dirLock fileutil.Releaser // prevents concurrent use of instance directory + stop chan struct{} // Channel to wait for termination notifications + server *p2p.Server // Currently running P2P networking layer + startStopLock sync.Mutex // Start/Stop are protected by an additional lock + state int // Tracks state of node lifecycle + lock sync.Mutex + lifecycles []Lifecycle // All registered backends, services, and auxiliary services that have a lifecycle rpcAPIs []rpc.API // List of APIs currently provided by the node + http *httpServer // + ws *httpServer // + ipc *ipcServer // Stores information about the ipc http server inprocHandler *rpc.Server // In-process RPC request handler to process the API requests - ipcEndpoint string // IPC endpoint to listen at (empty = IPC disabled) - ipcListener net.Listener // IPC RPC listener socket to serve API requests - ipcHandler *rpc.Server // IPC RPC request handler to process the API requests + databases map[*closeTrackingDB]struct{} // All open databases +} - httpEndpoint string // HTTP endpoint (interface + port) to listen at (empty = HTTP disabled) - httpWhitelist []string // HTTP RPC modules to allow through this endpoint - httpListenerAddr net.Addr // Address of HTTP RPC listener socket serving API requests - httpServer *http.Server // HTTP RPC HTTP server - httpHandler *rpc.Server // HTTP RPC request handler to process the API requests - - wsEndpoint string // WebSocket endpoint (interface + port) to listen at (empty = WebSocket disabled) - wsListenerAddr net.Addr // Address of WebSocket RPC listener socket serving API requests - wsHTTPServer *http.Server // WebSocket RPC HTTP server - wsHandler *rpc.Server // WebSocket RPC request handler to process the API requests +const ( + initializingState = iota + runningState + closedState +) stop chan struct{} // Channel to wait for termination notifications lock sync.RWMutex @@ -90,6 +85,10 @@ func New(conf *Config) (*Node, error) { } conf.DataDir = absdatadir } + if conf.Logger == nil { + conf.Logger = log.New() + } + // Ensure that the instance name doesn't cause weird conflicts with // other files in the data directory. if strings.ContainsAny(conf.Name, `/\`) { @@ -101,43 +100,149 @@ func New(conf *Config) (*Node, error) { if strings.HasSuffix(conf.Name, ".ipc") { return nil, errors.New(`Config.Name cannot end in ".ipc"`) } - // Ensure that the AccountManager method works before the node has started. - // We rely on this in cmd/geth. + + node := &Node{ + config: conf, + inprocHandler: rpc.NewServer(), + eventmux: new(event.TypeMux), + log: conf.Logger, + stop: make(chan struct{}), + server: &p2p.Server{Config: conf.P2P}, + databases: make(map[*closeTrackingDB]struct{}), + } + + // Register built-in APIs. + node.rpcAPIs = append(node.rpcAPIs, node.apis()...) + + // Acquire the instance directory lock. + if err := node.openDataDir(); err != nil { + return nil, err + } + // Ensure that the AccountManager method works before the node has started. We rely on + // this in cmd/geth. am, ephemeralKeystore, err := makeAccountManager(conf) if err != nil { return nil, err } - if conf.Logger == nil { - conf.Logger = log.New() + node.accman = am + node.ephemKeystore = ephemeralKeystore + + // Initialize the p2p server. This creates the node key and discovery databases. + node.server.Config.PrivateKey = node.config.NodeKey() + node.server.Config.Name = node.config.NodeName() + node.server.Config.Logger = node.log + if node.server.Config.StaticNodes == nil { + node.server.Config.StaticNodes = node.config.StaticNodes() } - // Note: any interaction with Config that would create/touch files - // in the data directory or instance directory is delayed until Start. - return &Node{ - accman: am, - ephemeralKeystore: ephemeralKeystore, - config: conf, - serviceFuncs: []ServiceConstructor{}, - ipcEndpoint: conf.IPCEndpoint(), - httpEndpoint: conf.HTTPEndpoint(), - wsEndpoint: conf.WSEndpoint(), - eventmux: new(event.TypeMux), - log: conf.Logger, - }, nil + if node.server.Config.TrustedNodes == nil { + node.server.Config.TrustedNodes = node.config.TrustedNodes() + } + if node.server.Config.NodeDatabase == "" { + node.server.Config.NodeDatabase = node.config.NodeDB() + } + + // Configure RPC servers. + node.http = newHTTPServer(node.log, conf.HTTPTimeouts) + node.ws = newHTTPServer(node.log, rpc.DefaultHTTPTimeouts) + node.ipc = newIPCServer(node.log, conf.IPCEndpoint()) + + return node, nil +} + +// Start starts all registered lifecycles, RPC services and p2p networking. +// Node can only be started once. +func (n *Node) Start() error { + n.startStopLock.Lock() + defer n.startStopLock.Unlock() + + n.lock.Lock() + switch n.state { + case runningState: + n.lock.Unlock() + return ErrNodeRunning + case closedState: + n.lock.Unlock() + return ErrNodeStopped + } + n.state = runningState + err := n.startNetworking() + lifecycles := make([]Lifecycle, len(n.lifecycles)) + copy(lifecycles, n.lifecycles) + n.lock.Unlock() + + // Check if networking startup failed. + if err != nil { + n.doClose(nil) + return err + } + // Start all registered lifecycles. + var started []Lifecycle + for _, lifecycle := range lifecycles { + if err = lifecycle.Start(); err != nil { + break + } + started = append(started, lifecycle) + } + // Check if any lifecycle failed to start. + if err != nil { + n.stopServices(started) + n.doClose(nil) + } + return err } // Close stops the Node and releases resources acquired in // Node constructor New. func (n *Node) Close() error { - var errs []error + n.startStopLock.Lock() + defer n.startStopLock.Unlock() - // Terminate all subsystems and collect any errors - if err := n.Stop(); err != nil && err != ErrNodeStopped { - errs = append(errs, err) + n.lock.Lock() + state := n.state + n.lock.Unlock() + switch state { + case initializingState: + // The node was never started. + return n.doClose(nil) + case runningState: + // The node was started, release resources acquired by Start(). + var errs []error + if err := n.stopServices(n.lifecycles); err != nil { + errs = append(errs, err) + } + return n.doClose(errs) + case closedState: + return ErrNodeStopped + default: + panic(fmt.Sprintf("node is in unknown state %d", state)) } +} + +// doClose releases resources acquired by New(), collecting errors. +func (n *Node) doClose(errs []error) error { + // Close databases. This needs the lock because it needs to + // synchronize with OpenDatabase*. + n.lock.Lock() + n.state = closedState + errs = append(errs, n.closeDatabases()...) + n.lock.Unlock() + if err := n.accman.Close(); err != nil { errs = append(errs, err) } - // Report any errors that might have occurred + if n.ephemKeystore != "" { + if err := os.RemoveAll(n.ephemKeystore); err != nil { + errs = append(errs, err) + } + } + + // Release instance directory lock. + n.closeDataDir() + + // Unblock n.Wait. + close(n.stop) + + // Report any errors that might have occurred. switch len(errs) { case 0: return nil @@ -148,114 +253,50 @@ func (n *Node) Close() error { } } -// Register injects a new service into the node's stack. The service created by -// the passed constructor must be unique in its type with regard to sibling ones. -func (n *Node) Register(constructor ServiceConstructor) error { - n.lock.Lock() - defer n.lock.Unlock() - - if n.server != nil { - return ErrNodeRunning - } - n.serviceFuncs = append(n.serviceFuncs, constructor) - return nil -} - -// Start creates a live P2P node and starts running it. -func (n *Node) Start() error { - n.lock.Lock() - defer n.lock.Unlock() - - // Short circuit if the node's already running - if n.server != nil { - return ErrNodeRunning - } - if err := n.openDataDir(); err != nil { - return err - } - - // Initialize the p2p server. This creates the node key and - // discovery databases. - n.serverConfig = n.config.P2P - n.serverConfig.PrivateKey = n.config.NodeKey() - n.serverConfig.Name = n.config.NodeName() - n.serverConfig.Logger = n.log - if n.serverConfig.StaticNodes == nil { - n.serverConfig.StaticNodes = n.config.StaticNodes() - } - if n.serverConfig.TrustedNodes == nil { - n.serverConfig.TrustedNodes = n.config.TrustedNodes() - } - if n.serverConfig.NodeDatabase == "" { - n.serverConfig.NodeDatabase = n.config.NodeDB() - } - running := &p2p.Server{Config: n.serverConfig} - n.log.Info("Starting peer-to-peer node", "instance", n.serverConfig.Name) - - // Otherwise copy and specialize the P2P configuration - services := make(map[reflect.Type]Service) - for _, constructor := range n.serviceFuncs { - // Create a new context for the particular service - ctx := &ServiceContext{ - Config: *n.config, - services: make(map[reflect.Type]Service), - EventMux: n.eventmux, - AccountManager: n.accman, - } - for kind, s := range services { // copy needed for threaded access - ctx.services[kind] = s - } - // Construct and save the service - service, err := constructor(ctx) - if err != nil { - return err - } - kind := reflect.TypeOf(service) - if _, exists := services[kind]; exists { - return &DuplicateServiceError{Kind: kind} - } - services[kind] = service - } - // Gather the protocols and start the freshly assembled P2P server - for _, service := range services { - running.Protocols = append(running.Protocols, service.Protocols()...) - } - if err := running.Start(); err != nil { +// startNetworking starts all network endpoints. +func (n *Node) startNetworking() error { + n.log.Info("Starting peer-to-peer node", "instance", n.server.Name) + if err := n.server.Start(); err != nil { return convertFileLockError(err) } - // Start each of the services - var started []reflect.Type - for kind, service := range services { - // Start the next service, stopping all previous upon failure - if err := service.Start(running); err != nil { - for _, kind := range started { - services[kind].Stop() - } - running.Stop() - - return err - } - // Mark the service started for potential cleanup - started = append(started, kind) + err := n.startRPC() + if err != nil { + n.stopRPC() + n.server.Stop() } - // Lastly, start the configured RPC interfaces - if err := n.startRPC(services); err != nil { - for _, service := range services { - service.Stop() - } - running.Stop() - return err - } - // Finish initializing the startup - n.services = services - n.server = running - n.stop = make(chan struct{}) - return nil + return err } -// Config returns the configuration of node. -func (n *Node) Config() *Config { - return n.config +// containsLifecycle checks if 'lfs' contains 'l'. +func containsLifecycle(lfs []Lifecycle, l Lifecycle) bool { + for _, obj := range lfs { + if obj == l { + return true + } + } + return false +} + +// stopServices terminates running services, RPC and p2p networking. +// It is the inverse of Start. +func (n *Node) stopServices(running []Lifecycle) error { + n.stopRPC() + + // Stop running lifecycles in reverse order. + failure := &StopError{Services: make(map[reflect.Type]error)} + for i := len(running) - 1; i >= 0; i-- { + if err := running[i].Stop(); err != nil { + failure.Services[reflect.TypeOf(running[i])] = err + } + } + + // Stop p2p networking. + n.server.Stop() + + if len(failure.Services) > 0 { + return failure + } + return nil } func (n *Node) openDataDir() error { @@ -273,325 +314,189 @@ func (n *Node) openDataDir() error { if err != nil { return convertFileLockError(err) } - n.instanceDirLock = release + n.dirLock = release return nil } -// startRPC is a helper method to start all the various RPC endpoints during node +func (n *Node) closeDataDir() { + // Release instance directory lock. + if n.dirLock != nil { + if err := n.dirLock.Release(); err != nil { + n.log.Error("Can't release datadir lock", "err", err) + } + n.dirLock = nil + } +} + +// configureRPC is a helper method to configure all the various RPC endpoints during node // startup. It's not meant to be called at any time afterwards as it makes certain // assumptions about the state of the node. -func (n *Node) startRPC(services map[reflect.Type]Service) error { - // Gather all the possible APIs to surface - apis := n.apis() - for _, service := range services { - apis = append(apis, service.APIs()...) - } - // Start the various API endpoints, terminating all in case of errors - if err := n.startInProc(apis); err != nil { +func (n *Node) startRPC() error { + if err := n.startInProc(); err != nil { return err } - if err := n.startIPC(apis); err != nil { - n.stopInProc() - return err - } - if err := n.startHTTP(n.httpEndpoint, apis, n.config.HTTPModules, n.config.HTTPCors, n.config.HTTPVirtualHosts, n.config.HTTPTimeouts, n.config.WSOrigins); err != nil { - n.stopIPC() - n.stopInProc() - return err - } - // if endpoints are not the same, start separate servers - if n.httpEndpoint != n.wsEndpoint { - if err := n.startWS(n.wsEndpoint, apis, n.config.WSModules, n.config.WSOrigins, n.config.WSExposeAll); err != nil { - n.stopHTTP() - n.stopIPC() - n.stopInProc() + + // Configure IPC. + if n.ipc.endpoint != "" { + if err := n.ipc.start(n.rpcAPIs); err != nil { return err } } - // All API endpoints started successfully - n.rpcAPIs = apis - return nil + // Configure HTTP. + if n.config.HTTPHost != "" { + config := httpConfig{ + CorsAllowedOrigins: n.config.HTTPCors, + Vhosts: n.config.HTTPVirtualHosts, + Modules: n.config.HTTPModules, + } + if err := n.http.setListenAddr(n.config.HTTPHost, n.config.HTTPPort); err != nil { + return err + } + if err := n.http.enableRPC(n.rpcAPIs, config); err != nil { + return err + } + } + + // Configure WebSocket. + if n.config.WSHost != "" { + server := n.wsServerForPort(n.config.WSPort) + config := wsConfig{ + Modules: n.config.WSModules, + Origins: n.config.WSOrigins, + } + if err := server.setListenAddr(n.config.WSHost, n.config.WSPort); err != nil { + return err + } + if err := server.enableWS(n.rpcAPIs, config); err != nil { + return err + } + } + + if err := n.http.start(); err != nil { + return err + } + return n.ws.start() } -// startInProc initializes an in-process RPC endpoint. -func (n *Node) startInProc(apis []rpc.API) error { - // Register all the APIs exposed by the services - handler := rpc.NewServer() - for _, api := range apis { - if err := handler.RegisterName(api.Namespace, api.Service); err != nil { +func (n *Node) wsServerForPort(port int) *httpServer { + if n.config.HTTPHost == "" || n.http.port == port { + return n.http + } + return n.ws +} + +func (n *Node) stopRPC() { + n.http.stop() + n.ws.stop() + n.ipc.stop() + n.stopInProc() +} + +// startInProc registers all RPC APIs on the inproc server. +func (n *Node) startInProc() error { + for _, api := range n.rpcAPIs { + if err := n.inprocHandler.RegisterName(api.Namespace, api.Service); err != nil { return err } - n.log.Debug("InProc registered", "namespace", api.Namespace) } - n.inprocHandler = handler return nil } // stopInProc terminates the in-process RPC endpoint. func (n *Node) stopInProc() { - if n.inprocHandler != nil { - n.inprocHandler.Stop() - n.inprocHandler = nil - } + n.inprocHandler.Stop() } -// startIPC initializes and starts the IPC RPC endpoint. -func (n *Node) startIPC(apis []rpc.API) error { - if n.ipcEndpoint == "" { - return nil // IPC disabled. - } - listener, handler, err := rpc.StartIPCEndpoint(n.ipcEndpoint, apis) - if err != nil { - return err - } - n.ipcListener = listener - n.ipcHandler = handler - n.log.Info("IPC endpoint opened", "url", n.ipcEndpoint) - return nil +// Wait blocks until the node is closed. +func (n *Node) Wait() { + <-n.stop } -// stopIPC terminates the IPC RPC endpoint. -func (n *Node) stopIPC() { - if n.ipcListener != nil { - n.ipcListener.Close() - n.ipcListener = nil - - n.log.Info("IPC endpoint closed", "url", n.ipcEndpoint) - } - if n.ipcHandler != nil { - n.ipcHandler.Stop() - n.ipcHandler = nil - } -} - -// startHTTP initializes and starts the HTTP RPC endpoint. -func (n *Node) startHTTP(endpoint string, apis []rpc.API, modules []string, cors []string, vhosts []string, timeouts rpc.HTTPTimeouts, wsOrigins []string) error { - // Short circuit if the HTTP endpoint isn't being exposed - if endpoint == "" { - return nil - } - // register apis and create handler stack - srv := rpc.NewServer() - err := RegisterApisFromWhitelist(apis, modules, srv, false) - if err != nil { - return err - } - handler := NewHTTPHandlerStack(srv, cors, vhosts) - // wrap handler in WebSocket handler only if WebSocket port is the same as http rpc - if n.httpEndpoint == n.wsEndpoint { - handler = NewWebsocketUpgradeHandler(handler, srv.WebsocketHandler(wsOrigins)) - } - httpServer, addr, err := StartHTTPEndpoint(endpoint, timeouts, handler) - if err != nil { - return err - } - n.log.Info("HTTP endpoint opened", "url", fmt.Sprintf("http://%v/", addr), - "cors", strings.Join(cors, ","), - "vhosts", strings.Join(vhosts, ",")) - if n.httpEndpoint == n.wsEndpoint { - n.log.Info("WebSocket endpoint opened", "url", fmt.Sprintf("ws://%v", addr)) - } - // All listeners booted successfully - n.httpEndpoint = endpoint - n.httpListenerAddr = addr - n.httpServer = httpServer - n.httpHandler = srv - - return nil -} - -// stopHTTP terminates the HTTP RPC endpoint. -func (n *Node) stopHTTP() { - if n.httpServer != nil { - // Don't bother imposing a timeout here. - n.httpServer.Shutdown(context.Background()) //nolint:errcheck - n.log.Info("HTTP endpoint closed", "url", fmt.Sprintf("http://%v/", n.httpListenerAddr)) - } - if n.httpHandler != nil { - n.httpHandler.Stop() - n.httpHandler = nil - } -} - -// startWS initializes and starts the WebSocket RPC endpoint. -func (n *Node) startWS(endpoint string, apis []rpc.API, modules []string, wsOrigins []string, exposeAll bool) error { - // Short circuit if the WS endpoint isn't being exposed - if endpoint == "" { - return nil - } - - srv := rpc.NewServer() - handler := srv.WebsocketHandler(wsOrigins) - err := RegisterApisFromWhitelist(apis, modules, srv, exposeAll) - if err != nil { - return err - } - httpServer, addr, err := startWSEndpoint(endpoint, handler) - if err != nil { - return err - } - n.log.Info("WebSocket endpoint opened", "url", fmt.Sprintf("ws://%v", addr)) - // All listeners booted successfully - n.wsEndpoint = endpoint - n.wsListenerAddr = addr - n.wsHTTPServer = httpServer - n.wsHandler = srv - - return nil -} - -// stopWS terminates the WebSocket RPC endpoint. -func (n *Node) stopWS() { - if n.wsHTTPServer != nil { - // Don't bother imposing a timeout here. - n.wsHTTPServer.Shutdown(context.Background()) //nolint:errcheck - n.log.Info("WebSocket endpoint closed", "url", fmt.Sprintf("ws://%v", n.wsListenerAddr)) - } - if n.wsHandler != nil { - n.wsHandler.Stop() - n.wsHandler = nil - } -} - -// Stop terminates a running node along with all it's services. In the node was -// not started, an error is returned. -func (n *Node) Stop() error { +// RegisterLifecycle registers the given Lifecycle on the node. +func (n *Node) RegisterLifecycle(lifecycle Lifecycle) { n.lock.Lock() defer n.lock.Unlock() - // Short circuit if the node's not running - if n.server == nil { - return ErrNodeStopped + if n.state != initializingState { + panic("can't register lifecycle on running/stopped node") } + if containsLifecycle(n.lifecycles, lifecycle) { + panic(fmt.Sprintf("attempt to register lifecycle %T more than once", lifecycle)) + } + n.lifecycles = append(n.lifecycles, lifecycle) +} - // Terminate the API, services and the p2p server. - n.stopWS() - n.stopHTTP() - n.stopIPC() - n.rpcAPIs = nil - failure := &StopError{ - Services: make(map[reflect.Type]error), - } - for kind, service := range n.services { - if err := service.Stop(); err != nil { - failure.Services[kind] = err - } - } - n.server.Stop() - n.services = nil - n.server = nil +// RegisterProtocols adds backend's protocols to the node's p2p server. +func (n *Node) RegisterProtocols(protocols []p2p.Protocol) { + n.lock.Lock() + defer n.lock.Unlock() - // Release instance directory lock. - if n.instanceDirLock != nil { - if err := n.instanceDirLock.Release(); err != nil { - n.log.Error("Can't release datadir lock", "err", err) - } - n.instanceDirLock = nil - } - - // unblock n.Wait + if n.state != initializingState { close(n.stop) // Remove the keystore if it was created ephemerally. type closer interface { - Close() + panic("can't register protocols on running/stopped node") } + n.server.Protocols = append(n.server.Protocols, protocols...) +} +// RegisterAPIs registers the APIs a service provides on the node. - for _, api := range n.rpcAPIs { - if closeAPI, ok := api.Service.(closer); ok { - closeAPI.Close() - } - } +func (n *Node) RegisterAPIs(apis []rpc.API) { + n.lock.Lock() + defer n.lock.Unlock() - var keystoreErr error - if n.ephemeralKeystore != "" { - keystoreErr = os.RemoveAll(n.ephemeralKeystore) + if n.state != initializingState { + panic("can't register APIs on running/stopped node") } - - if len(failure.Services) > 0 { - return failure - } - if keystoreErr != nil { - return keystoreErr - } - return nil + n.rpcAPIs = append(n.rpcAPIs, apis...) } -// Wait blocks the thread until the node is stopped. If the node is not running -// at the time of invocation, the method immediately returns. -func (n *Node) Wait() { - n.lock.RLock() - if n.server == nil { - n.lock.RUnlock() - return - } - stop := n.stop - n.lock.RUnlock() +// RegisterHandler mounts a handler on the given path on the canonical HTTP server. +// +// The name of the handler is shown in a log message when the HTTP server starts +// and should be a descriptive term for the service provided by the handler. +func (n *Node) RegisterHandler(name, path string, handler http.Handler) { + n.lock.Lock() + defer n.lock.Unlock() - <-stop -} - -// Restart terminates a running node and boots up a new one in its place. If the -// node isn't running, an error is returned. -func (n *Node) Restart() error { - if err := n.Stop(); err != nil { - return err + if n.state != initializingState { + panic("can't register HTTP handler on running/stopped node") } - if err := n.Start(); err != nil { - return err - } - return nil + n.http.mux.Handle(path, handler) + n.http.handlerNames[path] = name } // Attach creates an RPC client attached to an in-process API handler. func (n *Node) Attach() (*rpc.Client, error) { - n.lock.RLock() - defer n.lock.RUnlock() - - if n.server == nil { - return nil, ErrNodeStopped - } return rpc.DialInProc(n.inprocHandler), nil } // RPCHandler returns the in-process RPC request handler. func (n *Node) RPCHandler() (*rpc.Server, error) { - n.lock.RLock() - defer n.lock.RUnlock() + n.lock.Lock() + defer n.lock.Unlock() - if n.inprocHandler == nil { + if n.state == closedState { return nil, ErrNodeStopped } return n.inprocHandler, nil } -// Server retrieves the currently running P2P network layer. This method is meant -// only to inspect fields of the currently running server, life cycle management -// should be left to this Node entity. -func (n *Node) Server() *p2p.Server { - n.lock.RLock() - defer n.lock.RUnlock() - - return n.server +// Config returns the configuration of node. +func (n *Node) Config() *Config { + return n.config } -// Service retrieves a currently running service registered of a specific type. -func (n *Node) Service(service interface{}) error { - n.lock.RLock() - defer n.lock.RUnlock() +// Server retrieves the currently running P2P network layer. This method is meant +// only to inspect fields of the currently running server. Callers should not +// start or stop the returned server. +func (n *Node) Server() *p2p.Server { + n.lock.Lock() + defer n.lock.Unlock() - // Short circuit if the node's not running - if n.server == nil { - return ErrNodeStopped - } - // Otherwise try to find the service to return - element := reflect.ValueOf(service).Elem() - if running, ok := n.services[element.Type()]; ok { - element.Set(reflect.ValueOf(running)) - return nil - } - return ErrServiceUnknown + return n.server } // DataDir retrieves the current datadir used by the protocol stack. @@ -612,29 +517,20 @@ func (n *Node) AccountManager() *accounts.Manager { // IPCEndpoint retrieves the current IPC endpoint used by the protocol stack. func (n *Node) IPCEndpoint() string { - return n.ipcEndpoint + return n.ipc.endpoint } -// HTTPEndpoint retrieves the current HTTP endpoint used by the protocol stack. +// HTTPEndpoint returns the URL of the HTTP server. func (n *Node) HTTPEndpoint() string { - n.lock.Lock() - defer n.lock.Unlock() - - if n.httpListenerAddr != nil { - return n.httpListenerAddr.String() - } - return n.httpEndpoint + return "http://" + n.http.listenAddr() } // WSEndpoint retrieves the current WS endpoint used by the protocol stack. func (n *Node) WSEndpoint() string { - n.lock.Lock() - defer n.lock.Unlock() - - if n.wsListenerAddr != nil { - return n.wsListenerAddr.String() + if n.http.wsAllowed() { + return "ws://" + n.http.listenAddr() } - return n.wsEndpoint + return "ws://" + n.ws.listenAddr() } // EventMux retrieves the event multiplexer used by all the network services in @@ -647,12 +543,48 @@ func (n *Node) EventMux() *event.TypeMux { // previous can be found) from within the node's instance directory. If the node is // ephemeral, a memory database is returned. func (n *Node) OpenDatabase(name string) (*ethdb.ObjectDatabase, error) { - if n.config.DataDir == "" { - return ethdb.NewMemDatabase(), nil + n.lock.Lock() + defer n.lock.Unlock() + if n.state == closedState { + return nil, ErrNodeStopped } + var db ethdb.Database + var err error + if n.config.DataDir == "" { + db = rawdb.NewMemoryDatabase() + } else { + db, err = rawdb.NewLevelDBDatabase(n.ResolvePath(name), cache, handles, namespace) + } + + if err == nil { + db = n.wrapDatabase(db) + } + return db, err +} + +// OpenDatabaseWithFreezer opens an existing database with the given name (or +// creates one if no previous can be found) from within the node's data directory, +// also attaching a chain freezer to it that moves ancient chain data from the +// database to immutable append-only files. If the node is an ephemeral one, a +// memory database is returned. +func (n *Node) OpenDatabaseWithFreezer(name string, cache, handles int, freezer, namespace string) (ethdb.Database, error) { + n.lock.Lock() + defer n.lock.Unlock() + if n.state == closedState { + return nil, ErrNodeStopped + } + + var db ethdb.Database + var err error + if n.config.DataDir == "" { + db = rawdb.NewMemoryDatabase() if n.config.Bolt { log.Info("Opening Database (Bolt)") + switch { + case freezer == "": + freezer = filepath.Join(root, "ancient") + case !filepath.IsAbs(freezer): return ethdb.Open(n.config.ResolvePath(name + "_bolt")) } @@ -665,49 +597,35 @@ func (n *Node) ResolvePath(x string) string { return n.config.ResolvePath(x) } -// apis returns the collection of RPC descriptors this node offers. -func (n *Node) apis() []rpc.API { - return []rpc.API{ - { - Namespace: "admin", - Version: "1.0", - Service: NewPrivateAdminAPI(n), - }, { - Namespace: "admin", - Version: "1.0", - Service: NewPublicAdminAPI(n), - Public: true, - }, { - Namespace: "debug", - Version: "1.0", - Service: debug.Handler, - }, { - Namespace: "web3", - Version: "1.0", - Service: NewPublicWeb3API(n), - Public: true, - }, - } +// closeTrackingDB wraps the Close method of a database. When the database is closed by the +// service, the wrapper removes it from the node's database map. This ensures that Node +// won't auto-close the database if it is closed by the service that opened it. +type closeTrackingDB struct { + ethdb.Database + n *Node } -// RegisterApisFromWhitelist checks the given modules' availability, generates a whitelist based on the allowed modules, -// and then registers all of the APIs exposed by the services. -func RegisterApisFromWhitelist(apis []rpc.API, modules []string, srv *rpc.Server, exposeAll bool) error { - if bad, available := checkModuleAvailability(modules, apis); len(bad) > 0 { - log.Error("Unavailable modules in HTTP API list", "unavailable", bad, "available", available) - } - // Generate the whitelist based on the allowed modules - whitelist := make(map[string]bool) - for _, module := range modules { - whitelist[module] = true - } - // Register all the APIs exposed by the services - for _, api := range apis { - if exposeAll || whitelist[api.Namespace] || (len(whitelist) == 0 && api.Public) { - if err := srv.RegisterName(api.Namespace, api.Service); err != nil { - return err - } +func (db *closeTrackingDB) Close() error { + db.n.lock.Lock() + delete(db.n.databases, db) + db.n.lock.Unlock() + return db.Database.Close() +} + +// wrapDatabase ensures the database will be auto-closed when Node is closed. +func (n *Node) wrapDatabase(db ethdb.Database) ethdb.Database { + wrapper := &closeTrackingDB{db, n} + n.databases[wrapper] = struct{}{} + return wrapper +} + +// closeDatabases closes all open databases. +func (n *Node) closeDatabases() (errors []error) { + for db := range n.databases { + delete(n.databases, db) + if err := db.Database.Close(); err != nil { + errors = append(errors, err) } } - return nil + return errors } diff --git a/node/node_example_test.go b/node/node_example_test.go index 95623e3fb..53a4cdfbc 100644 --- a/node/node_example_test.go +++ b/node/node_example_test.go @@ -21,26 +21,20 @@ import ( "log" "github.com/ledgerwatch/turbo-geth/node" - "github.com/ledgerwatch/turbo-geth/p2p" - "github.com/ledgerwatch/turbo-geth/rpc" ) -// SampleService is a trivial network service that can be attached to a node for +// SampleLifecycle is a trivial network service that can be attached to a node for // life cycle management. // -// The following methods are needed to implement a node.Service: -// - Protocols() []p2p.Protocol - devp2p protocols the service can communicate on -// - APIs() []rpc.API - api methods the service wants to expose on rpc channels +// The following methods are needed to implement a node.Lifecycle: // - Start() error - method invoked when the node is ready to start the service // - Stop() error - method invoked when the node terminates the service -type SampleService struct{} +type SampleLifecycle struct{} -func (s *SampleService) Protocols() []p2p.Protocol { return nil } -func (s *SampleService) APIs() []rpc.API { return nil } -func (s *SampleService) Start(*p2p.Server) error { fmt.Println("Service starting..."); return nil } -func (s *SampleService) Stop() error { fmt.Println("Service stopping..."); return nil } +func (s *SampleLifecycle) Start() error { fmt.Println("Service starting..."); return nil } +func (s *SampleLifecycle) Stop() error { fmt.Println("Service stopping..."); return nil } -func ExampleService() { +func ExampleLifecycle() { // Create a network node to run protocols with the default values. stack, err := node.New(&node.Config{}) if err != nil { @@ -48,29 +42,18 @@ func ExampleService() { } defer stack.Close() - // Create and register a simple network service. This is done through the definition - // of a node.ServiceConstructor that will instantiate a node.Service. The reason for - // the factory method approach is to support service restarts without relying on the - // individual implementations' support for such operations. - constructor := func(context *node.ServiceContext) (node.Service, error) { - return new(SampleService), nil - } - if err := stack.Register(constructor); err != nil { - log.Fatalf("Failed to register service: %v", err) - } + // Create and register a simple network Lifecycle. + service := new(SampleLifecycle) + stack.RegisterLifecycle(service) + // Boot up the entire protocol stack, do a restart and terminate if err := stack.Start(); err != nil { log.Fatalf("Failed to start the protocol stack: %v", err) } - if err := stack.Restart(); err != nil { - log.Fatalf("Failed to restart the protocol stack: %v", err) - } - if err := stack.Stop(); err != nil { + if err := stack.Close(); err != nil { log.Fatalf("Failed to stop the protocol stack: %v", err) } // Output: // Service starting... // Service stopping... - // Service starting... - // Service stopping... } diff --git a/node/node_test.go b/node/node_test.go index 59a4c9bc3..1e5ffa944 100644 --- a/node/node_test.go +++ b/node/node_test.go @@ -18,12 +18,15 @@ package node import ( "errors" + "fmt" + "io" "io/ioutil" + "net" "net/http" "os" "reflect" + "strings" "testing" - "time" "github.com/ledgerwatch/turbo-geth/crypto" "github.com/ledgerwatch/turbo-geth/p2p" @@ -43,20 +46,28 @@ func testNodeConfig() *Config { } } -// Tests that an empty protocol stack can be started, restarted and stopped. -func TestNodeLifeCycle(t *testing.T) { +// Tests that an empty protocol stack can be closed more than once. +func TestNodeCloseMultipleTimes(t *testing.T) { stack, err := New(testNodeConfig()) if err != nil { t.Fatalf("failed to create protocol stack: %v", err) } - defer stack.Close() + stack.Close() // Ensure that a stopped node can be stopped again for i := 0; i < 3; i++ { - if err := stack.Stop(); err != ErrNodeStopped { + if err := stack.Close(); err != ErrNodeStopped { t.Fatalf("iter %d: stop failure mismatch: have %v, want %v", i, err, ErrNodeStopped) } } +} + +func TestNodeStartMultipleTimes(t *testing.T) { + stack, err := New(testNodeConfig()) + if err != nil { + t.Fatalf("failed to create protocol stack: %v", err) + } + // Ensure that a node can be successfully started, but only once if err := stack.Start(); err != nil { t.Fatalf("failed to start node: %v", err) @@ -64,17 +75,11 @@ func TestNodeLifeCycle(t *testing.T) { if err := stack.Start(); err != ErrNodeRunning { t.Fatalf("start failure mismatch: have %v, want %v ", err, ErrNodeRunning) } - // Ensure that a node can be restarted arbitrarily many times - for i := 0; i < 3; i++ { - if err := stack.Restart(); err != nil { - t.Fatalf("iter %d: failed to restart node: %v", i, err) - } - } // Ensure that a node can be stopped, but only once - if err := stack.Stop(); err != nil { + if err := stack.Close(); err != nil { t.Fatalf("failed to stop node: %v", err) } - if err := stack.Stop(); err != ErrNodeStopped { + if err := stack.Close(); err != ErrNodeStopped { t.Fatalf("stop failure mismatch: have %v, want %v ", err, ErrNodeStopped) } } @@ -94,92 +99,152 @@ func TestNodeUsedDataDir(t *testing.T) { t.Fatalf("failed to create original protocol stack: %v", err) } defer original.Close() - if err := original.Start(); err != nil { t.Fatalf("failed to start original protocol stack: %v", err) } - defer original.Stop() // Create a second node based on the same data directory and ensure failure - duplicate, err := New(&Config{DataDir: dir}) - if err != nil { - t.Fatalf("failed to create duplicate protocol stack: %v", err) - } - defer duplicate.Close() - - if err := duplicate.Start(); err != ErrDatadirUsed { + _, err = New(&Config{DataDir: dir}) + if err != ErrDatadirUsed { t.Fatalf("duplicate datadir failure mismatch: have %v, want %v", err, ErrDatadirUsed) } } -// Tests whether services can be registered and duplicates caught. -func TestServiceRegistry(t *testing.T) { +// Tests whether a Lifecycle can be registered. +func TestLifecycleRegistry_Successful(t *testing.T) { stack, err := New(testNodeConfig()) if err != nil { t.Fatalf("failed to create protocol stack: %v", err) } defer stack.Close() - // Register a batch of unique services and ensure they start successfully - services := []ServiceConstructor{NewNoopServiceA, NewNoopServiceB, NewNoopServiceC} - for i, constructor := range services { - if err := stack.Register(constructor); err != nil { - t.Fatalf("service #%d: registration failed: %v", i, err) + noop := NewNoop() + stack.RegisterLifecycle(noop) + + if !containsLifecycle(stack.lifecycles, noop) { + t.Fatalf("lifecycle was not properly registered on the node, %v", err) + } +} + +// Tests whether a service's protocols can be registered properly on the node's p2p server. +func TestRegisterProtocols(t *testing.T) { + stack, err := New(testNodeConfig()) + if err != nil { + t.Fatalf("failed to create protocol stack: %v", err) + } + defer stack.Close() + + fs, err := NewFullService(stack) + if err != nil { + t.Fatalf("could not create full service: %v", err) + } + + for _, protocol := range fs.Protocols() { + if !containsProtocol(stack.server.Protocols, protocol) { + t.Fatalf("protocol %v was not successfully registered", protocol) } } - if err := stack.Start(); err != nil { - t.Fatalf("failed to start original service stack: %v", err) - } - if err := stack.Stop(); err != nil { - t.Fatalf("failed to stop original service stack: %v", err) - } - // Duplicate one of the services and retry starting the node - if err := stack.Register(NewNoopServiceB); err != nil { - t.Fatalf("duplicate registration failed: %v", err) - } - if err := stack.Start(); err == nil { - t.Fatalf("duplicate service started") - } else { - if _, ok := err.(*DuplicateServiceError); !ok { - t.Fatalf("duplicate error mismatch: have %v, want %v", err, DuplicateServiceError{}) + + for _, api := range fs.APIs() { + if !containsAPI(stack.rpcAPIs, api) { + t.Fatalf("api %v was not successfully registered", api) } } } -// Tests that registered services get started and stopped correctly. -func TestServiceLifeCycle(t *testing.T) { - stack, err := New(testNodeConfig()) - if err != nil { - t.Fatalf("failed to create protocol stack: %v", err) - } +// This test checks that open databases are closed with node. +func TestNodeCloseClosesDB(t *testing.T) { + stack, _ := New(testNodeConfig()) defer stack.Close() - // Register a batch of life-cycle instrumented services - services := map[string]InstrumentingWrapper{ - "A": InstrumentedServiceMakerA, - "B": InstrumentedServiceMakerB, - "C": InstrumentedServiceMakerC, + db, err := stack.OpenDatabase("mydb", 0, 0, "") + if err != nil { + t.Fatal("can't open DB:", err) } + if err = db.Put([]byte{}, []byte{}); err != nil { + t.Fatal("can't Put on open DB:", err) + } + + stack.Close() + if err = db.Put([]byte{}, []byte{}); err == nil { + t.Fatal("Put succeeded after node is closed") + } +} + +// This test checks that OpenDatabase can be used from within a Lifecycle Start method. +func TestNodeOpenDatabaseFromLifecycleStart(t *testing.T) { + stack, _ := New(testNodeConfig()) + defer stack.Close() + + var db ethdb.Database + var err error + stack.RegisterLifecycle(&InstrumentedService{ + startHook: func() { + db, err = stack.OpenDatabase("mydb", 0, 0, "") + if err != nil { + t.Fatal("can't open DB:", err) + } + }, + stopHook: func() { + db.Close() + }, + }) + + stack.Start() + stack.Close() +} + +// This test checks that OpenDatabase can be used from within a Lifecycle Stop method. +func TestNodeOpenDatabaseFromLifecycleStop(t *testing.T) { + stack, _ := New(testNodeConfig()) + defer stack.Close() + + stack.RegisterLifecycle(&InstrumentedService{ + stopHook: func() { + db, err := stack.OpenDatabase("mydb", 0, 0, "") + if err != nil { + t.Fatal("can't open DB:", err) + } + db.Close() + }, + }) + + stack.Start() + stack.Close() +} + +// Tests that registered Lifecycles get started and stopped correctly. +func TestLifecycleLifeCycle(t *testing.T) { + stack, _ := New(testNodeConfig()) + defer stack.Close() + started := make(map[string]bool) stopped := make(map[string]bool) - for id, maker := range services { - id := id // Closure for the constructor - constructor := func(*ServiceContext) (Service, error) { - return &InstrumentedService{ - startHook: func(*p2p.Server) { started[id] = true }, - stopHook: func() { stopped[id] = true }, - }, nil - } - if err := stack.Register(maker(constructor)); err != nil { - t.Fatalf("service %s: registration failed: %v", id, err) - } + // Create a batch of instrumented services + lifecycles := map[string]Lifecycle{ + "A": &InstrumentedService{ + startHook: func() { started["A"] = true }, + stopHook: func() { stopped["A"] = true }, + }, + "B": &InstrumentedService{ + startHook: func() { started["B"] = true }, + stopHook: func() { stopped["B"] = true }, + }, + "C": &InstrumentedService{ + startHook: func() { started["C"] = true }, + stopHook: func() { stopped["C"] = true }, + }, + } + // register lifecycles on node + for _, lifecycle := range lifecycles { + stack.RegisterLifecycle(lifecycle) } // Start the node and check that all services are running if err := stack.Start(); err != nil { t.Fatalf("failed to start protocol stack: %v", err) } - for id := range services { + for id := range lifecycles { if !started[id] { t.Fatalf("service %s: freshly started service not running", id) } @@ -188,470 +253,286 @@ func TestServiceLifeCycle(t *testing.T) { } } // Stop the node and check that all services have been stopped - if err := stack.Stop(); err != nil { + if err := stack.Close(); err != nil { t.Fatalf("failed to stop protocol stack: %v", err) } - for id := range services { + for id := range lifecycles { if !stopped[id] { t.Fatalf("service %s: freshly terminated service still running", id) } } } -// Tests that services are restarted cleanly as new instances. -func TestServiceRestarts(t *testing.T) { +// Tests that if a Lifecycle fails to start, all others started before it will be +// shut down. +func TestLifecycleStartupError(t *testing.T) { stack, err := New(testNodeConfig()) if err != nil { t.Fatalf("failed to create protocol stack: %v", err) } defer stack.Close() - // Define a service that does not support restarts - var ( - running bool - started int - ) - constructor := func(*ServiceContext) (Service, error) { - running = false - - return &InstrumentedService{ - startHook: func(*p2p.Server) { - if running { - panic("already running") - } - running = true - started++ - }, - }, nil - } - // Register the service and start the protocol stack - if err := stack.Register(constructor); err != nil { - t.Fatalf("failed to register the service: %v", err) - } - if err := stack.Start(); err != nil { - t.Fatalf("failed to start protocol stack: %v", err) - } - defer stack.Stop() - - if !running || started != 1 { - t.Fatalf("running/started mismatch: have %v/%d, want true/1", running, started) - } - // Restart the stack a few times and check successful service restarts - for i := 0; i < 3; i++ { - if err := stack.Restart(); err != nil { - t.Fatalf("iter %d: failed to restart stack: %v", i, err) - } - } - if !running || started != 4 { - t.Fatalf("running/started mismatch: have %v/%d, want true/4", running, started) - } -} - -// Tests that if a service fails to initialize itself, none of the other services -// will be allowed to even start. -func TestServiceConstructionAbortion(t *testing.T) { - stack, err := New(testNodeConfig()) - if err != nil { - t.Fatalf("failed to create protocol stack: %v", err) - } - defer stack.Close() - - // Define a batch of good services - services := map[string]InstrumentingWrapper{ - "A": InstrumentedServiceMakerA, - "B": InstrumentedServiceMakerB, - "C": InstrumentedServiceMakerC, - } started := make(map[string]bool) - for id, maker := range services { - id := id // Closure for the constructor - constructor := func(*ServiceContext) (Service, error) { - return &InstrumentedService{ - startHook: func(*p2p.Server) { started[id] = true }, - }, nil - } - if err := stack.Register(maker(constructor)); err != nil { - t.Fatalf("service %s: registration failed: %v", id, err) - } + stopped := make(map[string]bool) + + // Create a batch of instrumented services + lifecycles := map[string]Lifecycle{ + "A": &InstrumentedService{ + startHook: func() { started["A"] = true }, + stopHook: func() { stopped["A"] = true }, + }, + "B": &InstrumentedService{ + startHook: func() { started["B"] = true }, + stopHook: func() { stopped["B"] = true }, + }, + "C": &InstrumentedService{ + startHook: func() { started["C"] = true }, + stopHook: func() { stopped["C"] = true }, + }, } + // register lifecycles on node + for _, lifecycle := range lifecycles { + stack.RegisterLifecycle(lifecycle) + } + // Register a service that fails to construct itself failure := errors.New("fail") - failer := func(*ServiceContext) (Service, error) { - return nil, failure - } - if err := stack.Register(failer); err != nil { - t.Fatalf("failer registration failed: %v", err) - } - // Start the protocol stack and ensure none of the services get started - for i := 0; i < 100; i++ { - if err := stack.Start(); err != failure { - t.Fatalf("iter %d: stack startup failure mismatch: have %v, want %v", i, err, failure) - } - for id := range services { - if started[id] { - t.Fatalf("service %s: started should not have", id) - } - delete(started, id) - } - } -} + failer := &InstrumentedService{start: failure} + stack.RegisterLifecycle(failer) -// Tests that if a service fails to start, all others started before it will be -// shut down. -func TestServiceStartupAbortion(t *testing.T) { - stack, err := New(testNodeConfig()) - if err != nil { - t.Fatalf("failed to create protocol stack: %v", err) - } - defer stack.Close() - - // Register a batch of good services - services := map[string]InstrumentingWrapper{ - "A": InstrumentedServiceMakerA, - "B": InstrumentedServiceMakerB, - "C": InstrumentedServiceMakerC, - } - started := make(map[string]bool) - stopped := make(map[string]bool) - - for id, maker := range services { - id := id // Closure for the constructor - constructor := func(*ServiceContext) (Service, error) { - return &InstrumentedService{ - startHook: func(*p2p.Server) { started[id] = true }, - stopHook: func() { stopped[id] = true }, - }, nil - } - if err := stack.Register(maker(constructor)); err != nil { - t.Fatalf("service %s: registration failed: %v", id, err) - } - } - // Register a service that fails to start - failure := errors.New("fail") - failer := func(*ServiceContext) (Service, error) { - return &InstrumentedService{ - start: failure, - }, nil - } - if err := stack.Register(failer); err != nil { - t.Fatalf("failer registration failed: %v", err) - } // Start the protocol stack and ensure all started services stop - for i := 0; i < 100; i++ { - if err := stack.Start(); err != failure { - t.Fatalf("iter %d: stack startup failure mismatch: have %v, want %v", i, err, failure) - } - for id := range services { - if started[id] && !stopped[id] { - t.Fatalf("service %s: started but not stopped", id) - } - delete(started, id) - delete(stopped, id) + if err := stack.Start(); err != failure { + t.Fatalf("stack startup failure mismatch: have %v, want %v", err, failure) + } + for id := range lifecycles { + if started[id] && !stopped[id] { + t.Fatalf("service %s: started but not stopped", id) } + delete(started, id) + delete(stopped, id) } } -// Tests that even if a registered service fails to shut down cleanly, it does +// Tests that even if a registered Lifecycle fails to shut down cleanly, it does // not influence the rest of the shutdown invocations. -func TestServiceTerminationGuarantee(t *testing.T) { +func TestLifecycleTerminationGuarantee(t *testing.T) { stack, err := New(testNodeConfig()) if err != nil { t.Fatalf("failed to create protocol stack: %v", err) } defer stack.Close() - // Register a batch of good services - services := map[string]InstrumentingWrapper{ - "A": InstrumentedServiceMakerA, - "B": InstrumentedServiceMakerB, - "C": InstrumentedServiceMakerC, - } started := make(map[string]bool) stopped := make(map[string]bool) - for id, maker := range services { - id := id // Closure for the constructor - constructor := func(*ServiceContext) (Service, error) { - return &InstrumentedService{ - startHook: func(*p2p.Server) { started[id] = true }, - stopHook: func() { stopped[id] = true }, - }, nil - } - if err := stack.Register(maker(constructor)); err != nil { - t.Fatalf("service %s: registration failed: %v", id, err) - } + // Create a batch of instrumented services + lifecycles := map[string]Lifecycle{ + "A": &InstrumentedService{ + startHook: func() { started["A"] = true }, + stopHook: func() { stopped["A"] = true }, + }, + "B": &InstrumentedService{ + startHook: func() { started["B"] = true }, + stopHook: func() { stopped["B"] = true }, + }, + "C": &InstrumentedService{ + startHook: func() { started["C"] = true }, + stopHook: func() { stopped["C"] = true }, + }, } + // register lifecycles on node + for _, lifecycle := range lifecycles { + stack.RegisterLifecycle(lifecycle) + } + // Register a service that fails to shot down cleanly failure := errors.New("fail") - failer := func(*ServiceContext) (Service, error) { - return &InstrumentedService{ - stop: failure, - }, nil - } - if err := stack.Register(failer); err != nil { - t.Fatalf("failer registration failed: %v", err) - } + failer := &InstrumentedService{stop: failure} + stack.RegisterLifecycle(failer) + // Start the protocol stack, and ensure that a failing shut down terminates all - for i := 0; i < 100; i++ { - // Start the stack and make sure all is online - if err := stack.Start(); err != nil { - t.Fatalf("iter %d: failed to start protocol stack: %v", i, err) - } - for id := range services { - if !started[id] { - t.Fatalf("iter %d, service %s: service not running", i, id) - } - if stopped[id] { - t.Fatalf("iter %d, service %s: service already stopped", i, id) - } - } - // Stop the stack, verify failure and check all terminations - err := stack.Stop() - if err, ok := err.(*StopError); !ok { - t.Fatalf("iter %d: termination failure mismatch: have %v, want StopError", i, err) - } else { - failer := reflect.TypeOf(&InstrumentedService{}) - if err.Services[failer] != failure { - t.Fatalf("iter %d: failer termination failure mismatch: have %v, want %v", i, err.Services[failer], failure) - } - if len(err.Services) != 1 { - t.Fatalf("iter %d: failure count mismatch: have %d, want %d", i, len(err.Services), 1) - } - } - for id := range services { - if !stopped[id] { - t.Fatalf("iter %d, service %s: service not terminated", i, id) - } - delete(started, id) - delete(stopped, id) - } - } -} - -// TestServiceRetrieval tests that individual services can be retrieved. -func TestServiceRetrieval(t *testing.T) { - // Create a simple stack and register two service types - stack, err := New(testNodeConfig()) - if err != nil { - t.Fatalf("failed to create protocol stack: %v", err) - } - defer stack.Close() - - if err := stack.Register(NewNoopService); err != nil { - t.Fatalf("noop service registration failed: %v", err) - } - if err := stack.Register(NewInstrumentedService); err != nil { - t.Fatalf("instrumented service registration failed: %v", err) - } - // Make sure none of the services can be retrieved until started - var noopServ *NoopService - if err := stack.Service(&noopServ); err != ErrNodeStopped { - t.Fatalf("noop service retrieval mismatch: have %v, want %v", err, ErrNodeStopped) - } - var instServ *InstrumentedService - if err := stack.Service(&instServ); err != ErrNodeStopped { - t.Fatalf("instrumented service retrieval mismatch: have %v, want %v", err, ErrNodeStopped) - } - // Start the stack and ensure everything is retrievable now - if err := stack.Start(); err != nil { - t.Fatalf("failed to start stack: %v", err) - } - defer stack.Stop() - - if err := stack.Service(&noopServ); err != nil { - t.Fatalf("noop service retrieval mismatch: have %v, want %v", err, nil) - } - if err := stack.Service(&instServ); err != nil { - t.Fatalf("instrumented service retrieval mismatch: have %v, want %v", err, nil) - } -} - -// Tests that all protocols defined by individual services get launched. -func TestProtocolGather(t *testing.T) { - stack, err := New(testNodeConfig()) - if err != nil { - t.Fatalf("failed to create protocol stack: %v", err) - } - defer stack.Close() - - // Register a batch of services with some configured number of protocols - services := map[string]struct { - Count int - Maker InstrumentingWrapper - }{ - "zero": {0, InstrumentedServiceMakerA}, - "one": {1, InstrumentedServiceMakerB}, - "many": {10, InstrumentedServiceMakerC}, - } - for id, config := range services { - protocols := make([]p2p.Protocol, config.Count) - for i := 0; i < len(protocols); i++ { - protocols[i].Name = id - protocols[i].Version = uint(i) - } - constructor := func(*ServiceContext) (Service, error) { - return &InstrumentedService{ - protocols: protocols, - }, nil - } - if err := stack.Register(config.Maker(constructor)); err != nil { - t.Fatalf("service %s: registration failed: %v", id, err) - } - } - // Start the services and ensure all protocols start successfully + // Start the stack and make sure all is online if err := stack.Start(); err != nil { t.Fatalf("failed to start protocol stack: %v", err) } - defer stack.Stop() - - protocols := stack.Server().Protocols - if len(protocols) != 11 { - t.Fatalf("mismatching number of protocols launched: have %d, want %d", len(protocols), 26) - } - for id, config := range services { - for ver := 0; ver < config.Count; ver++ { - launched := false - for i := 0; i < len(protocols); i++ { - if protocols[i].Name == id && protocols[i].Version == uint(ver) { - launched = true - break - } - } - if !launched { - t.Errorf("configured protocol not launched: %s v%d", id, ver) - } + for id := range lifecycles { + if !started[id] { + t.Fatalf("service %s: service not running", id) + } + if stopped[id] { + t.Fatalf("service %s: service already stopped", id) } } + // Stop the stack, verify failure and check all terminations + err = stack.Close() + if err, ok := err.(*StopError); !ok { + t.Fatalf("termination failure mismatch: have %v, want StopError", err) + } else { + failer := reflect.TypeOf(&InstrumentedService{}) + if err.Services[failer] != failure { + t.Fatalf("failer termination failure mismatch: have %v, want %v", err.Services[failer], failure) + } + if len(err.Services) != 1 { + t.Fatalf("failure count mismatch: have %d, want %d", len(err.Services), 1) + } + } + for id := range lifecycles { + if !stopped[id] { + t.Fatalf("service %s: service not terminated", id) + } + delete(started, id) + delete(stopped, id) + } + + stack.server = &p2p.Server{} + stack.server.PrivateKey = testNodeKey } -// Tests that all APIs defined by individual services get exposed. -func TestAPIGather(t *testing.T) { - stack, err := New(testNodeConfig()) - if err != nil { - t.Fatalf("failed to create protocol stack: %v", err) - } - defer stack.Close() +// Tests whether a handler can be successfully mounted on the canonical HTTP server +// on the given path +func TestRegisterHandler_Successful(t *testing.T) { + node := createNode(t, 7878, 7979) - // Register a batch of services with some configured APIs - calls := make(chan string, 1) - makeAPI := func(result string) *OneMethodAPI { - return &OneMethodAPI{fun: func() { calls <- result }} - } - services := map[string]struct { - APIs []rpc.API - Maker InstrumentingWrapper - }{ - "Zero APIs": { - []rpc.API{}, InstrumentedServiceMakerA}, - "Single API": { - []rpc.API{ - {Namespace: "single", Version: "1", Service: makeAPI("single.v1"), Public: true}, - }, InstrumentedServiceMakerB}, - "Many APIs": { - []rpc.API{ - {Namespace: "multi", Version: "1", Service: makeAPI("multi.v1"), Public: true}, - {Namespace: "multi.v2", Version: "2", Service: makeAPI("multi.v2"), Public: true}, - {Namespace: "multi.v2.nested", Version: "2", Service: makeAPI("multi.v2.nested"), Public: true}, - }, InstrumentedServiceMakerC}, + // create and mount handler + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("success")) + }) + node.RegisterHandler("test", "/test", handler) + + // start node + if err := node.Start(); err != nil { + t.Fatalf("could not start node: %v", err) } - for id, config := range services { - config := config - constructor := func(*ServiceContext) (Service, error) { - return &InstrumentedService{apis: config.APIs}, nil - } - if err := stack.Register(config.Maker(constructor)); err != nil { - t.Fatalf("service %s: registration failed: %v", id, err) - } - } - // Start the services and ensure all API start successfully - if err := stack.Start(); err != nil { - t.Fatalf("failed to start protocol stack: %v", err) - } - defer stack.Stop() - - // Connect to the RPC server and verify the various registered endpoints - client, err := stack.Attach() - if err != nil { - t.Fatalf("failed to connect to the inproc API server: %v", err) - } - defer client.Close() - - tests := []struct { - Method string - Result string - }{ - {"single_theOneMethod", "single.v1"}, - {"multi_theOneMethod", "multi.v1"}, - {"multi.v2_theOneMethod", "multi.v2"}, - {"multi.v2.nested_theOneMethod", "multi.v2.nested"}, - } - for i, test := range tests { - if err := client.Call(nil, test.Method); err != nil { - t.Errorf("test %d: API request failed: %v", i, err) - } - select { - case result := <-calls: - if result != test.Result { - t.Errorf("test %d: result mismatch: have %s, want %s", i, result, test.Result) - } - case <-time.After(time.Second): - t.Fatalf("test %d: rpc execution timeout", i) - } - } -} - -func TestWebsocketHTTPOnSamePort_WebsocketRequest(t *testing.T) { - node := startHTTP(t) - defer node.stopHTTP() - - wsReq, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:7453", nil) + // create HTTP request + httpReq, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:7878/test", nil) if err != nil { t.Error("could not issue new http request ", err) } - wsReq.Header.Set("Connection", "upgrade") - wsReq.Header.Set("Upgrade", "websocket") - wsReq.Header.Set("Sec-WebSocket-Version", "13") - wsReq.Header.Set("Sec-Websocket-Key", "SGVsbG8sIHdvcmxkIQ==") - - resp := doHTTPRequest(t, wsReq) - assert.Equal(t, "websocket", resp.Header.Get("Upgrade")) -} - -func TestWebsocketHTTPOnSamePort_HTTPRequest(t *testing.T) { - node := startHTTP(t) - defer node.stopHTTP() - - httpReq, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:7453", nil) - if err != nil { - t.Error("could not issue new http request ", err) - } - httpReq.Header.Set("Accept-Encoding", "gzip") + // check response resp := doHTTPRequest(t, httpReq) - assert.Equal(t, "gzip", resp.Header.Get("Content-Encoding")) + buf := make([]byte, 7) + _, err = io.ReadFull(resp.Body, buf) + if err != nil { + t.Fatalf("could not read response: %v", err) + } + assert.Equal(t, "success", string(buf)) } -func startHTTP(t *testing.T) *Node { - conf := &Config{HTTPPort: 7453, WSPort: 7453} +// Tests that the given handler will not be successfully mounted since no HTTP server +// is enabled for RPC +func TestRegisterHandler_Unsuccessful(t *testing.T) { + node, err := New(&DefaultConfig) + if err != nil { + t.Fatalf("could not create new node: %v", err) + } + + // create and mount handler + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("success")) + }) + node.RegisterHandler("test", "/test", handler) +} + +// Tests whether websocket requests can be handled on the same port as a regular http server. +func TestWebsocketHTTPOnSamePort_WebsocketRequest(t *testing.T) { + node := startHTTP(t, 0, 0) + defer node.Close() + + ws := strings.Replace(node.HTTPEndpoint(), "http://", "ws://", 1) + + if node.WSEndpoint() != ws { + t.Fatalf("endpoints should be the same") + } + if !checkRPC(ws) { + t.Fatalf("ws request failed") + } + if !checkRPC(node.HTTPEndpoint()) { + t.Fatalf("http request failed") + } +} + +func TestWebsocketHTTPOnSeparatePort_WSRequest(t *testing.T) { + // try and get a free port + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal("can't listen:", err) + } + port := listener.Addr().(*net.TCPAddr).Port + listener.Close() + + node := startHTTP(t, 0, port) + defer node.Close() + + wsOnHTTP := strings.Replace(node.HTTPEndpoint(), "http://", "ws://", 1) + ws := fmt.Sprintf("ws://127.0.0.1:%d", port) + + if node.WSEndpoint() == wsOnHTTP { + t.Fatalf("endpoints should not be the same") + } + // ensure ws endpoint matches the expected endpoint + if node.WSEndpoint() != ws { + t.Fatalf("ws endpoint is incorrect: expected %s, got %s", ws, node.WSEndpoint()) + } + + if !checkRPC(ws) { + t.Fatalf("ws request failed") + } + if !checkRPC(node.HTTPEndpoint()) { + t.Fatalf("http request failed") + } + +} + +func createNode(t *testing.T, httpPort, wsPort int) *Node { + conf := &Config{ + HTTPHost: "127.0.0.1", + HTTPPort: httpPort, + WSHost: "127.0.0.1", + WSPort: wsPort, + } node, err := New(conf) if err != nil { - t.Error("could not create a new node ", err) + t.Fatalf("could not create a new node: %v", err) } + return node +} - err = node.startHTTP("127.0.0.1:7453", []rpc.API{}, []string{}, []string{}, []string{}, rpc.HTTPTimeouts{}, []string{}) +func startHTTP(t *testing.T, httpPort, wsPort int) *Node { + node := createNode(t, httpPort, wsPort) + err := node.Start() if err != nil { - t.Error("could not start http service on node ", err) + t.Fatalf("could not start http service on node: %v", err) } return node } func doHTTPRequest(t *testing.T, req *http.Request) *http.Response { - client := &http.Client{} + client := http.DefaultClient resp, err := client.Do(req) if err != nil { - t.Error("could not issue a GET request to the given endpoint", err) + t.Fatalf("could not issue a GET request to the given endpoint: %v", err) + } return resp } + +func containsProtocol(stackProtocols []p2p.Protocol, protocol p2p.Protocol) bool { + for _, a := range stackProtocols { + if reflect.DeepEqual(a, protocol) { + return true + } + } + return false +} + +func containsAPI(stackAPIs []rpc.API, api rpc.API) bool { + for _, a := range stackAPIs { + if reflect.DeepEqual(a, api) { + return true + } + } + return false +} diff --git a/node/rpcstack.go b/node/rpcstack.go index 4c2858e3a..5c9248354 100644 --- a/node/rpcstack.go +++ b/node/rpcstack.go @@ -18,17 +18,304 @@ package node import ( "compress/gzip" + "context" + "fmt" "io" "io/ioutil" "net" "net/http" + "sort" "strings" "sync" + "sync/atomic" "github.com/ledgerwatch/turbo-geth/log" + "github.com/ethereum/go-ethereum/rpc" "github.com/rs/cors" ) +// httpConfig is the JSON-RPC/HTTP configuration. +type httpConfig struct { + Modules []string + CorsAllowedOrigins []string + Vhosts []string +} + +// wsConfig is the JSON-RPC/Websocket configuration +type wsConfig struct { + Origins []string + Modules []string +} + +type rpcHandler struct { + http.Handler + server *rpc.Server +} + +type httpServer struct { + log log.Logger + timeouts rpc.HTTPTimeouts + mux http.ServeMux // registered handlers go here + + mu sync.Mutex + server *http.Server + listener net.Listener // non-nil when server is running + + // HTTP RPC handler things. + httpConfig httpConfig + httpHandler atomic.Value // *rpcHandler + + // WebSocket handler things. + wsConfig wsConfig + wsHandler atomic.Value // *rpcHandler + + // These are set by setListenAddr. + endpoint string + host string + port int + + handlerNames map[string]string +} + +func newHTTPServer(log log.Logger, timeouts rpc.HTTPTimeouts) *httpServer { + h := &httpServer{log: log, timeouts: timeouts, handlerNames: make(map[string]string)} + h.httpHandler.Store((*rpcHandler)(nil)) + h.wsHandler.Store((*rpcHandler)(nil)) + return h +} + +// setListenAddr configures the listening address of the server. +// The address can only be set while the server isn't running. +func (h *httpServer) setListenAddr(host string, port int) error { + h.mu.Lock() + defer h.mu.Unlock() + + if h.listener != nil && (host != h.host || port != h.port) { + return fmt.Errorf("HTTP server already running on %s", h.endpoint) + } + + h.host, h.port = host, port + h.endpoint = fmt.Sprintf("%s:%d", host, port) + return nil +} + +// listenAddr returns the listening address of the server. +func (h *httpServer) listenAddr() string { + h.mu.Lock() + defer h.mu.Unlock() + + if h.listener != nil { + return h.listener.Addr().String() + } + return h.endpoint +} + +// start starts the HTTP server if it is enabled and not already running. +func (h *httpServer) start() error { + h.mu.Lock() + defer h.mu.Unlock() + + if h.endpoint == "" || h.listener != nil { + return nil // already running or not configured + } + + // Initialize the server. + h.server = &http.Server{Handler: h} + if h.timeouts != (rpc.HTTPTimeouts{}) { + CheckTimeouts(&h.timeouts) + h.server.ReadTimeout = h.timeouts.ReadTimeout + h.server.WriteTimeout = h.timeouts.WriteTimeout + h.server.IdleTimeout = h.timeouts.IdleTimeout + } + + // Start the server. + listener, err := net.Listen("tcp", h.endpoint) + if err != nil { + // If the server fails to start, we need to clear out the RPC and WS + // configuration so they can be configured another time. + h.disableRPC() + h.disableWS() + return err + } + h.listener = listener + go h.server.Serve(listener) + + // if server is websocket only, return after logging + if h.wsAllowed() && !h.rpcAllowed() { + h.log.Info("WebSocket enabled", "url", fmt.Sprintf("ws://%v", listener.Addr())) + return nil + } + // Log http endpoint. + h.log.Info("HTTP server started", + "endpoint", listener.Addr(), + "cors", strings.Join(h.httpConfig.CorsAllowedOrigins, ","), + "vhosts", strings.Join(h.httpConfig.Vhosts, ","), + ) + + // Log all handlers mounted on server. + var paths []string + for path := range h.handlerNames { + paths = append(paths, path) + } + sort.Strings(paths) + logged := make(map[string]bool, len(paths)) + for _, path := range paths { + name := h.handlerNames[path] + if !logged[name] { + log.Info(name+" enabled", "url", "http://"+listener.Addr().String()+path) + logged[name] = true + } + } + return nil +} + +func (h *httpServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { + rpc := h.httpHandler.Load().(*rpcHandler) + if r.RequestURI == "/" { + // Serve JSON-RPC on the root path. + ws := h.wsHandler.Load().(*rpcHandler) + if ws != nil && isWebsocket(r) { + ws.ServeHTTP(w, r) + return + } + if rpc != nil { + rpc.ServeHTTP(w, r) + return + } + } else if rpc != nil { + // Requests to a path below root are handled by the mux, + // which has all the handlers registered via Node.RegisterHandler. + // These are made available when RPC is enabled. + h.mux.ServeHTTP(w, r) + return + } + w.WriteHeader(404) +} + +// stop shuts down the HTTP server. +func (h *httpServer) stop() { + h.mu.Lock() + defer h.mu.Unlock() + h.doStop() +} + +func (h *httpServer) doStop() { + if h.listener == nil { + return // not running + } + + // Shut down the server. + httpHandler := h.httpHandler.Load().(*rpcHandler) + wsHandler := h.httpHandler.Load().(*rpcHandler) + if httpHandler != nil { + h.httpHandler.Store((*rpcHandler)(nil)) + httpHandler.server.Stop() + } + if wsHandler != nil { + h.wsHandler.Store((*rpcHandler)(nil)) + wsHandler.server.Stop() + } + h.server.Shutdown(context.Background()) + h.listener.Close() + h.log.Info("HTTP server stopped", "endpoint", h.listener.Addr()) + + // Clear out everything to allow re-configuring it later. + h.host, h.port, h.endpoint = "", 0, "" + h.server, h.listener = nil, nil +} + +// enableRPC turns on JSON-RPC over HTTP on the server. +func (h *httpServer) enableRPC(apis []rpc.API, config httpConfig) error { + h.mu.Lock() + defer h.mu.Unlock() + + if h.rpcAllowed() { + return fmt.Errorf("JSON-RPC over HTTP is already enabled") + } + + // Create RPC server and handler. + srv := rpc.NewServer() + if err := RegisterApisFromWhitelist(apis, config.Modules, srv, false); err != nil { + return err + } + h.httpConfig = config + h.httpHandler.Store(&rpcHandler{ + Handler: NewHTTPHandlerStack(srv, config.CorsAllowedOrigins, config.Vhosts), + server: srv, + }) + return nil +} + +// disableRPC stops the HTTP RPC handler. This is internal, the caller must hold h.mu. +func (h *httpServer) disableRPC() bool { + handler := h.httpHandler.Load().(*rpcHandler) + if handler != nil { + h.httpHandler.Store((*rpcHandler)(nil)) + handler.server.Stop() + } + return handler != nil +} + +// enableWS turns on JSON-RPC over WebSocket on the server. +func (h *httpServer) enableWS(apis []rpc.API, config wsConfig) error { + h.mu.Lock() + defer h.mu.Unlock() + + if h.wsAllowed() { + return fmt.Errorf("JSON-RPC over WebSocket is already enabled") + } + + // Create RPC server and handler. + srv := rpc.NewServer() + if err := RegisterApisFromWhitelist(apis, config.Modules, srv, false); err != nil { + return err + } + h.wsConfig = config + h.wsHandler.Store(&rpcHandler{ + Handler: srv.WebsocketHandler(config.Origins), + server: srv, + }) + return nil +} + +// stopWS disables JSON-RPC over WebSocket and also stops the server if it only serves WebSocket. +func (h *httpServer) stopWS() { + h.mu.Lock() + defer h.mu.Unlock() + + if h.disableWS() { + if !h.rpcAllowed() { + h.doStop() + } + } +} + +// disableWS disables the WebSocket handler. This is internal, the caller must hold h.mu. +func (h *httpServer) disableWS() bool { + ws := h.wsHandler.Load().(*rpcHandler) + if ws != nil { + h.wsHandler.Store((*rpcHandler)(nil)) + ws.server.Stop() + } + return ws != nil +} + +// rpcAllowed returns true when JSON-RPC over HTTP is enabled. +func (h *httpServer) rpcAllowed() bool { + return h.httpHandler.Load().(*rpcHandler) != nil +} + +// wsAllowed returns true when JSON-RPC over WebSocket is enabled. +func (h *httpServer) wsAllowed() bool { + return h.wsHandler.Load().(*rpcHandler) != nil +} + +// isWebsocket checks the header of an http request for a websocket upgrade request. +func isWebsocket(r *http.Request) bool { + return strings.ToLower(r.Header.Get("Upgrade")) == "websocket" && + strings.ToLower(r.Header.Get("Connection")) == "upgrade" +} + // NewHTTPHandlerStack returns wrapped http-related handlers func NewHTTPHandlerStack(srv http.Handler, cors []string, vhosts []string) http.Handler { // Wrap the CORS-handler within a host-handler @@ -45,8 +332,8 @@ func newCorsHandler(srv http.Handler, allowedOrigins []string) http.Handler { c := cors.New(cors.Options{ AllowedOrigins: allowedOrigins, AllowedMethods: []string{http.MethodPost, http.MethodGet}, - MaxAge: 600, AllowedHeaders: []string{"*"}, + MaxAge: 600, }) return c.Handler(srv) } @@ -138,22 +425,68 @@ func newGzipHandler(next http.Handler) http.Handler { }) } -// NewWebsocketUpgradeHandler returns a websocket handler that serves an incoming request only if it contains an upgrade -// request to the websocket protocol. If not, serves the the request with the http handler. -func NewWebsocketUpgradeHandler(h http.Handler, ws http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if isWebsocket(r) { - ws.ServeHTTP(w, r) - log.Debug("serving websocket request") - return +type ipcServer struct { + log log.Logger + endpoint string + + mu sync.Mutex + listener net.Listener + srv *rpc.Server +} + +func newIPCServer(log log.Logger, endpoint string) *ipcServer { + return &ipcServer{log: log, endpoint: endpoint} +} + +// Start starts the httpServer's http.Server +func (is *ipcServer) start(apis []rpc.API) error { + is.mu.Lock() + defer is.mu.Unlock() + + if is.listener != nil { + return nil // already running + } + listener, srv, err := rpc.StartIPCEndpoint(is.endpoint, apis) + if err != nil { + return err + } + is.log.Info("IPC endpoint opened", "url", is.endpoint) + is.listener, is.srv = listener, srv + return nil +} + +func (is *ipcServer) stop() error { + is.mu.Lock() + defer is.mu.Unlock() + + if is.listener == nil { + return nil // not running + } + err := is.listener.Close() + is.srv.Stop() + is.listener, is.srv = nil, nil + is.log.Info("IPC endpoint closed", "url", is.endpoint) + return err +} + +// RegisterApisFromWhitelist checks the given modules' availability, generates a whitelist based on the allowed modules, +// and then registers all of the APIs exposed by the services. +func RegisterApisFromWhitelist(apis []rpc.API, modules []string, srv *rpc.Server, exposeAll bool) error { + if bad, available := checkModuleAvailability(modules, apis); len(bad) > 0 { + log.Error("Unavailable modules in HTTP API list", "unavailable", bad, "available", available) + } + // Generate the whitelist based on the allowed modules + whitelist := make(map[string]bool) + for _, module := range modules { + whitelist[module] = true + } + // Register all the APIs exposed by the services + for _, api := range apis { + if exposeAll || whitelist[api.Namespace] || (len(whitelist) == 0 && api.Public) { + if err := srv.RegisterName(api.Namespace, api.Service); err != nil { + return err + } } - - h.ServeHTTP(w, r) - }) -} - -// isWebsocket checks the header of an http request for a websocket upgrade request. -func isWebsocket(r *http.Request) bool { - return strings.ToLower(r.Header.Get("Upgrade")) == "websocket" && - strings.ToLower(r.Header.Get("Connection")) == "upgrade" + } + return nil } diff --git a/node/rpcstack_test.go b/node/rpcstack_test.go index 56ce4c171..3999b9073 100644 --- a/node/rpcstack_test.go +++ b/node/rpcstack_test.go @@ -1,38 +1,107 @@ +// Copyright 2020 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + package node import ( + "bytes" "net/http" - "net/http/httptest" "testing" "github.com/ledgerwatch/turbo-geth/rpc" "github.com/stretchr/testify/assert" ) -func TestNewWebsocketUpgradeHandler_websocket(t *testing.T) { - srv := rpc.NewServer() +// TestCorsHandler makes sure CORS are properly handled on the http server. +func TestCorsHandler(t *testing.T) { + srv := createAndStartServer(t, httpConfig{CorsAllowedOrigins: []string{"test", "test.com"}}, false, wsConfig{}) + defer srv.stop() - handler := NewWebsocketUpgradeHandler(nil, srv.WebsocketHandler([]string{})) - ts := httptest.NewServer(handler) - defer ts.Close() + resp := testRequest(t, "origin", "test.com", "", srv) + assert.Equal(t, "test.com", resp.Header.Get("Access-Control-Allow-Origin")) - responses := make(chan *http.Response) - go func(responses chan *http.Response) { - client := &http.Client{} - - req, _ := http.NewRequest(http.MethodGet, ts.URL, nil) - req.Header.Set("Connection", "upgrade") - req.Header.Set("Upgrade", "websocket") - req.Header.Set("Sec-WebSocket-Version", "13") - req.Header.Set("Sec-Websocket-Key", "SGVsbG8sIHdvcmxkIQ==") - - resp, err := client.Do(req) - if err != nil { - t.Error("could not issue a GET request to the test http server", err) - } - responses <- resp - }(responses) - - response := <-responses - assert.Equal(t, "websocket", response.Header.Get("Upgrade")) + resp2 := testRequest(t, "origin", "bad", "", srv) + assert.Equal(t, "", resp2.Header.Get("Access-Control-Allow-Origin")) +} + +// TestVhosts makes sure vhosts are properly handled on the http server. +func TestVhosts(t *testing.T) { + srv := createAndStartServer(t, httpConfig{Vhosts: []string{"test"}}, false, wsConfig{}) + defer srv.stop() + + resp := testRequest(t, "", "", "test", srv) + assert.Equal(t, resp.StatusCode, http.StatusOK) + + resp2 := testRequest(t, "", "", "bad", srv) + assert.Equal(t, resp2.StatusCode, http.StatusForbidden) +} + +// TestWebsocketOrigins makes sure the websocket origins are properly handled on the websocket server. +func TestWebsocketOrigins(t *testing.T) { + srv := createAndStartServer(t, httpConfig{}, true, wsConfig{Origins: []string{"test"}}) + defer srv.stop() + + dialer := websocket.DefaultDialer + _, _, err := dialer.Dial("ws://"+srv.listenAddr(), http.Header{ + "Content-type": []string{"application/json"}, + "Sec-WebSocket-Version": []string{"13"}, + "Origin": []string{"test"}, + }) + assert.NoError(t, err) + + _, _, err = dialer.Dial("ws://"+srv.listenAddr(), http.Header{ + "Content-type": []string{"application/json"}, + "Sec-WebSocket-Version": []string{"13"}, + "Origin": []string{"bad"}, + }) + assert.Error(t, err) +} + +func createAndStartServer(t *testing.T, conf httpConfig, ws bool, wsConf wsConfig) *httpServer { + t.Helper() + + srv := newHTTPServer(testlog.Logger(t, log.LvlDebug), rpc.DefaultHTTPTimeouts) + + assert.NoError(t, srv.enableRPC(nil, conf)) + if ws { + assert.NoError(t, srv.enableWS(nil, wsConf)) + } + assert.NoError(t, srv.setListenAddr("localhost", 0)) + assert.NoError(t, srv.start()) + + return srv +} + +func testRequest(t *testing.T, key, value, host string, srv *httpServer) *http.Response { + t.Helper() + + body := bytes.NewReader([]byte(`{"jsonrpc":"2.0","id":1,method":"rpc_modules"}`)) + req, _ := http.NewRequest("POST", "http://"+srv.listenAddr(), body) + req.Header.Set("content-type", "application/json") + if key != "" && value != "" { + req.Header.Set(key, value) + } + if host != "" { + req.Host = host + } + + client := http.DefaultClient + resp, err := client.Do(req) + if err != nil { + t.Fatal(err) + } + return resp } diff --git a/node/service.go b/node/service.go deleted file mode 100644 index f845c9ff5..000000000 --- a/node/service.go +++ /dev/null @@ -1,129 +0,0 @@ -// Copyright 2015 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see . - -package node - -import ( - "reflect" - - "github.com/ledgerwatch/turbo-geth/accounts" - "github.com/ledgerwatch/turbo-geth/ethdb" - "github.com/ledgerwatch/turbo-geth/event" - "github.com/ledgerwatch/turbo-geth/log" - "github.com/ledgerwatch/turbo-geth/p2p" - "github.com/ledgerwatch/turbo-geth/rpc" -) - -// ServiceContext is a collection of service independent options inherited from -// the protocol stack, that is passed to all constructors to be optionally used; -// as well as utility methods to operate on the service environment. -type ServiceContext struct { - services map[reflect.Type]Service // Index of the already constructed services - Config Config - EventMux *event.TypeMux // Event multiplexer used for decoupled notifications - AccountManager *accounts.Manager // Account manager created by the node. -} - -// OpenDatabaseWithFreezer -// FIXME: implement the functionality -func (ctx *ServiceContext) OpenDatabaseWithFreezer(name string, freezer string) (*ethdb.ObjectDatabase, error) { - return ctx.OpenDatabase(name) -} - -// OpenDatabase opens an existing database with the given name (or creates one -// if no previous can be found) from within the node's data directory. If the -// node is an ephemeral one, a memory database is returned. -func (ctx *ServiceContext) OpenDatabase(name string) (*ethdb.ObjectDatabase, error) { - if ctx.Config.DataDir == "" { - return ethdb.NewMemDatabase(), nil - } - - if ctx.Config.Bolt { - log.Info("Opening Database (Bolt)") - return ethdb.Open(ctx.Config.ResolvePath(name + "_bolt")) - } - - log.Info("Opening Database (LMDB)") - return ethdb.Open(ctx.Config.ResolvePath(name)) - /* - if err != nil { - return nil, err - } - root := ctx.config.ResolvePath(name) - - FIXME: restore and move to OpenDatabaseWithFreezer - switch { - case freezer == "": - freezer = filepath.Join(root, "ancient") - case !filepath.IsAbs(freezer): - freezer = ctx.config.ResolvePath(freezer) - } - return ethdb.NewBoltDatabase(root) - */ -} - -// ResolvePath resolves a user path into the data directory if that was relative -// and if the user actually uses persistent storage. It will return an empty string -// for emphemeral storage and the user's own input for absolute paths. -func (ctx *ServiceContext) ResolvePath(path string) string { - return ctx.Config.ResolvePath(path) -} - -// Service retrieves a currently running service registered of a specific type. -func (ctx *ServiceContext) Service(service interface{}) error { - element := reflect.ValueOf(service).Elem() - if running, ok := ctx.services[element.Type()]; ok { - element.Set(reflect.ValueOf(running)) - return nil - } - return ErrServiceUnknown -} - -// ExtRPCEnabled returns the indicator whether node enables the external -// RPC(http, ws or graphql). -func (ctx *ServiceContext) ExtRPCEnabled() bool { - return ctx.Config.ExtRPCEnabled() -} - -// ServiceConstructor is the function signature of the constructors needed to be -// registered for service instantiation. -type ServiceConstructor func(ctx *ServiceContext) (Service, error) - -// Service is an individual protocol that can be registered into a node. -// -// Notes: -// -// • Service life-cycle management is delegated to the node. The service is allowed to -// initialize itself upon creation, but no goroutines should be spun up outside of the -// Start method. -// -// • Restart logic is not required as the node will create a fresh instance -// every time a service is started. -type Service interface { - // Protocols retrieves the P2P protocols the service wishes to start. - Protocols() []p2p.Protocol - - // APIs retrieves the list of RPC descriptors the service provides - APIs() []rpc.API - - // Start is called after all services have been constructed and the networking - // layer was also initialized to spawn any goroutines required by the service. - Start(server *p2p.Server) error - - // Stop terminates all goroutines belonging to the service, blocking until they - // are all terminated. - Stop() error -} diff --git a/node/service_test.go b/node/service_test.go deleted file mode 100644 index 44fe6c62c..000000000 --- a/node/service_test.go +++ /dev/null @@ -1,98 +0,0 @@ -// Copyright 2015 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see . - -package node - -import ( - "fmt" - "io/ioutil" - "os" - "path/filepath" - "testing" -) - -// Tests that databases are correctly created persistent or ephemeral based on -// the configured service context. -func TestContextDatabases(t *testing.T) { - // Create a temporary folder and ensure no database is contained within - dir, err := ioutil.TempDir("", "") - if err != nil { - t.Fatalf("failed to create temporary data directory: %v", err) - } - defer os.RemoveAll(dir) - - if _, err := os.Stat(filepath.Join(dir, "database")); err == nil { - t.Fatalf("non-created database already exists") - } - // Request the opening/creation of a database and ensure it persists to disk - ctx := &ServiceContext{Config: Config{Name: "unit-test", DataDir: dir}} - db, err := ctx.OpenDatabase("persistent") - if err != nil { - t.Fatalf("failed to open persistent database: %v", err) - } - db.Close() - - if _, err := os.Stat(filepath.Join(dir, "unit-test", "persistent")); err != nil { - t.Fatalf("persistent database doesn't exists: %v", err) - } - // Request th opening/creation of an ephemeral database and ensure it's not persisted - ctx = &ServiceContext{Config: Config{DataDir: ""}} - db, err = ctx.OpenDatabase("ephemeral") - if err != nil { - t.Fatalf("failed to open ephemeral database: %v", err) - } - db.Close() - - if _, err := os.Stat(filepath.Join(dir, "ephemeral")); err == nil { - t.Fatalf("ephemeral database exists") - } -} - -// Tests that already constructed services can be retrieves by later ones. -func TestContextServices(t *testing.T) { - stack, err := New(testNodeConfig()) - if err != nil { - t.Fatalf("failed to create protocol stack: %v", err) - } - defer stack.Close() - // Define a verifier that ensures a NoopA is before it and NoopB after - verifier := func(ctx *ServiceContext) (Service, error) { - var objA *NoopServiceA - if ctx.Service(&objA) != nil { - return nil, fmt.Errorf("former service not found") - } - var objB *NoopServiceB - if err := ctx.Service(&objB); err != ErrServiceUnknown { - return nil, fmt.Errorf("latters lookup error mismatch: have %v, want %v", err, ErrServiceUnknown) - } - return new(NoopService), nil - } - // Register the collection of services - if err := stack.Register(NewNoopServiceA); err != nil { - t.Fatalf("former failed to register service: %v", err) - } - if err := stack.Register(verifier); err != nil { - t.Fatalf("failed to register service verifier: %v", err) - } - if err := stack.Register(NewNoopServiceB); err != nil { - t.Fatalf("latter failed to register service: %v", err) - } - // Start the protocol stack and ensure services are constructed in order - if err := stack.Start(); err != nil { - t.Fatalf("failed to start stack: %v", err) - } - defer stack.Stop() -} diff --git a/node/utils_test.go b/node/utils_test.go index 020447354..eea729531 100644 --- a/node/utils_test.go +++ b/node/utils_test.go @@ -20,61 +20,40 @@ package node import ( - "reflect" - "github.com/ledgerwatch/turbo-geth/p2p" "github.com/ledgerwatch/turbo-geth/rpc" ) -// NoopService is a trivial implementation of the Service interface. -type NoopService struct{} +// NoopLifecycle is a trivial implementation of the Service interface. +type NoopLifecycle struct{} -func (s *NoopService) Protocols() []p2p.Protocol { return nil } -func (s *NoopService) APIs() []rpc.API { return nil } -func (s *NoopService) Start(*p2p.Server) error { return nil } -func (s *NoopService) Stop() error { return nil } +func (s *NoopLifecycle) Start() error { return nil } +func (s *NoopLifecycle) Stop() error { return nil } -func NewNoopService(*ServiceContext) (Service, error) { return new(NoopService), nil } +func NewNoop() *Noop { + noop := new(Noop) + return noop +} -// Set of services all wrapping the base NoopService resulting in the same method +// Set of services all wrapping the base NoopLifecycle resulting in the same method // signatures but different outer types. -type NoopServiceA struct{ NoopService } -type NoopServiceB struct{ NoopService } -type NoopServiceC struct{ NoopService } +type Noop struct{ NoopLifecycle } -func NewNoopServiceA(*ServiceContext) (Service, error) { return new(NoopServiceA), nil } -func NewNoopServiceB(*ServiceContext) (Service, error) { return new(NoopServiceB), nil } -func NewNoopServiceC(*ServiceContext) (Service, error) { return new(NoopServiceC), nil } - -// InstrumentedService is an implementation of Service for which all interface +// InstrumentedService is an implementation of Lifecycle for which all interface // methods can be instrumented both return value as well as event hook wise. type InstrumentedService struct { + start error + stop error + + startHook func() + stopHook func() + protocols []p2p.Protocol - apis []rpc.API - start error - stop error - - protocolsHook func() - startHook func(*p2p.Server) - stopHook func() } -func NewInstrumentedService(*ServiceContext) (Service, error) { return new(InstrumentedService), nil } - -func (s *InstrumentedService) Protocols() []p2p.Protocol { - if s.protocolsHook != nil { - s.protocolsHook() - } - return s.protocols -} - -func (s *InstrumentedService) APIs() []rpc.API { - return s.apis -} - -func (s *InstrumentedService) Start(server *p2p.Server) error { +func (s *InstrumentedService) Start() error { if s.startHook != nil { - s.startHook(server) + s.startHook() } return s.start } @@ -86,48 +65,49 @@ func (s *InstrumentedService) Stop() error { return s.stop } -// InstrumentingWrapper is a method to specialize a service constructor returning -// a generic InstrumentedService into one returning a wrapping specific one. -type InstrumentingWrapper func(base ServiceConstructor) ServiceConstructor +type FullService struct{} -func InstrumentingWrapperMaker(base ServiceConstructor, kind reflect.Type) ServiceConstructor { - return func(ctx *ServiceContext) (Service, error) { - obj, err := base(ctx) - if err != nil { - return nil, err - } - wrapper := reflect.New(kind) - wrapper.Elem().Field(0).Set(reflect.ValueOf(obj).Elem()) +func NewFullService(stack *Node) (*FullService, error) { + fs := new(FullService) - return wrapper.Interface().(Service), nil + stack.RegisterProtocols(fs.Protocols()) + stack.RegisterAPIs(fs.APIs()) + stack.RegisterLifecycle(fs) + return fs, nil +} + +func (f *FullService) Start() error { return nil } + +func (f *FullService) Stop() error { return nil } + +func (f *FullService) Protocols() []p2p.Protocol { + return []p2p.Protocol{ + p2p.Protocol{ + Name: "test1", + Version: uint(1), + }, + p2p.Protocol{ + Name: "test2", + Version: uint(2), + }, } } -// Set of services all wrapping the base InstrumentedService resulting in the -// same method signatures but different outer types. -type InstrumentedServiceA struct{ InstrumentedService } -type InstrumentedServiceB struct{ InstrumentedService } -type InstrumentedServiceC struct{ InstrumentedService } - -func InstrumentedServiceMakerA(base ServiceConstructor) ServiceConstructor { - return InstrumentingWrapperMaker(base, reflect.TypeOf(InstrumentedServiceA{})) -} - -func InstrumentedServiceMakerB(base ServiceConstructor) ServiceConstructor { - return InstrumentingWrapperMaker(base, reflect.TypeOf(InstrumentedServiceB{})) -} - -func InstrumentedServiceMakerC(base ServiceConstructor) ServiceConstructor { - return InstrumentingWrapperMaker(base, reflect.TypeOf(InstrumentedServiceC{})) -} - -// OneMethodAPI is a single-method API handler to be returned by test services. -type OneMethodAPI struct { - fun func() -} - -func (api *OneMethodAPI) TheOneMethod() { - if api.fun != nil { - api.fun() +func (f *FullService) APIs() []rpc.API { + return []rpc.API{ + { + Namespace: "admin", + Version: "1.0", + }, + { + Namespace: "debug", + Version: "1.0", + Public: true, + }, + { + Namespace: "net", + Version: "1.0", + Public: true, + }, } } diff --git a/p2p/simulations/adapters/exec.go b/p2p/simulations/adapters/exec.go index 980f85840..613ebcef5 100644 --- a/p2p/simulations/adapters/exec.go +++ b/p2p/simulations/adapters/exec.go @@ -77,11 +77,11 @@ func (e *ExecAdapter) Name() string { // NewNode returns a new ExecNode using the given config func (e *ExecAdapter) NewNode(config *NodeConfig) (Node, error) { - if len(config.Services) == 0 { - return nil, errors.New("node must have at least one service") + if len(config.Lifecycles) == 0 { + return nil, errors.New("node must have at least one service lifecycle") } - for _, service := range config.Services { - if _, exists := serviceFuncs[service]; !exists { + for _, service := range config.Lifecycles { + if _, exists := lifecycleConstructorFuncs[service]; !exists { return nil, fmt.Errorf("unknown node service %q", service) } } @@ -265,7 +265,7 @@ func (n *ExecNode) waitForStartupJSON(ctx context.Context) (string, chan nodeSta func (n *ExecNode) execCommand() *exec.Cmd { return &exec.Cmd{ Path: reexec.Self(), - Args: []string{"p2p-node", strings.Join(n.Config.Node.Services, ","), n.ID.String()}, + Args: []string{"p2p-node", strings.Join(n.Config.Node.Lifecycles, ","), n.ID.String()}, } } @@ -402,7 +402,7 @@ func execP2PNode() { defer signal.Stop(sigc) <-sigc log.Info("Received SIGTERM, shutting down...") - stack.Stop() + stack.Close() }() stack.Wait() // Wait for the stack to exit. } @@ -436,44 +436,36 @@ func startExecNodeStack() (*node.Node, error) { return nil, fmt.Errorf("error creating node stack: %v", err) } - // register the services, collecting them into a map so we can wrap - // them in a snapshot service - services := make(map[string]node.Service, len(serviceNames)) + // Register the services, collecting them into a map so they can + // be accessed by the snapshot API. + services := make(map[string]node.Lifecycle, len(serviceNames)) for _, name := range serviceNames { - serviceFunc, exists := serviceFuncs[name] + lifecycleFunc, exists := lifecycleConstructorFuncs[name] if !exists { return nil, fmt.Errorf("unknown node service %q", err) } - constructor := func(nodeCtx *node.ServiceContext) (node.Service, error) { - ctx := &ServiceContext{ - RPCDialer: &wsRPCDialer{addrs: conf.PeerAddrs}, - NodeContext: nodeCtx, - Config: conf.Node, - } - if conf.Snapshots != nil { - ctx.Snapshot = conf.Snapshots[name] - } - service, err := serviceFunc(ctx) - if err != nil { - return nil, err - } - services[name] = service - return service, nil + ctx := &ServiceContext{ + RPCDialer: &wsRPCDialer{addrs: conf.PeerAddrs}, + Config: conf.Node, } - if err := stack.Register(constructor); err != nil { - return stack, fmt.Errorf("error registering service %q: %v", name, err) + if conf.Snapshots != nil { + ctx.Snapshot = conf.Snapshots[name] } + service, err := lifecycleFunc(ctx, stack) + if err != nil { + return nil, err + } + services[name] = service + stack.RegisterLifecycle(service) } - // register the snapshot service - err = stack.Register(func(ctx *node.ServiceContext) (node.Service, error) { - return &snapshotService{services}, nil - }) - if err != nil { - return stack, fmt.Errorf("error starting snapshot service: %v", err) - } + // Add the snapshot API. + stack.RegisterAPIs([]rpc.API{{ + Namespace: "simulation", + Version: "1.0", + Service: SnapshotAPI{services}, + }}) - // start the stack if err = stack.Start(); err != nil { err = fmt.Errorf("error starting stack: %v", err) } @@ -492,35 +484,9 @@ type nodeStartupJSON struct { NodeInfo *p2p.NodeInfo } -// snapshotService is a node.Service which wraps a list of services and -// exposes an API to generate a snapshot of those services -type snapshotService struct { - services map[string]node.Service -} - -func (s *snapshotService) APIs() []rpc.API { - return []rpc.API{{ - Namespace: "simulation", - Version: "1.0", - Service: SnapshotAPI{s.services}, - }} -} - -func (s *snapshotService) Protocols() []p2p.Protocol { - return nil -} - -func (s *snapshotService) Start(*p2p.Server) error { - return nil -} - -func (s *snapshotService) Stop() error { - return nil -} - // SnapshotAPI provides an RPC method to create snapshots of services type SnapshotAPI struct { - services map[string]node.Service + services map[string]node.Lifecycle } func (api SnapshotAPI) Snapshot() (map[string][]byte, error) { diff --git a/p2p/simulations/adapters/inproc.go b/p2p/simulations/adapters/inproc.go index add230b3b..570c60528 100644 --- a/p2p/simulations/adapters/inproc.go +++ b/p2p/simulations/adapters/inproc.go @@ -38,29 +38,21 @@ import ( // SimAdapter is a NodeAdapter which creates in-memory simulation nodes and // connects them using net.Pipe type SimAdapter struct { - pipe func() (net.Conn, net.Conn, error) - mtx sync.RWMutex - nodes map[enode.ID]*SimNode - services map[string]ServiceFunc + pipe func() (net.Conn, net.Conn, error) + mtx sync.RWMutex + nodes map[enode.ID]*SimNode + lifecycles LifecycleConstructors } // NewSimAdapter creates a SimAdapter which is capable of running in-memory // simulation nodes running any of the given services (the services to run on a // particular node are passed to the NewNode function in the NodeConfig) // the adapter uses a net.Pipe for in-memory simulated network connections -func NewSimAdapter(services map[string]ServiceFunc) *SimAdapter { +func NewSimAdapter(services LifecycleConstructors) *SimAdapter { return &SimAdapter{ - pipe: pipes.NetPipe, - nodes: make(map[enode.ID]*SimNode), - services: services, - } -} - -func NewTCPAdapter(services map[string]ServiceFunc) *SimAdapter { - return &SimAdapter{ - pipe: pipes.TCPPipe, - nodes: make(map[enode.ID]*SimNode), - services: services, + pipe: pipes.NetPipe, + nodes: make(map[enode.ID]*SimNode), + lifecycles: services, } } @@ -86,11 +78,11 @@ func (s *SimAdapter) NewNode(config *NodeConfig) (Node, error) { } // check the services are valid - if len(config.Services) == 0 { + if len(config.Lifecycles) == 0 { return nil, errors.New("node must have at least one service") } - for _, service := range config.Services { - if _, exists := s.services[service]; !exists { + for _, service := range config.Lifecycles { + if _, exists := s.lifecycles[service]; !exists { return nil, fmt.Errorf("unknown node service %q", service) } } @@ -120,7 +112,7 @@ func (s *SimAdapter) NewNode(config *NodeConfig) (Node, error) { config: config, node: n, adapter: s, - running: make(map[string]node.Service), + running: make(map[string]node.Lifecycle), } s.nodes[id] = simNode return simNode, nil @@ -156,11 +148,7 @@ func (s *SimAdapter) DialRPC(id enode.ID) (*rpc.Client, error) { if !ok { return nil, fmt.Errorf("unknown node: %s", id) } - handler, err := node.node.RPCHandler() - if err != nil { - return nil, err - } - return rpc.DialInProc(handler), nil + return node.node.Attach() } // GetNode returns the node with the given ID if it exists @@ -180,7 +168,7 @@ type SimNode struct { config *NodeConfig adapter *SimAdapter node *node.Node - running map[string]node.Service + running map[string]node.Lifecycle client *rpc.Client registerOnce sync.Once } @@ -228,7 +216,7 @@ func (sn *SimNode) ServeRPC(conn *websocket.Conn) error { // simulation_snapshot RPC method func (sn *SimNode) Snapshots() (map[string][]byte, error) { sn.lock.RLock() - services := make(map[string]node.Service, len(sn.running)) + services := make(map[string]node.Lifecycle, len(sn.running)) for name, service := range sn.running { services[name] = service } @@ -253,35 +241,30 @@ func (sn *SimNode) Snapshots() (map[string][]byte, error) { // Start registers the services and starts the underlying devp2p node func (sn *SimNode) Start(snapshots map[string][]byte) error { - newService := func(name string) func(ctx *node.ServiceContext) (node.Service, error) { - return func(nodeCtx *node.ServiceContext) (node.Service, error) { - ctx := &ServiceContext{ - RPCDialer: sn.adapter, - NodeContext: nodeCtx, - Config: sn.config, - } - if snapshots != nil { - ctx.Snapshot = snapshots[name] - } - serviceFunc := sn.adapter.services[name] - service, err := serviceFunc(ctx) - if err != nil { - return nil, err - } - sn.running[name] = service - return service, nil - } - } - // ensure we only register the services once in the case of the node // being stopped and then started again var regErr error sn.registerOnce.Do(func() { - for _, name := range sn.config.Services { - if err := sn.node.Register(newService(name)); err != nil { + for _, name := range sn.config.Lifecycles { + ctx := &ServiceContext{ + RPCDialer: sn.adapter, + Config: sn.config, + } + if snapshots != nil { + ctx.Snapshot = snapshots[name] + } + serviceFunc := sn.adapter.lifecycles[name] + service, err := serviceFunc(ctx, sn.node) + if err != nil { regErr = err break } + // if the service has already been registered, don't register it again. + if _, ok := sn.running[name]; ok { + continue + } + sn.running[name] = service + sn.node.RegisterLifecycle(service) } }) if regErr != nil { @@ -293,13 +276,12 @@ func (sn *SimNode) Start(snapshots map[string][]byte) error { } // create an in-process RPC client - handler, err := sn.node.RPCHandler() + client, err := sn.node.Attach() if err != nil { return err } - sn.lock.Lock() - sn.client = rpc.DialInProc(handler) + sn.client = client sn.lock.Unlock() return nil @@ -313,21 +295,21 @@ func (sn *SimNode) Stop() error { sn.client = nil } sn.lock.Unlock() - return sn.node.Stop() + return sn.node.Close() } // Service returns a running service by name -func (sn *SimNode) Service(name string) node.Service { +func (sn *SimNode) Service(name string) node.Lifecycle { sn.lock.RLock() defer sn.lock.RUnlock() return sn.running[name] } // Services returns a copy of the underlying services -func (sn *SimNode) Services() []node.Service { +func (sn *SimNode) Services() []node.Lifecycle { sn.lock.RLock() defer sn.lock.RUnlock() - services := make([]node.Service, 0, len(sn.running)) + services := make([]node.Lifecycle, 0, len(sn.running)) for _, service := range sn.running { services = append(services, service) } @@ -335,10 +317,10 @@ func (sn *SimNode) Services() []node.Service { } // ServiceMap returns a map by names of the underlying services -func (sn *SimNode) ServiceMap() map[string]node.Service { +func (sn *SimNode) ServiceMap() map[string]node.Lifecycle { sn.lock.RLock() defer sn.lock.RUnlock() - services := make(map[string]node.Service, len(sn.running)) + services := make(map[string]node.Lifecycle, len(sn.running)) for name, service := range sn.running { services[name] = service } diff --git a/p2p/simulations/adapters/types.go b/p2p/simulations/adapters/types.go index e46074315..8b28691ad 100644 --- a/p2p/simulations/adapters/types.go +++ b/p2p/simulations/adapters/types.go @@ -98,11 +98,11 @@ type NodeConfig struct { // Use an existing database instead of a temporary one if non-empty DataDir string - // Services are the names of the services which should be run when - // starting the node (for SimNodes it should be the names of services - // contained in SimAdapter.services, for other nodes it should be - // services registered by calling the RegisterService function) - Services []string + // Lifecycles are the names of the service lifecycles which should be run when + // starting the node (for SimNodes it should be the names of service lifecycles + // contained in SimAdapter.lifecycles, for other nodes it should be + // service lifecycles registered by calling the RegisterLifecycle function) + Lifecycles []string // Properties are the names of the properties this node should hold // within running services (e.g. "bootnode", "lightnode" or any custom values) @@ -139,7 +139,7 @@ func (n *NodeConfig) MarshalJSON() ([]byte, error) { confJSON := nodeConfigJSON{ ID: n.ID.String(), Name: n.Name, - Services: n.Services, + Services: n.Lifecycles, Properties: n.Properties, Port: n.Port, EnableMsgEvents: n.EnableMsgEvents, @@ -177,7 +177,7 @@ func (n *NodeConfig) UnmarshalJSON(data []byte) error { } n.Name = confJSON.Name - n.Services = confJSON.Services + n.Lifecycles = confJSON.Services n.Properties = confJSON.Properties n.Port = confJSON.Port n.EnableMsgEvents = confJSON.EnableMsgEvents @@ -235,9 +235,8 @@ func assignTCPPort() (uint16, error) { type ServiceContext struct { RPCDialer - NodeContext *node.ServiceContext - Config *NodeConfig - Snapshot []byte + Config *NodeConfig + Snapshot []byte } // RPCDialer is used when initialising services which need to connect to @@ -247,27 +246,29 @@ type RPCDialer interface { DialRPC(id enode.ID) (*rpc.Client, error) } -// Services is a collection of services which can be run in a simulation -type Services map[string]ServiceFunc +// LifecycleConstructor allows a Lifecycle to be constructed during node start-up. +// While the service-specific package usually takes care of Lifecycle creation and registration, +// for testing purposes, it is useful to be able to construct a Lifecycle on spot. +type LifecycleConstructor func(ctx *ServiceContext, stack *node.Node) (node.Lifecycle, error) -// ServiceFunc returns a node.Service which can be used to boot a devp2p node -type ServiceFunc func(ctx *ServiceContext) (node.Service, error) +// LifecycleConstructors stores LifecycleConstructor functions to call during node start-up. +type LifecycleConstructors map[string]LifecycleConstructor -// serviceFuncs is a map of registered services which are used to boot devp2p +// lifecycleConstructorFuncs is a map of registered services which are used to boot devp2p // nodes -var serviceFuncs = make(Services) +var lifecycleConstructorFuncs = make(LifecycleConstructors) -// RegisterServices registers the given Services which can then be used to +// RegisterLifecycles registers the given Services which can then be used to // start devp2p nodes using either the Exec or Docker adapters. // // It should be called in an init function so that it has the opportunity to // execute the services before main() is called. -func RegisterServices(services Services) { - for name, f := range services { - if _, exists := serviceFuncs[name]; exists { +func RegisterLifecycles(lifecycles LifecycleConstructors) { + for name, f := range lifecycles { + if _, exists := lifecycleConstructorFuncs[name]; exists { panic(fmt.Sprintf("node service already exists: %q", name)) } - serviceFuncs[name] = f + lifecycleConstructorFuncs[name] = f } // now we have registered the services, run reexec.Init() which will diff --git a/p2p/simulations/connect_test.go b/p2p/simulations/connect_test.go index ff60f0ef8..e3de9a783 100644 --- a/p2p/simulations/connect_test.go +++ b/p2p/simulations/connect_test.go @@ -26,8 +26,8 @@ import ( func newTestNetwork(t *testing.T, nodeCount int) (*Network, []enode.ID) { t.Helper() - adapter := adapters.NewSimAdapter(adapters.Services{ - "noopwoop": func(ctx *adapters.ServiceContext) (node.Service, error) { + adapter := adapters.NewSimAdapter(adapters.LifecycleConstructors{ + "noopwoop": func(ctx *adapters.ServiceContext, stack *node.Node) (node.Lifecycle, error) { return NewNoopService(nil), nil }, }) diff --git a/p2p/simulations/examples/ping-pong.go b/p2p/simulations/examples/ping-pong.go index 134e42d60..b1a8a9a3a 100644 --- a/p2p/simulations/examples/ping-pong.go +++ b/p2p/simulations/examples/ping-pong.go @@ -31,7 +31,6 @@ import ( "github.com/ledgerwatch/turbo-geth/p2p/enode" "github.com/ledgerwatch/turbo-geth/p2p/simulations" "github.com/ledgerwatch/turbo-geth/p2p/simulations/adapters" - "github.com/ledgerwatch/turbo-geth/rpc" ) var adapterType = flag.String("adapter", "sim", `node adapter to use (one of "sim", "exec" or "docker")`) @@ -45,12 +44,14 @@ func main() { log.Root().SetHandler(log.LvlFilterHandler(log.LvlTrace, log.StreamHandler(os.Stderr, log.TerminalFormat(false)))) // register a single ping-pong service - services := map[string]adapters.ServiceFunc{ - "ping-pong": func(ctx *adapters.ServiceContext) (node.Service, error) { - return newPingPongService(ctx.Config.ID), nil + services := map[string]adapters.LifecycleConstructor{ + "ping-pong": func(ctx *adapters.ServiceContext, stack *node.Node) (node.Lifecycle, error) { + pps := newPingPongService(ctx.Config.ID) + stack.RegisterProtocols(pps.Protocols()) + return pps, nil }, } - adapters.RegisterServices(services) + adapters.RegisterLifecycles(services) // create the NodeAdapter var adapter adapters.NodeAdapter @@ -110,11 +111,7 @@ func (p *pingPongService) Protocols() []p2p.Protocol { }} } -func (p *pingPongService) APIs() []rpc.API { - return nil -} - -func (p *pingPongService) Start(server *p2p.Server) error { +func (p *pingPongService) Start() error { p.log.Info("ping-pong service starting") return nil } diff --git a/p2p/simulations/http_test.go b/p2p/simulations/http_test.go index 94ad51c57..50c92e048 100644 --- a/p2p/simulations/http_test.go +++ b/p2p/simulations/http_test.go @@ -63,12 +63,15 @@ type testService struct { state atomic.Value } -func newTestService(ctx *adapters.ServiceContext) (node.Service, error) { +func newTestService(ctx *adapters.ServiceContext, stack *node.Node) (node.Lifecycle, error) { svc := &testService{ id: ctx.Config.ID, peers: make(map[enode.ID]*testPeer), } svc.state.Store(ctx.Snapshot) + + stack.RegisterProtocols(svc.Protocols()) + stack.RegisterAPIs(svc.APIs()) return svc, nil } @@ -125,7 +128,7 @@ func (t *testService) APIs() []rpc.API { }} } -func (t *testService) Start(server *p2p.Server) error { +func (t *testService) Start() error { return nil } @@ -287,7 +290,7 @@ func (t *TestAPI) Events(ctx context.Context) (*rpc.Subscription, error) { return rpcSub, nil } -var testServices = adapters.Services{ +var testServices = adapters.LifecycleConstructors{ "test": newTestService, } diff --git a/p2p/simulations/network.go b/p2p/simulations/network.go index e9456267d..4ff9bf64f 100644 --- a/p2p/simulations/network.go +++ b/p2p/simulations/network.go @@ -110,8 +110,8 @@ func (net *Network) NewNodeWithConfig(conf *adapters.NodeConfig) (*Node, error) } // if no services are configured, use the default service - if len(conf.Services) == 0 { - conf.Services = []string{net.DefaultService} + if len(conf.Lifecycles) == 0 { + conf.Lifecycles = []string{net.DefaultService} } // use the NodeAdapter to create the node @@ -913,19 +913,19 @@ func (net *Network) snapshot(addServices []string, removeServices []string) (*Sn snap.Nodes[i].Snapshots = snapshots for _, addSvc := range addServices { haveSvc := false - for _, svc := range snap.Nodes[i].Node.Config.Services { + for _, svc := range snap.Nodes[i].Node.Config.Lifecycles { if svc == addSvc { haveSvc = true break } } if !haveSvc { - snap.Nodes[i].Node.Config.Services = append(snap.Nodes[i].Node.Config.Services, addSvc) + snap.Nodes[i].Node.Config.Lifecycles = append(snap.Nodes[i].Node.Config.Lifecycles, addSvc) } } if len(removeServices) > 0 { var cleanedServices []string - for _, svc := range snap.Nodes[i].Node.Config.Services { + for _, svc := range snap.Nodes[i].Node.Config.Lifecycles { haveSvc := false for _, rmSvc := range removeServices { if rmSvc == svc { @@ -938,7 +938,7 @@ func (net *Network) snapshot(addServices []string, removeServices []string) (*Sn } } - snap.Nodes[i].Node.Config.Services = cleanedServices + snap.Nodes[i].Node.Config.Lifecycles = cleanedServices } } for _, conn := range net.Conns { diff --git a/p2p/simulations/network_test.go b/p2p/simulations/network_test.go index 3c0f3602f..6e2ec385e 100644 --- a/p2p/simulations/network_test.go +++ b/p2p/simulations/network_test.go @@ -41,8 +41,8 @@ func TestSnapshot(t *testing.T) { // create snapshot from ring network // this is a minimal service, whose protocol will take exactly one message OR close of connection before quitting - adapter := adapters.NewSimAdapter(adapters.Services{ - "noopwoop": func(ctx *adapters.ServiceContext) (node.Service, error) { + adapter := adapters.NewSimAdapter(adapters.LifecycleConstructors{ + "noopwoop": func(ctx *adapters.ServiceContext, stack *node.Node) (node.Lifecycle, error) { return NewNoopService(nil), nil }, }) @@ -165,8 +165,8 @@ OUTER: // PART II // load snapshot and verify that exactly same connections are formed - adapter = adapters.NewSimAdapter(adapters.Services{ - "noopwoop": func(ctx *adapters.ServiceContext) (node.Service, error) { + adapter = adapters.NewSimAdapter(adapters.LifecycleConstructors{ + "noopwoop": func(ctx *adapters.ServiceContext, stack *node.Node) (node.Lifecycle, error) { return NewNoopService(nil), nil }, }) @@ -256,8 +256,8 @@ OuterTwo: t.Run("conns after load", func(t *testing.T) { // Create new network. n := NewNetwork( - adapters.NewSimAdapter(adapters.Services{ - "noopwoop": func(ctx *adapters.ServiceContext) (node.Service, error) { + adapters.NewSimAdapter(adapters.LifecycleConstructors{ + "noopwoop": func(ctx *adapters.ServiceContext, stack *node.Node) (node.Lifecycle, error) { return NewNoopService(nil), nil }, }), @@ -288,7 +288,7 @@ OuterTwo: // with each other and that a snapshot fully represents the desired topology func TestNetworkSimulation(t *testing.T) { // create simulation network with 20 testService nodes - adapter := adapters.NewSimAdapter(adapters.Services{ + adapter := adapters.NewSimAdapter(adapters.LifecycleConstructors{ "test": newTestService, }) network := NewNetwork(adapter, &NetworkConfig{ @@ -437,7 +437,7 @@ func createTestNodesWithProperty(property string, count int, network *Network) ( // It then tests again whilst excluding a node ID from being returned. // If a node ID is not returned, or more node IDs than expected are returned, the test fails. func TestGetNodeIDs(t *testing.T) { - adapter := adapters.NewSimAdapter(adapters.Services{ + adapter := adapters.NewSimAdapter(adapters.LifecycleConstructors{ "test": newTestService, }) network := NewNetwork(adapter, &NetworkConfig{ @@ -486,7 +486,7 @@ func TestGetNodeIDs(t *testing.T) { // It then tests again whilst excluding a node from being returned. // If a node is not returned, or more nodes than expected are returned, the test fails. func TestGetNodes(t *testing.T) { - adapter := adapters.NewSimAdapter(adapters.Services{ + adapter := adapters.NewSimAdapter(adapters.LifecycleConstructors{ "test": newTestService, }) network := NewNetwork(adapter, &NetworkConfig{ @@ -534,7 +534,7 @@ func TestGetNodes(t *testing.T) { // TestGetNodesByID creates a set of nodes and attempts to retrieve a subset of them by ID // If a node is not returned, or more nodes than expected are returned, the test fails. func TestGetNodesByID(t *testing.T) { - adapter := adapters.NewSimAdapter(adapters.Services{ + adapter := adapters.NewSimAdapter(adapters.LifecycleConstructors{ "test": newTestService, }) network := NewNetwork(adapter, &NetworkConfig{ @@ -579,7 +579,7 @@ func TestGetNodesByID(t *testing.T) { // GetNodesByProperty is then checked for correctness by comparing the nodes returned to those initially created. // If a node with a property is not found, or more nodes than expected are returned, the test fails. func TestGetNodesByProperty(t *testing.T) { - adapter := adapters.NewSimAdapter(adapters.Services{ + adapter := adapters.NewSimAdapter(adapters.LifecycleConstructors{ "test": newTestService, }) network := NewNetwork(adapter, &NetworkConfig{ @@ -624,7 +624,7 @@ func TestGetNodesByProperty(t *testing.T) { // GetNodeIDsByProperty is then checked for correctness by comparing the node IDs returned to those initially created. // If a node ID with a property is not found, or more nodes IDs than expected are returned, the test fails. func TestGetNodeIDsByProperty(t *testing.T) { - adapter := adapters.NewSimAdapter(adapters.Services{ + adapter := adapters.NewSimAdapter(adapters.LifecycleConstructors{ "test": newTestService, }) network := NewNetwork(adapter, &NetworkConfig{ @@ -705,8 +705,8 @@ func benchmarkMinimalServiceTmp(b *testing.B) { // this is a minimal service, whose protocol will close a channel upon run of protocol // making it possible to bench the time it takes for the service to start and protocol actually to be run protoCMap := make(map[enode.ID]map[enode.ID]chan struct{}) - adapter := adapters.NewSimAdapter(adapters.Services{ - "noopwoop": func(ctx *adapters.ServiceContext) (node.Service, error) { + adapter := adapters.NewSimAdapter(adapters.LifecycleConstructors{ + "noopwoop": func(ctx *adapters.ServiceContext, stack *node.Node) (node.Lifecycle, error) { protoCMap[ctx.Config.ID] = make(map[enode.ID]chan struct{}) svc := NewNoopService(protoCMap[ctx.Config.ID]) return svc, nil diff --git a/p2p/simulations/test.go b/p2p/simulations/test.go index f169e58e0..3d3523cac 100644 --- a/p2p/simulations/test.go +++ b/p2p/simulations/test.go @@ -66,7 +66,7 @@ func (t *NoopService) APIs() []rpc.API { return []rpc.API{} } -func (t *NoopService) Start(server *p2p.Server) error { +func (t *NoopService) Start() error { return nil } diff --git a/p2p/testing/peerpool.go b/p2p/testing/peerpool.go deleted file mode 100644 index 73798cc7d..000000000 --- a/p2p/testing/peerpool.go +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright 2018 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see . - -package testing - -import ( - "fmt" - "sync" - - "github.com/ledgerwatch/turbo-geth/log" - "github.com/ledgerwatch/turbo-geth/p2p/enode" -) - -type TestPeer interface { - ID() enode.ID - Drop() -} - -// TestPeerPool is an example peerPool to demonstrate registration of peer connections -type TestPeerPool struct { - lock sync.Mutex - peers map[enode.ID]TestPeer -} - -func NewTestPeerPool() *TestPeerPool { - return &TestPeerPool{peers: make(map[enode.ID]TestPeer)} -} - -func (p *TestPeerPool) Add(peer TestPeer) { - p.lock.Lock() - defer p.lock.Unlock() - log.Trace(fmt.Sprintf("pp add peer %v", peer.ID())) - p.peers[peer.ID()] = peer - -} - -func (p *TestPeerPool) Remove(peer TestPeer) { - p.lock.Lock() - defer p.lock.Unlock() - delete(p.peers, peer.ID()) -} - -func (p *TestPeerPool) Has(id enode.ID) bool { - p.lock.Lock() - defer p.lock.Unlock() - _, ok := p.peers[id] - return ok -} - -func (p *TestPeerPool) Get(id enode.ID) TestPeer { - p.lock.Lock() - defer p.lock.Unlock() - return p.peers[id] -} diff --git a/p2p/testing/protocolsession.go b/p2p/testing/protocolsession.go deleted file mode 100644 index 40304239b..000000000 --- a/p2p/testing/protocolsession.go +++ /dev/null @@ -1,283 +0,0 @@ -// Copyright 2018 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see . - -package testing - -import ( - "errors" - "fmt" - "sync" - "time" - - "github.com/ledgerwatch/turbo-geth/log" - "github.com/ledgerwatch/turbo-geth/p2p" - "github.com/ledgerwatch/turbo-geth/p2p/enode" - "github.com/ledgerwatch/turbo-geth/p2p/simulations/adapters" -) - -var errTimedOut = errors.New("timed out") - -// ProtocolSession is a quasi simulation of a pivot node running -// a service and a number of dummy peers that can send (trigger) or -// receive (expect) messages -type ProtocolSession struct { - Server *p2p.Server - Nodes []*enode.Node - adapter *adapters.SimAdapter - events chan *p2p.PeerEvent -} - -// Exchange is the basic units of protocol tests -// the triggers and expects in the arrays are run immediately and asynchronously -// thus one cannot have multiple expects for the SAME peer with DIFFERENT message types -// because it's unpredictable which expect will receive which message -// (with expect #1 and #2, messages might be sent #2 and #1, and both expects will complain about wrong message code) -// an exchange is defined on a session -type Exchange struct { - Label string - Triggers []Trigger - Expects []Expect - Timeout time.Duration -} - -// Trigger is part of the exchange, incoming message for the pivot node -// sent by a peer -type Trigger struct { - Msg interface{} // type of message to be sent - Code uint64 // code of message is given - Peer enode.ID // the peer to send the message to - Timeout time.Duration // timeout duration for the sending -} - -// Expect is part of an exchange, outgoing message from the pivot node -// received by a peer -type Expect struct { - Msg interface{} // type of message to expect - Code uint64 // code of message is now given - Peer enode.ID // the peer that expects the message - Timeout time.Duration // timeout duration for receiving -} - -// Disconnect represents a disconnect event, used and checked by TestDisconnected -type Disconnect struct { - Peer enode.ID // discconnected peer - Error error // disconnect reason -} - -// trigger sends messages from peers -func (s *ProtocolSession) trigger(trig Trigger) error { - simNode, ok := s.adapter.GetNode(trig.Peer) - if !ok { - return fmt.Errorf("trigger: peer %v does not exist (1- %v)", trig.Peer, len(s.Nodes)) - } - mockNode, ok := simNode.Services()[0].(*mockNode) - if !ok { - return fmt.Errorf("trigger: peer %v is not a mock", trig.Peer) - } - - errc := make(chan error) - - go func() { - log.Trace(fmt.Sprintf("trigger %v (%v)....", trig.Msg, trig.Code)) - errc <- mockNode.Trigger(&trig) - log.Trace(fmt.Sprintf("triggered %v (%v)", trig.Msg, trig.Code)) - }() - - t := trig.Timeout - if t == time.Duration(0) { - t = 1000 * time.Millisecond - } - select { - case err := <-errc: - return err - case <-time.After(t): - return fmt.Errorf("timout expecting %v to send to peer %v", trig.Msg, trig.Peer) - } -} - -// expect checks an expectation of a message sent out by the pivot node -func (s *ProtocolSession) expect(exps []Expect) error { - // construct a map of expectations for each node - peerExpects := make(map[enode.ID][]Expect) - for _, exp := range exps { - if exp.Msg == nil { - return errors.New("no message to expect") - } - peerExpects[exp.Peer] = append(peerExpects[exp.Peer], exp) - } - - // construct a map of mockNodes for each node - mockNodes := make(map[enode.ID]*mockNode) - for nodeID := range peerExpects { - simNode, ok := s.adapter.GetNode(nodeID) - if !ok { - return fmt.Errorf("trigger: peer %v does not exist (1- %v)", nodeID, len(s.Nodes)) - } - mockNode, ok := simNode.Services()[0].(*mockNode) - if !ok { - return fmt.Errorf("trigger: peer %v is not a mock", nodeID) - } - mockNodes[nodeID] = mockNode - } - - // done chanell cancels all created goroutines when function returns - done := make(chan struct{}) - defer close(done) - // errc catches the first error from - errc := make(chan error) - - wg := &sync.WaitGroup{} - wg.Add(len(mockNodes)) - for nodeID, mockNode := range mockNodes { - nodeID := nodeID - mockNode := mockNode - go func() { - defer wg.Done() - - // Sum all Expect timeouts to give the maximum - // time for all expectations to finish. - // mockNode.Expect checks all received messages against - // a list of expected messages and timeout for each - // of them can not be checked separately. - var t time.Duration - for _, exp := range peerExpects[nodeID] { - if exp.Timeout == time.Duration(0) { - t += 2000 * time.Millisecond - } else { - t += exp.Timeout - } - } - alarm := time.NewTimer(t) - defer alarm.Stop() - - // expectErrc is used to check if error returned - // from mockNode.Expect is not nil and to send it to - // errc only in that case. - // done channel will be closed when function - expectErrc := make(chan error) - go func() { - select { - case expectErrc <- mockNode.Expect(peerExpects[nodeID]...): - case <-done: - case <-alarm.C: - } - }() - - select { - case err := <-expectErrc: - if err != nil { - select { - case errc <- err: - case <-done: - case <-alarm.C: - errc <- errTimedOut - } - } - case <-done: - case <-alarm.C: - errc <- errTimedOut - } - - }() - } - - go func() { - wg.Wait() - // close errc when all goroutines finish to return nill err from errc - close(errc) - }() - - return <-errc -} - -// TestExchanges tests a series of exchanges against the session -func (s *ProtocolSession) TestExchanges(exchanges ...Exchange) error { - for i, e := range exchanges { - if err := s.testExchange(e); err != nil { - return fmt.Errorf("exchange #%d %q: %v", i, e.Label, err) - } - log.Trace(fmt.Sprintf("exchange #%d %q: run successfully", i, e.Label)) - } - return nil -} - -// testExchange tests a single Exchange. -// Default timeout value is 2 seconds. -func (s *ProtocolSession) testExchange(e Exchange) error { - errc := make(chan error) - done := make(chan struct{}) - defer close(done) - - go func() { - for _, trig := range e.Triggers { - err := s.trigger(trig) - if err != nil { - errc <- err - return - } - } - - select { - case errc <- s.expect(e.Expects): - case <-done: - } - }() - - // time out globally or finish when all expectations satisfied - t := e.Timeout - if t == 0 { - t = 2000 * time.Millisecond - } - alarm := time.NewTimer(t) - defer alarm.Stop() - select { - case err := <-errc: - return err - case <-alarm.C: - return errTimedOut - } -} - -// TestDisconnected tests the disconnections given as arguments -// the disconnect structs describe what disconnect error is expected on which peer -func (s *ProtocolSession) TestDisconnected(disconnects ...*Disconnect) error { - expects := make(map[enode.ID]error) - for _, disconnect := range disconnects { - expects[disconnect.Peer] = disconnect.Error - } - - timeout := time.After(time.Second) - for len(expects) > 0 { - select { - case event := <-s.events: - if event.Type != p2p.PeerEventTypeDrop { - continue - } - expectErr, ok := expects[event.Peer] - if !ok { - continue - } - - if !(expectErr == nil && event.Error == "" || expectErr != nil && expectErr.Error() == event.Error) { - return fmt.Errorf("unexpected error on peer %v. expected '%v', got '%v'", event.Peer, expectErr, event.Error) - } - delete(expects, event.Peer) - case <-timeout: - return fmt.Errorf("timed out waiting for peers to disconnect") - } - } - return nil -} diff --git a/p2p/testing/protocoltester.go b/p2p/testing/protocoltester.go deleted file mode 100644 index ade299881..000000000 --- a/p2p/testing/protocoltester.go +++ /dev/null @@ -1,284 +0,0 @@ -// Copyright 2018 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see . - -/* -the p2p/testing package provides a unit test scheme to check simple -protocol message exchanges with one pivot node and a number of dummy peers -The pivot test node runs a node.Service, the dummy peers run a mock node -that can be used to send and receive messages -*/ - -package testing - -import ( - "bytes" - "crypto/ecdsa" - "fmt" - "io" - "io/ioutil" - "strings" - "sync" - - "github.com/ledgerwatch/turbo-geth/log" - "github.com/ledgerwatch/turbo-geth/node" - "github.com/ledgerwatch/turbo-geth/p2p" - "github.com/ledgerwatch/turbo-geth/p2p/enode" - "github.com/ledgerwatch/turbo-geth/p2p/simulations" - "github.com/ledgerwatch/turbo-geth/p2p/simulations/adapters" - "github.com/ledgerwatch/turbo-geth/rlp" - "github.com/ledgerwatch/turbo-geth/rpc" -) - -// ProtocolTester is the tester environment used for unit testing protocol -// message exchanges. It uses p2p/simulations framework -type ProtocolTester struct { - *ProtocolSession - network *simulations.Network -} - -// NewProtocolTester constructs a new ProtocolTester -// it takes as argument the pivot node id, the number of dummy peers and the -// protocol run function called on a peer connection by the p2p server -func NewProtocolTester(prvkey *ecdsa.PrivateKey, nodeCount int, run func(*p2p.Peer, p2p.MsgReadWriter) error) *ProtocolTester { - services := adapters.Services{ - "test": func(ctx *adapters.ServiceContext) (node.Service, error) { - return &testNode{run}, nil - }, - "mock": func(ctx *adapters.ServiceContext) (node.Service, error) { - return newMockNode(), nil - }, - } - adapter := adapters.NewSimAdapter(services) - net := simulations.NewNetwork(adapter, &simulations.NetworkConfig{}) - nodeConfig := &adapters.NodeConfig{ - PrivateKey: prvkey, - EnableMsgEvents: true, - Services: []string{"test"}, - } - if _, err := net.NewNodeWithConfig(nodeConfig); err != nil { - panic(err.Error()) - } - if err := net.Start(nodeConfig.ID); err != nil { - panic(err.Error()) - } - - node := net.GetNode(nodeConfig.ID).Node.(*adapters.SimNode) - peers := make([]*adapters.NodeConfig, nodeCount) - nodes := make([]*enode.Node, nodeCount) - for i := 0; i < nodeCount; i++ { - peers[i] = adapters.RandomNodeConfig() - peers[i].Services = []string{"mock"} - if _, err := net.NewNodeWithConfig(peers[i]); err != nil { - panic(fmt.Sprintf("error initializing peer %v: %v", peers[i].ID, err)) - } - if err := net.Start(peers[i].ID); err != nil { - panic(fmt.Sprintf("error starting peer %v: %v", peers[i].ID, err)) - } - nodes[i] = peers[i].Node() - } - events := make(chan *p2p.PeerEvent, 1000) - node.SubscribeEvents(events) - ps := &ProtocolSession{ - Server: node.Server(), - Nodes: nodes, - adapter: adapter, - events: events, - } - self := &ProtocolTester{ - ProtocolSession: ps, - network: net, - } - - self.Connect(nodeConfig.ID, peers...) - - return self -} - -// Stop stops the p2p server -func (t *ProtocolTester) Stop() { - t.Server.Stop() - t.network.Shutdown() -} - -// Connect brings up the remote peer node and connects it using the -// p2p/simulations network connection with the in memory network adapter -func (t *ProtocolTester) Connect(selfID enode.ID, peers ...*adapters.NodeConfig) { - for _, peer := range peers { - log.Trace(fmt.Sprintf("connect to %v", peer.ID)) - if err := t.network.Connect(selfID, peer.ID); err != nil { - panic(fmt.Sprintf("error connecting to peer %v: %v", peer.ID, err)) - } - } - -} - -// testNode wraps a protocol run function and implements the node.Service -// interface -type testNode struct { - run func(*p2p.Peer, p2p.MsgReadWriter) error -} - -func (t *testNode) Protocols() []p2p.Protocol { - return []p2p.Protocol{{ - Length: 100, - Run: t.run, - }} -} - -func (t *testNode) APIs() []rpc.API { - return nil -} - -func (t *testNode) Start(server *p2p.Server) error { - return nil -} - -func (t *testNode) Stop() error { - return nil -} - -// mockNode is a testNode which doesn't actually run a protocol, instead -// exposing channels so that tests can manually trigger and expect certain -// messages -type mockNode struct { - testNode - - trigger chan *Trigger - expect chan []Expect - err chan error - stop chan struct{} - stopOnce sync.Once -} - -func newMockNode() *mockNode { - mock := &mockNode{ - trigger: make(chan *Trigger), - expect: make(chan []Expect), - err: make(chan error), - stop: make(chan struct{}), - } - mock.testNode.run = mock.Run - return mock -} - -// Run is a protocol run function which just loops waiting for tests to -// instruct it to either trigger or expect a message from the peer -func (m *mockNode) Run(peer *p2p.Peer, rw p2p.MsgReadWriter) error { - for { - select { - case trig := <-m.trigger: - wmsg := Wrap(trig.Msg) - m.err <- p2p.Send(rw, trig.Code, wmsg) - case exps := <-m.expect: - m.err <- expectMsgs(rw, exps) - case <-m.stop: - return nil - } - } -} - -func (m *mockNode) Trigger(trig *Trigger) error { - m.trigger <- trig - return <-m.err -} - -func (m *mockNode) Expect(exp ...Expect) error { - m.expect <- exp - return <-m.err -} - -func (m *mockNode) Stop() error { - m.stopOnce.Do(func() { close(m.stop) }) - return nil -} - -func expectMsgs(rw p2p.MsgReadWriter, exps []Expect) error { - matched := make([]bool, len(exps)) - for { - msg, err := rw.ReadMsg() - if err != nil { - if err == io.EOF { - break - } - return err - } - actualContent, err := ioutil.ReadAll(msg.Payload) - if err != nil { - return err - } - var found bool - for i, exp := range exps { - if exp.Code == msg.Code && bytes.Equal(actualContent, mustEncodeMsg(Wrap(exp.Msg))) { - if matched[i] { - return fmt.Errorf("message #%d received two times", i) - } - matched[i] = true - found = true - break - } - } - if !found { - expected := make([]string, 0) - for i, exp := range exps { - if matched[i] { - continue - } - expected = append(expected, fmt.Sprintf("code %d payload %x", exp.Code, mustEncodeMsg(Wrap(exp.Msg)))) - } - return fmt.Errorf("unexpected message code %d payload %x, expected %s", msg.Code, actualContent, strings.Join(expected, " or ")) - } - done := true - for _, m := range matched { - if !m { - done = false - break - } - } - if done { - return nil - } - } - for i, m := range matched { - if !m { - return fmt.Errorf("expected message #%d not received", i) - } - } - return nil -} - -// mustEncodeMsg uses rlp to encode a message. -// In case of error it panics. -func mustEncodeMsg(msg interface{}) []byte { - contentEnc, err := rlp.EncodeToBytes(msg) - if err != nil { - panic("content encode error: " + err.Error()) - } - return contentEnc -} - -type WrappedMsg struct { - Context []byte - Size uint32 - Payload []byte -} - -func Wrap(msg interface{}) interface{} { - data, _ := rlp.EncodeToBytes(msg) - return &WrappedMsg{ - Size: uint32(len(data)), - Payload: data, - } -}