Add --jwt-id flag (#13218)

* add jwt-id flag

* optimize unit test for jwt-id

* Add jwt-id to help text

* gofmt

---------

Co-authored-by: Preston Van Loon <pvanloon@offchainlabs.com>
This commit is contained in:
Brandon Liu 2023-12-06 03:02:25 +08:00 committed by GitHub
parent 705e98e3c3
commit c78d698d89
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 117 additions and 6 deletions

View File

@ -108,3 +108,10 @@ func WithFinalizedStateAtStartup(st state.BeaconState) Option {
return nil
}
}
func WithJwtId(jwtId string) Option {
return func(s *Service) error {
s.cfg.jwtId = jwtId
return nil
}
}

View File

@ -128,6 +128,7 @@ type config struct {
currHttpEndpoint network.Endpoint
headers []string
finalizedStateAtStartup state.BeaconState
jwtId string
}
// Service fetches important information about the canonical

View File

@ -675,6 +675,7 @@ func (b *BeaconNode) registerPOWChainService() error {
execution.WithStateGen(b.stateGen),
execution.WithBeaconNodeStatsUpdater(bs),
execution.WithFinalizedStateAtStartup(b.finalizedStateAtStartUp),
execution.WithJwtId(b.cliCtx.String(flags.JwtId.Name)),
)
web3Service, err := execution.NewService(b.ctx, opts...)
if err != nil {

View File

@ -54,6 +54,11 @@ var (
"This is not required if using an IPC connection.",
Value: "",
}
// JwtId is the id field of the JWT claims. The consensus layer client MAY use this to communicate a unique identifier for the individual consensus layer client
JwtId = &cli.StringFlag{
Name: "jwt-id",
Usage: "JWT claims id. Could be used to identify the client",
}
// DepositContractFlag defines a flag for the deposit contract address.
DepositContractFlag = &cli.StringFlag{
Name: "deposit-contract",

View File

@ -136,6 +136,7 @@ var appFlags = []cli.Flag{
genesis.StatePath,
genesis.BeaconAPIURL,
flags.SlasherDirFlag,
flags.JwtId,
}
func init() {

View File

@ -129,6 +129,7 @@ var appHelpFlagGroups = []flagGroup{
flags.SlasherDirFlag,
flags.LocalBlockValueBoost,
flags.BlobRetentionEpoch,
flags.JwtId,
checkpoint.BlockPath,
checkpoint.StatePath,
checkpoint.RemoteURL,

View File

@ -24,6 +24,7 @@ const DefaultRPCHTTPTimeout = time.Second * 30
type jwtTransport struct {
underlyingTransport http.RoundTripper
jwtSecret []byte
jwtId string
}
// RoundTrip ensures our transport implements http.RoundTripper interface from the
@ -32,12 +33,16 @@ type jwtTransport struct {
// an JWT bearer token in the Authorization request header of every outgoing request
// our HTTP client makes.
func (t *jwtTransport) RoundTrip(req *http.Request) (*http.Response, error) {
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
claims := jwt.MapClaims{
// Required claim for engine API auth. "iat" stands for issued at
// and it must be a unix timestamp that is +/- 5 seconds from the current
// timestamp at the moment the server verifies this value.
"iat": time.Now().Unix(),
})
}
if len(t.jwtId) > 0 {
claims["id"] = t.jwtId
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, err := token.SignedString(t.jwtSecret)
if err != nil {
return nil, errors.Wrap(err, "could not produce signed JWT token")

View File

@ -51,3 +51,91 @@ func TestJWTAuthTransport(t *testing.T) {
_, err := client.Get(srv.URL)
require.NoError(t, err)
}
func TestJWTWithId(t *testing.T) {
secret := bytesutil.PadTo([]byte("foo"), 32)
jwtId := "abc"
authTransport := &jwtTransport{
underlyingTransport: http.DefaultTransport,
jwtSecret: secret,
jwtId: jwtId,
}
client := &http.Client{
Timeout: DefaultRPCHTTPTimeout,
Transport: authTransport,
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
reqToken := r.Header.Get("Authorization")
splitToken := strings.Split(reqToken, "Bearer")
// The format should be `Bearer ${token}`.
require.Equal(t, 2, len(splitToken))
reqToken = strings.TrimSpace(splitToken[1])
token, err := jwt.Parse(reqToken, func(token *jwt.Token) (interface{}, error) {
// We should be doing HMAC signing.
_, ok := token.Method.(*jwt.SigningMethodHMAC)
require.Equal(t, true, ok)
return secret, nil
})
require.NoError(t, err)
require.Equal(t, true, token.Valid)
claims, ok := token.Claims.(jwt.MapClaims)
require.Equal(t, true, ok)
item, ok := claims["iat"]
require.Equal(t, true, ok)
iat, ok := item.(float64)
require.Equal(t, true, ok)
issuedAt := time.Unix(int64(iat), 0)
// The claims should have an "iat" field (issued at) that is at most, 5 seconds ago.
since := time.Since(issuedAt)
require.Equal(t, true, since <= time.Second*5)
// check jwt claims id
id, ok := claims["id"]
require.Equal(t, true, ok)
require.Equal(t, id, jwtId)
}))
defer srv.Close()
_, err := client.Get(srv.URL)
require.NoError(t, err)
}
func TestJWTWithoutId(t *testing.T) {
secret := bytesutil.PadTo([]byte("foo"), 32)
authTransport := &jwtTransport{
underlyingTransport: http.DefaultTransport,
jwtSecret: secret,
}
client := &http.Client{
Timeout: DefaultRPCHTTPTimeout,
Transport: authTransport,
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
reqToken := r.Header.Get("Authorization")
splitToken := strings.Split(reqToken, "Bearer")
// The format should be `Bearer ${token}`.
require.Equal(t, 2, len(splitToken))
reqToken = strings.TrimSpace(splitToken[1])
token, err := jwt.Parse(reqToken, func(token *jwt.Token) (interface{}, error) {
// We should be doing HMAC signing.
_, ok := token.Method.(*jwt.SigningMethodHMAC)
require.Equal(t, true, ok)
return secret, nil
})
require.NoError(t, err)
require.Equal(t, true, token.Valid)
claims, ok := token.Claims.(jwt.MapClaims)
require.Equal(t, true, ok)
item, ok := claims["iat"]
require.Equal(t, true, ok)
iat, ok := item.(float64)
require.Equal(t, true, ok)
issuedAt := time.Unix(int64(iat), 0)
// The claims should have an "iat" field (issued at) that is at most, 5 seconds ago.
since := time.Since(issuedAt)
require.Equal(t, true, since <= time.Second*5)
_, ok = claims["id"]
require.Equal(t, false, ok)
}))
defer srv.Close()
_, err := client.Get(srv.URL)
require.NoError(t, err)
}

View File

@ -24,6 +24,7 @@ type Endpoint struct {
type AuthorizationData struct {
Method authorization.Method
Value string
JwtId string
}
// Equals compares two endpoints for equality.
@ -37,7 +38,7 @@ func (e Endpoint) HttpClient() *http.Client {
if e.Auth.Method != authorization.Bearer {
return http.DefaultClient
}
return NewHttpClientWithSecret(e.Auth.Value)
return NewHttpClientWithSecret(e.Auth.Value, e.Auth.JwtId)
}
// Equals compares two authorization data objects for equality.
@ -112,10 +113,11 @@ func Method(auth string) authorization.Method {
// NewHttpClientWithSecret returns a http client that utilizes
// jwt authentication.
func NewHttpClientWithSecret(secret string) *http.Client {
func NewHttpClientWithSecret(secret, id string) *http.Client {
authTransport := &jwtTransport{
underlyingTransport: http.DefaultTransport,
jwtSecret: []byte(secret),
jwtId: id,
}
return &http.Client{
Timeout: DefaultRPCHTTPTimeout,

View File

@ -563,7 +563,7 @@ func (p *Builder) sendHttpRequest(req *http.Request, requestBytes []byte) (*http
client := &http.Client{}
if p.cfg.secret != "" {
client = network.NewHttpClientWithSecret(p.cfg.secret)
client = network.NewHttpClientWithSecret(p.cfg.secret, "")
}
proxyRes, err := client.Do(proxyReq)
if err != nil {

View File

@ -249,7 +249,7 @@ func (p *Proxy) sendHttpRequest(req *http.Request, requestBytes []byte) (*http.R
client := &http.Client{}
if p.cfg.secret != "" {
client = network.NewHttpClientWithSecret(p.cfg.secret)
client = network.NewHttpClientWithSecret(p.cfg.secret, "")
}
proxyRes, err := client.Do(proxyReq)
if err != nil {