diff --git a/beacon-chain/execution/options.go b/beacon-chain/execution/options.go index 3130b4b86..e09f67b7d 100644 --- a/beacon-chain/execution/options.go +++ b/beacon-chain/execution/options.go @@ -108,3 +108,10 @@ func WithFinalizedStateAtStartup(st state.BeaconState) Option { return nil } } + +func WithJwtId(jwtId string) Option { + return func(s *Service) error { + s.cfg.jwtId = jwtId + return nil + } +} diff --git a/beacon-chain/execution/service.go b/beacon-chain/execution/service.go index 386d66349..cc17b79aa 100644 --- a/beacon-chain/execution/service.go +++ b/beacon-chain/execution/service.go @@ -128,6 +128,7 @@ type config struct { currHttpEndpoint network.Endpoint headers []string finalizedStateAtStartup state.BeaconState + jwtId string } // Service fetches important information about the canonical diff --git a/beacon-chain/node/node.go b/beacon-chain/node/node.go index 37f2fbf42..1be5d2ef5 100644 --- a/beacon-chain/node/node.go +++ b/beacon-chain/node/node.go @@ -675,6 +675,7 @@ func (b *BeaconNode) registerPOWChainService() error { execution.WithStateGen(b.stateGen), execution.WithBeaconNodeStatsUpdater(bs), execution.WithFinalizedStateAtStartup(b.finalizedStateAtStartUp), + execution.WithJwtId(b.cliCtx.String(flags.JwtId.Name)), ) web3Service, err := execution.NewService(b.ctx, opts...) if err != nil { diff --git a/cmd/beacon-chain/flags/base.go b/cmd/beacon-chain/flags/base.go index 5f2c9402c..a23ce2ab7 100644 --- a/cmd/beacon-chain/flags/base.go +++ b/cmd/beacon-chain/flags/base.go @@ -54,6 +54,11 @@ var ( "This is not required if using an IPC connection.", Value: "", } + // JwtId is the id field of the JWT claims. The consensus layer client MAY use this to communicate a unique identifier for the individual consensus layer client + JwtId = &cli.StringFlag{ + Name: "jwt-id", + Usage: "JWT claims id. Could be used to identify the client", + } // DepositContractFlag defines a flag for the deposit contract address. DepositContractFlag = &cli.StringFlag{ Name: "deposit-contract", diff --git a/cmd/beacon-chain/main.go b/cmd/beacon-chain/main.go index b2a64ecd5..4298e98ab 100644 --- a/cmd/beacon-chain/main.go +++ b/cmd/beacon-chain/main.go @@ -136,6 +136,7 @@ var appFlags = []cli.Flag{ genesis.StatePath, genesis.BeaconAPIURL, flags.SlasherDirFlag, + flags.JwtId, } func init() { diff --git a/cmd/beacon-chain/usage.go b/cmd/beacon-chain/usage.go index ab7835af5..1d0e67dc6 100644 --- a/cmd/beacon-chain/usage.go +++ b/cmd/beacon-chain/usage.go @@ -129,6 +129,7 @@ var appHelpFlagGroups = []flagGroup{ flags.SlasherDirFlag, flags.LocalBlockValueBoost, flags.BlobRetentionEpoch, + flags.JwtId, checkpoint.BlockPath, checkpoint.StatePath, checkpoint.RemoteURL, diff --git a/network/auth.go b/network/auth.go index 2758bd07e..a17f81d7a 100644 --- a/network/auth.go +++ b/network/auth.go @@ -24,6 +24,7 @@ const DefaultRPCHTTPTimeout = time.Second * 30 type jwtTransport struct { underlyingTransport http.RoundTripper jwtSecret []byte + jwtId string } // RoundTrip ensures our transport implements http.RoundTripper interface from the @@ -32,12 +33,16 @@ type jwtTransport struct { // an JWT bearer token in the Authorization request header of every outgoing request // our HTTP client makes. func (t *jwtTransport) RoundTrip(req *http.Request) (*http.Response, error) { - token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + claims := jwt.MapClaims{ // Required claim for engine API auth. "iat" stands for issued at // and it must be a unix timestamp that is +/- 5 seconds from the current // timestamp at the moment the server verifies this value. "iat": time.Now().Unix(), - }) + } + if len(t.jwtId) > 0 { + claims["id"] = t.jwtId + } + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) tokenString, err := token.SignedString(t.jwtSecret) if err != nil { return nil, errors.Wrap(err, "could not produce signed JWT token") diff --git a/network/auth_test.go b/network/auth_test.go index 4e350d0f9..c2f3b96f7 100644 --- a/network/auth_test.go +++ b/network/auth_test.go @@ -51,3 +51,91 @@ func TestJWTAuthTransport(t *testing.T) { _, err := client.Get(srv.URL) require.NoError(t, err) } + +func TestJWTWithId(t *testing.T) { + secret := bytesutil.PadTo([]byte("foo"), 32) + jwtId := "abc" + authTransport := &jwtTransport{ + underlyingTransport: http.DefaultTransport, + jwtSecret: secret, + jwtId: jwtId, + } + client := &http.Client{ + Timeout: DefaultRPCHTTPTimeout, + Transport: authTransport, + } + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + reqToken := r.Header.Get("Authorization") + splitToken := strings.Split(reqToken, "Bearer") + // The format should be `Bearer ${token}`. + require.Equal(t, 2, len(splitToken)) + reqToken = strings.TrimSpace(splitToken[1]) + token, err := jwt.Parse(reqToken, func(token *jwt.Token) (interface{}, error) { + // We should be doing HMAC signing. + _, ok := token.Method.(*jwt.SigningMethodHMAC) + require.Equal(t, true, ok) + return secret, nil + }) + require.NoError(t, err) + require.Equal(t, true, token.Valid) + claims, ok := token.Claims.(jwt.MapClaims) + require.Equal(t, true, ok) + item, ok := claims["iat"] + require.Equal(t, true, ok) + iat, ok := item.(float64) + require.Equal(t, true, ok) + issuedAt := time.Unix(int64(iat), 0) + // The claims should have an "iat" field (issued at) that is at most, 5 seconds ago. + since := time.Since(issuedAt) + require.Equal(t, true, since <= time.Second*5) + // check jwt claims id + id, ok := claims["id"] + require.Equal(t, true, ok) + require.Equal(t, id, jwtId) + })) + defer srv.Close() + _, err := client.Get(srv.URL) + require.NoError(t, err) +} + +func TestJWTWithoutId(t *testing.T) { + secret := bytesutil.PadTo([]byte("foo"), 32) + authTransport := &jwtTransport{ + underlyingTransport: http.DefaultTransport, + jwtSecret: secret, + } + client := &http.Client{ + Timeout: DefaultRPCHTTPTimeout, + Transport: authTransport, + } + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + reqToken := r.Header.Get("Authorization") + splitToken := strings.Split(reqToken, "Bearer") + // The format should be `Bearer ${token}`. + require.Equal(t, 2, len(splitToken)) + reqToken = strings.TrimSpace(splitToken[1]) + token, err := jwt.Parse(reqToken, func(token *jwt.Token) (interface{}, error) { + // We should be doing HMAC signing. + _, ok := token.Method.(*jwt.SigningMethodHMAC) + require.Equal(t, true, ok) + return secret, nil + }) + require.NoError(t, err) + require.Equal(t, true, token.Valid) + claims, ok := token.Claims.(jwt.MapClaims) + require.Equal(t, true, ok) + item, ok := claims["iat"] + require.Equal(t, true, ok) + iat, ok := item.(float64) + require.Equal(t, true, ok) + issuedAt := time.Unix(int64(iat), 0) + // The claims should have an "iat" field (issued at) that is at most, 5 seconds ago. + since := time.Since(issuedAt) + require.Equal(t, true, since <= time.Second*5) + _, ok = claims["id"] + require.Equal(t, false, ok) + })) + defer srv.Close() + _, err := client.Get(srv.URL) + require.NoError(t, err) +} diff --git a/network/endpoint.go b/network/endpoint.go index f7d240632..76057c72b 100644 --- a/network/endpoint.go +++ b/network/endpoint.go @@ -24,6 +24,7 @@ type Endpoint struct { type AuthorizationData struct { Method authorization.Method Value string + JwtId string } // Equals compares two endpoints for equality. @@ -37,7 +38,7 @@ func (e Endpoint) HttpClient() *http.Client { if e.Auth.Method != authorization.Bearer { return http.DefaultClient } - return NewHttpClientWithSecret(e.Auth.Value) + return NewHttpClientWithSecret(e.Auth.Value, e.Auth.JwtId) } // Equals compares two authorization data objects for equality. @@ -112,10 +113,11 @@ func Method(auth string) authorization.Method { // NewHttpClientWithSecret returns a http client that utilizes // jwt authentication. -func NewHttpClientWithSecret(secret string) *http.Client { +func NewHttpClientWithSecret(secret, id string) *http.Client { authTransport := &jwtTransport{ underlyingTransport: http.DefaultTransport, jwtSecret: []byte(secret), + jwtId: id, } return &http.Client{ Timeout: DefaultRPCHTTPTimeout, diff --git a/testing/middleware/builder/builder.go b/testing/middleware/builder/builder.go index bd0670c9a..c09b2b7cc 100644 --- a/testing/middleware/builder/builder.go +++ b/testing/middleware/builder/builder.go @@ -563,7 +563,7 @@ func (p *Builder) sendHttpRequest(req *http.Request, requestBytes []byte) (*http client := &http.Client{} if p.cfg.secret != "" { - client = network.NewHttpClientWithSecret(p.cfg.secret) + client = network.NewHttpClientWithSecret(p.cfg.secret, "") } proxyRes, err := client.Do(proxyReq) if err != nil { diff --git a/testing/middleware/engine-api-proxy/proxy.go b/testing/middleware/engine-api-proxy/proxy.go index 176a3886d..507c77ca3 100644 --- a/testing/middleware/engine-api-proxy/proxy.go +++ b/testing/middleware/engine-api-proxy/proxy.go @@ -249,7 +249,7 @@ func (p *Proxy) sendHttpRequest(req *http.Request, requestBytes []byte) (*http.R client := &http.Client{} if p.cfg.secret != "" { - client = network.NewHttpClientWithSecret(p.cfg.secret) + client = network.NewHttpClientWithSecret(p.cfg.secret, "") } proxyRes, err := client.Do(proxyReq) if err != nil {