diff --git a/beacon-chain/node/BUILD.bazel b/beacon-chain/node/BUILD.bazel index 9e848c590..27feb2318 100644 --- a/beacon-chain/node/BUILD.bazel +++ b/beacon-chain/node/BUILD.bazel @@ -85,6 +85,7 @@ go_test( "//beacon-chain/monitor:go_default_library", "//cmd:go_default_library", "//cmd/beacon-chain/flags:go_default_library", + "//config/features:go_default_library", "//config/fieldparams:go_default_library", "//config/params:go_default_library", "//consensus-types/primitives:go_default_library", diff --git a/beacon-chain/node/node.go b/beacon-chain/node/node.go index dd2d9c7e1..a0c66b51e 100644 --- a/beacon-chain/node/node.go +++ b/beacon-chain/node/node.go @@ -117,6 +117,9 @@ func New(cliCtx *cli.Context, opts ...Option) (*BeaconNode, error) { return nil, err } prereqs.WarnIfPlatformNotSupported(cliCtx.Context) + if hasNetworkFlag(cliCtx) && cliCtx.IsSet(cmd.ChainConfigFileFlag.Name) { + return nil, fmt.Errorf("%s cannot be passed concurrently with network flag", cmd.ChainConfigFileFlag.Name) + } if err := features.ConfigureBeaconChain(cliCtx); err != nil { return nil, err } @@ -970,3 +973,14 @@ func (b *BeaconNode) registerBuilderService() error { } return b.services.RegisterService(svc) } + +func hasNetworkFlag(cliCtx *cli.Context) bool { + for _, flag := range features.NetworkFlags { + for _, name := range flag.Names() { + if cliCtx.IsSet(name) { + return true + } + } + } + return false +} diff --git a/beacon-chain/node/node_test.go b/beacon-chain/node/node_test.go index 9ae39ae62..2c7448971 100644 --- a/beacon-chain/node/node_test.go +++ b/beacon-chain/node/node_test.go @@ -18,6 +18,7 @@ import ( "github.com/prysmaticlabs/prysm/v3/beacon-chain/monitor" "github.com/prysmaticlabs/prysm/v3/cmd" "github.com/prysmaticlabs/prysm/v3/cmd/beacon-chain/flags" + "github.com/prysmaticlabs/prysm/v3/config/features" fieldparams "github.com/prysmaticlabs/prysm/v3/config/fieldparams" "github.com/prysmaticlabs/prysm/v3/config/params" ethpb "github.com/prysmaticlabs/prysm/v3/proto/prysm/v1alpha1" @@ -171,3 +172,45 @@ func TestMonitor_RegisteredCorrectly(t *testing.T) { require.Equal(t, true, mService.TrackedValidators[2]) require.Equal(t, false, mService.TrackedValidators[100]) } + +func Test_hasNetworkFlag(t *testing.T) { + tests := []struct { + name string + networkName string + networkValue string + want bool + }{ + { + name: "Prater testnet", + networkName: features.PraterTestnet.Name, + networkValue: "prater", + want: true, + }, + { + name: "Mainnet", + networkName: features.Mainnet.Name, + networkValue: "mainnet", + want: true, + }, + { + name: "No network flag", + networkName: "", + networkValue: "", + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + set := flag.NewFlagSet("test", 0) + set.String(tt.networkName, tt.networkValue, tt.name) + + cliCtx := cli.NewContext(&cli.App{}, set, nil) + err := cliCtx.Set(tt.networkName, tt.networkValue) + require.NoError(t, err) + + if got := hasNetworkFlag(cliCtx); got != tt.want { + t.Errorf("hasNetworkFlag() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/config/features/flags.go b/config/features/flags.go index 468c0c7f2..150e64be3 100644 --- a/config/features/flags.go +++ b/config/features/flags.go @@ -179,3 +179,11 @@ var BeaconChainFlags = append(deprecatedBeaconFlags, append(deprecatedFlags, []c var E2EBeaconChainFlags = []string{ "--dev", } + +// NetworkFlags contains a list of network flags. +var NetworkFlags = []cli.Flag{ + Mainnet, + PraterTestnet, + RopstenTestnet, + SepoliaTestnet, +}