diff --git a/cmd/rpcdaemon/cli/config.go b/cmd/rpcdaemon/cli/config.go index cfaa65b6b..c8ccb76b0 100644 --- a/cmd/rpcdaemon/cli/config.go +++ b/cmd/rpcdaemon/cli/config.go @@ -4,8 +4,8 @@ import ( "context" "crypto/rand" "encoding/binary" + "errors" "fmt" - "io/ioutil" "net" "net/http" "os" @@ -32,6 +32,7 @@ import ( "github.com/ledgerwatch/erigon/cmd/rpcdaemon/services" "github.com/ledgerwatch/erigon/cmd/utils" "github.com/ledgerwatch/erigon/common" + "github.com/ledgerwatch/erigon/common/hexutil" "github.com/ledgerwatch/erigon/common/paths" "github.com/ledgerwatch/erigon/core/rawdb" "github.com/ledgerwatch/erigon/eth/ethconfig" @@ -517,35 +518,44 @@ func isWebsocket(r *http.Request) bool { strings.Contains(strings.ToLower(r.Header.Get("Connection")), "upgrade") } -func createHandler(cfg httpcfg.HttpCfg, apiList []rpc.API, httpHandler http.Handler, wsHandler http.Handler, isAuth bool) (http.Handler, error) { - var jwtVerificationKey []byte - var err error - - if isAuth { - // If no file is specified we generate a key in jwt.hex - if cfg.JWTSecretPath == "" { - jwtVerificationKey := make([]byte, 32) - rand.Read(jwtVerificationKey) - jwtVerificationKey = []byte(common.Bytes2Hex(jwtVerificationKey)) - f, err := os.Create(JwtDefaultFile) - if err != nil { - return nil, err - } - defer f.Close() - - _, err = f.Write(jwtVerificationKey) - if err != nil { - return nil, err - } - } else { - jwtVerificationKey, err = ioutil.ReadFile(cfg.JWTSecretPath) - if err != nil { - return nil, err - } - if len(jwtVerificationKey) != 64 { - return nil, fmt.Errorf("error: invalid size of verification key in %s", cfg.JWTSecretPath) - } +// obtainJWTSecret loads the jwt-secret, either from the provided config, +// or from the default location. If neither of those are present, it generates +// a new secret and stores to the default location. +func obtainJWTSecret(cfg httpcfg.HttpCfg) ([]byte, error) { + var fileName string + if len(cfg.JWTSecretPath) > 0 { + // path provided + fileName = cfg.JWTSecretPath + } else { + // no path provided, use default + fileName = JwtDefaultFile + } + // try reading from file + log.Info("Reading JWT secret", "path", fileName) + if data, err := os.ReadFile(fileName); err == nil { + jwtSecret := common.FromHex(strings.TrimSpace(string(data))) + if len(jwtSecret) == 32 { + return jwtSecret, nil } + log.Error("Invalid JWT secret", "path", fileName, "length", len(jwtSecret)) + return nil, errors.New("invalid JWT secret") + } + // Need to generate one + jwtSecret := make([]byte, 32) + rand.Read(jwtSecret) + + if err := os.WriteFile(fileName, []byte(hexutil.Encode(jwtSecret)), 0600); err != nil { + return nil, err + } + log.Info("Generated JWT secret", "path", fileName) + return jwtSecret, nil +} + +func createHandler(cfg httpcfg.HttpCfg, apiList []rpc.API, httpHandler http.Handler, wsHandler http.Handler, isAuth bool) (http.Handler, error) { + // Finds jwt secret + jwtVerificationKey, err := obtainJWTSecret(cfg) + if err != nil { + return nil, err } var handler http.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -559,26 +569,41 @@ func createHandler(cfg httpcfg.HttpCfg, apiList []rpc.API, httpHandler http.Hand } if isAuth { + var tokenStr string // Check if JWT signature is correct - tokenStr, ok := r.Header["Authorization"] - if !ok { - w.WriteHeader(http.StatusBadRequest) + if auth := r.Header.Get("Authorization"); strings.HasPrefix(auth, "Bearer ") { + tokenStr = strings.TrimPrefix(auth, "Bearer ") + } + + if len(tokenStr) == 0 { + http.Error(w, "missing token", http.StatusForbidden) return } - claims := jwt.StandardClaims{} - tkn, err := jwt.ParseWithClaims(strings.Replace(tokenStr[0], "Bearer ", "", 1), &claims, func(token *jwt.Token) (interface{}, error) { + keyFunc := func(token *jwt.Token) (interface{}, error) { return jwtVerificationKey, nil - }) - if err != nil || !tkn.Valid { - w.WriteHeader(http.StatusUnauthorized) - return } - // Validate time of iat - now := time.Now().Unix() - if claims.IssuedAt > now+JwtTokenExpiry.Nanoseconds() && claims.IssuedAt < now-JwtTokenExpiry.Nanoseconds() { - w.WriteHeader(http.StatusUnauthorized) - return + claims := jwt.RegisteredClaims{} + // We explicitly set only HS256 allowed, and also disables the + // claim-check: the RegisteredClaims internally requires 'iat' to + // be no later than 'now', but we allow for a bit of drift. + token, err := jwt.ParseWithClaims(tokenStr, &claims, keyFunc, + jwt.WithValidMethods([]string{"HS256"}), + jwt.WithoutClaimsValidation()) + + switch { + case err != nil: + http.Error(w, err.Error(), http.StatusForbidden) + case !token.Valid: + http.Error(w, "invalid token", http.StatusForbidden) + case !claims.VerifyExpiresAt(time.Now(), false): // optional + http.Error(w, "token is expired", http.StatusForbidden) + case claims.IssuedAt == nil: + http.Error(w, "missing issued-at", http.StatusForbidden) + case time.Since(claims.IssuedAt.Time) > JwtTokenExpiry: + http.Error(w, "stale token", http.StatusForbidden) + case time.Until(claims.IssuedAt.Time) > JwtTokenExpiry: + http.Error(w, "future token", http.StatusForbidden) } }