Use middleware to handle comma-separated query params (#12995)

This commit is contained in:
Radosław Kapka 2023-10-04 15:49:42 +02:00 committed by GitHub
parent f37301c0c0
commit 7454041356
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 35 additions and 69 deletions

View File

@ -8,6 +8,7 @@ go_library(
"node.go",
"options.go",
"prometheus.go",
"router.go",
],
importpath = "github.com/prysmaticlabs/prysm/v4/beacon-chain/node",
visibility = [
@ -40,6 +41,7 @@ go_library(
"//beacon-chain/p2p:go_default_library",
"//beacon-chain/rpc:go_default_library",
"//beacon-chain/rpc/apimiddleware:go_default_library",
"//beacon-chain/rpc/eth/helpers:go_default_library",
"//beacon-chain/slasher:go_default_library",
"//beacon-chain/startup:go_default_library",
"//beacon-chain/state:go_default_library",

View File

@ -271,6 +271,7 @@ func New(cliCtx *cli.Context, opts ...Option) (*BeaconNode, error) {
log.Debugln("Registering RPC Service")
router := mux.NewRouter()
router.Use(middleware)
if err := beacon.registerRPCService(router); err != nil {
return nil, err
}

View File

@ -0,0 +1,17 @@
package node
import (
"net/http"
"github.com/prysmaticlabs/prysm/v4/beacon-chain/rpc/eth/helpers"
)
func middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
query := r.URL.Query()
helpers.NormalizeQueryValues(query)
r.URL.RawQuery = query.Encode()
next.ServeHTTP(w, r)
})
}

View File

@ -806,10 +806,6 @@ func (s *Server) GetBlockHeaders(w http.ResponseWriter, r *http.Request) {
ctx, span := trace.StartSpan(r.Context(), "beacon.GetBlockHeaders")
defer span.End()
query := r.URL.Query()
helpers.NormalizeQueryValues(query)
r.URL.RawQuery = query.Encode()
rawSlot := r.URL.Query().Get("slot")
rawParentRoot := r.URL.Query().Get("parent_root")

View File

@ -26,10 +26,6 @@ func (s *Server) GetValidators(w http.ResponseWriter, r *http.Request) {
ctx, span := trace.StartSpan(r.Context(), "beacon.GetValidators")
defer span.End()
query := r.URL.Query()
helpers.NormalizeQueryValues(query)
r.URL.RawQuery = query.Encode()
stateId := mux.Vars(r)["state_id"]
if stateId == "" {
http2.HandleError(w, "state_id is required in URL params", http.StatusBadRequest)
@ -215,10 +211,6 @@ func (bs *Server) GetValidatorBalances(w http.ResponseWriter, r *http.Request) {
ctx, span := trace.StartSpan(r.Context(), "beacon.GetValidatorBalances")
defer span.End()
query := r.URL.Query()
helpers.NormalizeQueryValues(query)
r.URL.RawQuery = query.Encode()
stateId := mux.Vars(r)["state_id"]
if stateId == "" {
http2.HandleError(w, "state_id is required in URL params", http.StatusBadRequest)

View File

@ -154,36 +154,6 @@ func TestGetValidators(t *testing.T) {
assert.Equal(t, "20", resp.Data[0].Index)
assert.Equal(t, "60", resp.Data[1].Index)
})
t.Run("multiple comma-separated", func(t *testing.T) {
chainService := &chainMock.ChainService{}
s := Server{
Stater: &testutil.MockStater{
BeaconState: st,
},
HeadFetcher: chainService,
OptimisticModeFetcher: chainService,
FinalizationFetcher: chainService,
}
pubkey := st.PubkeyAtIndex(primitives.ValidatorIndex(20))
hexPubkey := hexutil.Encode(pubkey[:])
request := httptest.NewRequest(
http.MethodGet,
fmt.Sprintf("http://example.com/eth/v1/beacon/states/{state_id}/validators?id=%s,60", hexPubkey),
nil,
)
request = mux.SetURLVars(request, map[string]string{"state_id": "head"})
writer := httptest.NewRecorder()
writer.Body = &bytes.Buffer{}
s.GetValidators(writer, request)
assert.Equal(t, http.StatusOK, writer.Code)
resp := &GetValidatorsResponse{}
require.NoError(t, json.Unmarshal(writer.Body.Bytes(), resp))
require.Equal(t, 2, len(resp.Data))
assert.Equal(t, "20", resp.Data[0].Index)
assert.Equal(t, "60", resp.Data[1].Index)
})
t.Run("state ID required", func(t *testing.T) {
s := Server{
Stater: &testutil.MockStater{

View File

@ -12,7 +12,6 @@ go_library(
deps = [
"//beacon-chain/blockchain:go_default_library",
"//beacon-chain/db:go_default_library",
"//beacon-chain/rpc/eth/helpers:go_default_library",
"//beacon-chain/rpc/lookup:go_default_library",
"//config/fieldparams:go_default_library",
"//config/params:go_default_library",

View File

@ -8,7 +8,6 @@ import (
"github.com/ethereum/go-ethereum/common/hexutil"
"github.com/pkg/errors"
"github.com/prysmaticlabs/prysm/v4/beacon-chain/rpc/eth/helpers"
"github.com/prysmaticlabs/prysm/v4/beacon-chain/rpc/lookup"
field_params "github.com/prysmaticlabs/prysm/v4/config/fieldparams"
"github.com/prysmaticlabs/prysm/v4/config/params"
@ -122,9 +121,7 @@ func (s *Server) Blobs(w http.ResponseWriter, r *http.Request) {
// parseIndices filters out invalid and duplicate blob indices
func parseIndices(url *url.URL) []uint64 {
query := url.Query()
helpers.NormalizeQueryValues(query)
rawIndices := query["indices"]
rawIndices := url.Query()["indices"]
indices := make([]uint64, 0, field_params.MaxBlobsPerBlock)
loop:
for _, raw := range rawIndices {

View File

@ -23,7 +23,7 @@ import (
)
func TestParseIndices(t *testing.T) {
assert.DeepEqual(t, []uint64{1, 2, 3}, parseIndices(&url.URL{RawQuery: "indices=1,2,foo,1&indices=3,1&bar=bar"}))
assert.DeepEqual(t, []uint64{1, 2, 3}, parseIndices(&url.URL{RawQuery: "indices=1&indices=2&indices=foo&indices=1&indices=3&bar=bar"}))
}
func TestBlobs(t *testing.T) {

View File

@ -44,10 +44,6 @@ func (s *Server) GetAggregateAttestation(w http.ResponseWriter, r *http.Request)
ctx, span := trace.StartSpan(r.Context(), "validator.GetAggregateAttestation")
defer span.End()
query := r.URL.Query()
rpchelpers.NormalizeQueryValues(query)
r.URL.RawQuery = query.Encode()
attDataRoot := r.URL.Query().Get("attestation_data_root")
attDataRootBytes, valid := shared.ValidateHex(w, "Attestation data root", attDataRoot, fieldparams.RootLength)
if !valid {
@ -402,10 +398,6 @@ func (s *Server) GetAttestationData(w http.ResponseWriter, r *http.Request) {
ctx, span := trace.StartSpan(r.Context(), "validator.GetAttestationData")
defer span.End()
query := r.URL.Query()
rpchelpers.NormalizeQueryValues(query)
r.URL.RawQuery = query.Encode()
if shared.IsSyncing(ctx, w, s.SyncChecker, s.HeadFetcher, s.TimeFetcher, s.OptimisticModeFetcher) {
return
}
@ -458,10 +450,6 @@ func (s *Server) ProduceSyncCommitteeContribution(w http.ResponseWriter, r *http
ctx, span := trace.StartSpan(r.Context(), "validator.ProduceSyncCommitteeContribution")
defer span.End()
query := r.URL.Query()
rpchelpers.NormalizeQueryValues(query)
r.URL.RawQuery = query.Encode()
subIndex := r.URL.Query().Get("subcommittee_index")
index, valid := shared.ValidateUint(w, "Subcommittee Index", subIndex)
if !valid {

View File

@ -9,7 +9,6 @@ import (
"github.com/pkg/errors"
"github.com/prysmaticlabs/prysm/v4/api"
"github.com/prysmaticlabs/prysm/v4/beacon-chain/rpc/eth/helpers"
"github.com/prysmaticlabs/prysm/v4/beacon-chain/rpc/eth/shared"
fieldparams "github.com/prysmaticlabs/prysm/v4/config/fieldparams"
"github.com/prysmaticlabs/prysm/v4/consensus-types/primitives"
@ -32,10 +31,6 @@ func (s *Server) ProduceBlockV3(w http.ResponseWriter, r *http.Request) {
ctx, span := trace.StartSpan(r.Context(), "validator.ProduceBlockV3")
defer span.End()
query := r.URL.Query()
helpers.NormalizeQueryValues(query)
r.URL.RawQuery = query.Encode()
if shared.IsSyncing(r.Context(), w, s.SyncChecker, s.HeadFetcher, s.TimeFetcher, s.OptimisticModeFetcher) {
return
}

View File

@ -63,10 +63,6 @@ func (vs *Server) GetValidatorCount(w http.ResponseWriter, r *http.Request) {
ctx, span := trace.StartSpan(r.Context(), "beacon.GetValidatorCount")
defer span.End()
query := r.URL.Query()
helpers.NormalizeQueryValues(query)
r.URL.RawQuery = query.Encode()
stateID := mux.Vars(r)["state_id"]
isOptimistic, err := helpers.IsOptimistic(ctx, []byte(stateID), vs.OptimisticModeFetcher, vs.Stater, vs.ChainInfoFetcher, vs.BeaconDB)

View File

@ -188,6 +188,19 @@ var beaconPathsAndObjects = map[string]metadata{
"json": &beacon.GetBlockHeaderResponse{},
},
},
// we want to test comma-separated query params
"/beacon/states/{param1}/validators?id=0,1": {
basepath: v1MiddlewarePathTemplate,
params: func(_ string, e primitives.Epoch) []string {
return []string{"head"}
},
prysmResps: map[string]interface{}{
"json": &beacon.GetValidatorsResponse{},
},
lighthouseResps: map[string]interface{}{
"json": &beacon.GetValidatorsResponse{},
},
},
"/node/identity": {
basepath: v1MiddlewarePathTemplate,
params: func(_ string, _ primitives.Epoch) []string {