mirror of
https://gitlab.com/pulsechaincom/prysm-pulse.git
synced 2024-12-21 19:20:38 +00:00
Add --jwt-id
flag (#13218)
* add jwt-id flag * optimize unit test for jwt-id * Add jwt-id to help text * gofmt --------- Co-authored-by: Preston Van Loon <pvanloon@offchainlabs.com>
This commit is contained in:
parent
705e98e3c3
commit
c78d698d89
@ -108,3 +108,10 @@ func WithFinalizedStateAtStartup(st state.BeaconState) Option {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func WithJwtId(jwtId string) Option {
|
||||
return func(s *Service) error {
|
||||
s.cfg.jwtId = jwtId
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
@ -128,6 +128,7 @@ type config struct {
|
||||
currHttpEndpoint network.Endpoint
|
||||
headers []string
|
||||
finalizedStateAtStartup state.BeaconState
|
||||
jwtId string
|
||||
}
|
||||
|
||||
// Service fetches important information about the canonical
|
||||
|
@ -675,6 +675,7 @@ func (b *BeaconNode) registerPOWChainService() error {
|
||||
execution.WithStateGen(b.stateGen),
|
||||
execution.WithBeaconNodeStatsUpdater(bs),
|
||||
execution.WithFinalizedStateAtStartup(b.finalizedStateAtStartUp),
|
||||
execution.WithJwtId(b.cliCtx.String(flags.JwtId.Name)),
|
||||
)
|
||||
web3Service, err := execution.NewService(b.ctx, opts...)
|
||||
if err != nil {
|
||||
|
@ -54,6 +54,11 @@ var (
|
||||
"This is not required if using an IPC connection.",
|
||||
Value: "",
|
||||
}
|
||||
// JwtId is the id field of the JWT claims. The consensus layer client MAY use this to communicate a unique identifier for the individual consensus layer client
|
||||
JwtId = &cli.StringFlag{
|
||||
Name: "jwt-id",
|
||||
Usage: "JWT claims id. Could be used to identify the client",
|
||||
}
|
||||
// DepositContractFlag defines a flag for the deposit contract address.
|
||||
DepositContractFlag = &cli.StringFlag{
|
||||
Name: "deposit-contract",
|
||||
|
@ -136,6 +136,7 @@ var appFlags = []cli.Flag{
|
||||
genesis.StatePath,
|
||||
genesis.BeaconAPIURL,
|
||||
flags.SlasherDirFlag,
|
||||
flags.JwtId,
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
@ -129,6 +129,7 @@ var appHelpFlagGroups = []flagGroup{
|
||||
flags.SlasherDirFlag,
|
||||
flags.LocalBlockValueBoost,
|
||||
flags.BlobRetentionEpoch,
|
||||
flags.JwtId,
|
||||
checkpoint.BlockPath,
|
||||
checkpoint.StatePath,
|
||||
checkpoint.RemoteURL,
|
||||
|
@ -24,6 +24,7 @@ const DefaultRPCHTTPTimeout = time.Second * 30
|
||||
type jwtTransport struct {
|
||||
underlyingTransport http.RoundTripper
|
||||
jwtSecret []byte
|
||||
jwtId string
|
||||
}
|
||||
|
||||
// RoundTrip ensures our transport implements http.RoundTripper interface from the
|
||||
@ -32,12 +33,16 @@ type jwtTransport struct {
|
||||
// an JWT bearer token in the Authorization request header of every outgoing request
|
||||
// our HTTP client makes.
|
||||
func (t *jwtTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
claims := jwt.MapClaims{
|
||||
// Required claim for engine API auth. "iat" stands for issued at
|
||||
// and it must be a unix timestamp that is +/- 5 seconds from the current
|
||||
// timestamp at the moment the server verifies this value.
|
||||
"iat": time.Now().Unix(),
|
||||
})
|
||||
}
|
||||
if len(t.jwtId) > 0 {
|
||||
claims["id"] = t.jwtId
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
tokenString, err := token.SignedString(t.jwtSecret)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "could not produce signed JWT token")
|
||||
|
@ -51,3 +51,91 @@ func TestJWTAuthTransport(t *testing.T) {
|
||||
_, err := client.Get(srv.URL)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestJWTWithId(t *testing.T) {
|
||||
secret := bytesutil.PadTo([]byte("foo"), 32)
|
||||
jwtId := "abc"
|
||||
authTransport := &jwtTransport{
|
||||
underlyingTransport: http.DefaultTransport,
|
||||
jwtSecret: secret,
|
||||
jwtId: jwtId,
|
||||
}
|
||||
client := &http.Client{
|
||||
Timeout: DefaultRPCHTTPTimeout,
|
||||
Transport: authTransport,
|
||||
}
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
reqToken := r.Header.Get("Authorization")
|
||||
splitToken := strings.Split(reqToken, "Bearer")
|
||||
// The format should be `Bearer ${token}`.
|
||||
require.Equal(t, 2, len(splitToken))
|
||||
reqToken = strings.TrimSpace(splitToken[1])
|
||||
token, err := jwt.Parse(reqToken, func(token *jwt.Token) (interface{}, error) {
|
||||
// We should be doing HMAC signing.
|
||||
_, ok := token.Method.(*jwt.SigningMethodHMAC)
|
||||
require.Equal(t, true, ok)
|
||||
return secret, nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, true, token.Valid)
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
require.Equal(t, true, ok)
|
||||
item, ok := claims["iat"]
|
||||
require.Equal(t, true, ok)
|
||||
iat, ok := item.(float64)
|
||||
require.Equal(t, true, ok)
|
||||
issuedAt := time.Unix(int64(iat), 0)
|
||||
// The claims should have an "iat" field (issued at) that is at most, 5 seconds ago.
|
||||
since := time.Since(issuedAt)
|
||||
require.Equal(t, true, since <= time.Second*5)
|
||||
// check jwt claims id
|
||||
id, ok := claims["id"]
|
||||
require.Equal(t, true, ok)
|
||||
require.Equal(t, id, jwtId)
|
||||
}))
|
||||
defer srv.Close()
|
||||
_, err := client.Get(srv.URL)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestJWTWithoutId(t *testing.T) {
|
||||
secret := bytesutil.PadTo([]byte("foo"), 32)
|
||||
authTransport := &jwtTransport{
|
||||
underlyingTransport: http.DefaultTransport,
|
||||
jwtSecret: secret,
|
||||
}
|
||||
client := &http.Client{
|
||||
Timeout: DefaultRPCHTTPTimeout,
|
||||
Transport: authTransport,
|
||||
}
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
reqToken := r.Header.Get("Authorization")
|
||||
splitToken := strings.Split(reqToken, "Bearer")
|
||||
// The format should be `Bearer ${token}`.
|
||||
require.Equal(t, 2, len(splitToken))
|
||||
reqToken = strings.TrimSpace(splitToken[1])
|
||||
token, err := jwt.Parse(reqToken, func(token *jwt.Token) (interface{}, error) {
|
||||
// We should be doing HMAC signing.
|
||||
_, ok := token.Method.(*jwt.SigningMethodHMAC)
|
||||
require.Equal(t, true, ok)
|
||||
return secret, nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, true, token.Valid)
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
require.Equal(t, true, ok)
|
||||
item, ok := claims["iat"]
|
||||
require.Equal(t, true, ok)
|
||||
iat, ok := item.(float64)
|
||||
require.Equal(t, true, ok)
|
||||
issuedAt := time.Unix(int64(iat), 0)
|
||||
// The claims should have an "iat" field (issued at) that is at most, 5 seconds ago.
|
||||
since := time.Since(issuedAt)
|
||||
require.Equal(t, true, since <= time.Second*5)
|
||||
_, ok = claims["id"]
|
||||
require.Equal(t, false, ok)
|
||||
}))
|
||||
defer srv.Close()
|
||||
_, err := client.Get(srv.URL)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
@ -24,6 +24,7 @@ type Endpoint struct {
|
||||
type AuthorizationData struct {
|
||||
Method authorization.Method
|
||||
Value string
|
||||
JwtId string
|
||||
}
|
||||
|
||||
// Equals compares two endpoints for equality.
|
||||
@ -37,7 +38,7 @@ func (e Endpoint) HttpClient() *http.Client {
|
||||
if e.Auth.Method != authorization.Bearer {
|
||||
return http.DefaultClient
|
||||
}
|
||||
return NewHttpClientWithSecret(e.Auth.Value)
|
||||
return NewHttpClientWithSecret(e.Auth.Value, e.Auth.JwtId)
|
||||
}
|
||||
|
||||
// Equals compares two authorization data objects for equality.
|
||||
@ -112,10 +113,11 @@ func Method(auth string) authorization.Method {
|
||||
|
||||
// NewHttpClientWithSecret returns a http client that utilizes
|
||||
// jwt authentication.
|
||||
func NewHttpClientWithSecret(secret string) *http.Client {
|
||||
func NewHttpClientWithSecret(secret, id string) *http.Client {
|
||||
authTransport := &jwtTransport{
|
||||
underlyingTransport: http.DefaultTransport,
|
||||
jwtSecret: []byte(secret),
|
||||
jwtId: id,
|
||||
}
|
||||
return &http.Client{
|
||||
Timeout: DefaultRPCHTTPTimeout,
|
||||
|
@ -563,7 +563,7 @@ func (p *Builder) sendHttpRequest(req *http.Request, requestBytes []byte) (*http
|
||||
|
||||
client := &http.Client{}
|
||||
if p.cfg.secret != "" {
|
||||
client = network.NewHttpClientWithSecret(p.cfg.secret)
|
||||
client = network.NewHttpClientWithSecret(p.cfg.secret, "")
|
||||
}
|
||||
proxyRes, err := client.Do(proxyReq)
|
||||
if err != nil {
|
||||
|
@ -249,7 +249,7 @@ func (p *Proxy) sendHttpRequest(req *http.Request, requestBytes []byte) (*http.R
|
||||
|
||||
client := &http.Client{}
|
||||
if p.cfg.secret != "" {
|
||||
client = network.NewHttpClientWithSecret(p.cfg.secret)
|
||||
client = network.NewHttpClientWithSecret(p.cfg.secret, "")
|
||||
}
|
||||
proxyRes, err := client.Do(proxyReq)
|
||||
if err != nil {
|
||||
|
Loading…
Reference in New Issue
Block a user