mirror of
https://gitlab.com/pulsechaincom/prysm-pulse.git
synced 2024-12-24 12:27:18 +00:00
700f5fee8c
* add context to beacon APIs * add TODO to merge GET and POST methods * fix linter action Co-authored-by: kasey <489222+kasey@users.noreply.github.com> Co-authored-by: james-prysm <90280386+james-prysm@users.noreply.github.com>
437 lines
13 KiB
Go
437 lines
13 KiB
Go
package beacon_api
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/prysmaticlabs/prysm/v3/api/gateway/apimiddleware"
|
|
rpcmiddleware "github.com/prysmaticlabs/prysm/v3/beacon-chain/rpc/apimiddleware"
|
|
"github.com/prysmaticlabs/prysm/v3/testing/assert"
|
|
"github.com/prysmaticlabs/prysm/v3/testing/require"
|
|
)
|
|
|
|
func TestGetRestJsonResponse_Valid(t *testing.T) {
|
|
const endpoint = "/example/rest/api/endpoint"
|
|
|
|
genesisJson := &rpcmiddleware.GenesisResponseJson{
|
|
Data: &rpcmiddleware.GenesisResponse_GenesisJson{
|
|
GenesisTime: "123",
|
|
GenesisValidatorsRoot: "0x456",
|
|
GenesisForkVersion: "0x789",
|
|
},
|
|
}
|
|
|
|
ctx := context.Background()
|
|
|
|
mux := http.NewServeMux()
|
|
mux.HandleFunc(endpoint, func(w http.ResponseWriter, r *http.Request) {
|
|
// Make sure the url parameters match
|
|
assert.Equal(t, "abc", r.URL.Query().Get("arg1"))
|
|
assert.Equal(t, "def", r.URL.Query().Get("arg2"))
|
|
|
|
marshalledJson, err := json.Marshal(genesisJson)
|
|
require.NoError(t, err)
|
|
|
|
_, err = w.Write(marshalledJson)
|
|
require.NoError(t, err)
|
|
})
|
|
server := httptest.NewServer(mux)
|
|
defer server.Close()
|
|
|
|
jsonRestHandler := beaconApiJsonRestHandler{
|
|
httpClient: http.Client{Timeout: time.Second * 5},
|
|
host: server.URL,
|
|
}
|
|
|
|
responseJson := &rpcmiddleware.GenesisResponseJson{}
|
|
_, err := jsonRestHandler.GetRestJsonResponse(ctx, endpoint+"?arg1=abc&arg2=def", responseJson)
|
|
assert.NoError(t, err)
|
|
assert.DeepEqual(t, genesisJson, responseJson)
|
|
}
|
|
|
|
func TestGetRestJsonResponse_Error(t *testing.T) {
|
|
const endpoint = "/example/rest/api/endpoint"
|
|
|
|
testCases := []struct {
|
|
name string
|
|
funcHandler func(w http.ResponseWriter, r *http.Request)
|
|
expectedErrorJson *apimiddleware.DefaultErrorJson
|
|
expectedErrorMessage string
|
|
timeout time.Duration
|
|
responseJson interface{}
|
|
}{
|
|
{
|
|
name: "nil response json",
|
|
funcHandler: invalidJsonResponseHandler,
|
|
expectedErrorMessage: "responseJson is nil",
|
|
timeout: time.Second * 5,
|
|
responseJson: nil,
|
|
},
|
|
{
|
|
name: "400 error",
|
|
funcHandler: httpErrorJsonHandler(http.StatusBadRequest, "Bad request"),
|
|
expectedErrorMessage: "error 400: Bad request",
|
|
expectedErrorJson: &apimiddleware.DefaultErrorJson{
|
|
Code: http.StatusBadRequest,
|
|
Message: "Bad request",
|
|
},
|
|
timeout: time.Second * 5,
|
|
responseJson: &rpcmiddleware.GenesisResponseJson{},
|
|
},
|
|
{
|
|
name: "404 error",
|
|
funcHandler: httpErrorJsonHandler(http.StatusNotFound, "Not found"),
|
|
expectedErrorMessage: "error 404: Not found",
|
|
expectedErrorJson: &apimiddleware.DefaultErrorJson{
|
|
Code: http.StatusNotFound,
|
|
Message: "Not found",
|
|
},
|
|
timeout: time.Second * 5,
|
|
responseJson: &rpcmiddleware.GenesisResponseJson{},
|
|
},
|
|
{
|
|
name: "500 error",
|
|
funcHandler: httpErrorJsonHandler(http.StatusInternalServerError, "Internal server error"),
|
|
expectedErrorMessage: "error 500: Internal server error",
|
|
expectedErrorJson: &apimiddleware.DefaultErrorJson{
|
|
Code: http.StatusInternalServerError,
|
|
Message: "Internal server error",
|
|
},
|
|
timeout: time.Second * 5,
|
|
responseJson: &rpcmiddleware.GenesisResponseJson{},
|
|
},
|
|
{
|
|
name: "999 error",
|
|
funcHandler: httpErrorJsonHandler(999, "Invalid error"),
|
|
expectedErrorMessage: "error 999: Invalid error",
|
|
expectedErrorJson: &apimiddleware.DefaultErrorJson{
|
|
Code: 999,
|
|
Message: "Invalid error",
|
|
},
|
|
timeout: time.Second * 5,
|
|
responseJson: &rpcmiddleware.GenesisResponseJson{},
|
|
},
|
|
{
|
|
name: "bad error json formatting",
|
|
funcHandler: invalidJsonErrHandler,
|
|
expectedErrorMessage: "failed to decode error json",
|
|
timeout: time.Second * 5,
|
|
responseJson: &rpcmiddleware.GenesisResponseJson{},
|
|
},
|
|
{
|
|
name: "bad response json formatting",
|
|
funcHandler: invalidJsonResponseHandler,
|
|
expectedErrorMessage: "failed to decode response json",
|
|
timeout: time.Second * 5,
|
|
responseJson: &rpcmiddleware.GenesisResponseJson{},
|
|
},
|
|
{
|
|
name: "timeout",
|
|
funcHandler: httpErrorJsonHandler(http.StatusNotFound, "Not found"),
|
|
expectedErrorMessage: "failed to query REST API",
|
|
timeout: 1,
|
|
responseJson: &rpcmiddleware.GenesisResponseJson{},
|
|
},
|
|
}
|
|
|
|
for _, testCase := range testCases {
|
|
t.Run(testCase.name, func(t *testing.T) {
|
|
mux := http.NewServeMux()
|
|
mux.HandleFunc(endpoint, testCase.funcHandler)
|
|
server := httptest.NewServer(mux)
|
|
defer server.Close()
|
|
|
|
ctx := context.Background()
|
|
|
|
jsonRestHandler := beaconApiJsonRestHandler{
|
|
httpClient: http.Client{Timeout: testCase.timeout},
|
|
host: server.URL,
|
|
}
|
|
errorJson, err := jsonRestHandler.GetRestJsonResponse(ctx, endpoint, testCase.responseJson)
|
|
assert.ErrorContains(t, testCase.expectedErrorMessage, err)
|
|
assert.DeepEqual(t, testCase.expectedErrorJson, errorJson)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestPostRestJson_Valid(t *testing.T) {
|
|
const endpoint = "/example/rest/api/endpoint"
|
|
dataBytes := []byte{1, 2, 3, 4, 5}
|
|
|
|
genesisJson := &rpcmiddleware.GenesisResponseJson{
|
|
Data: &rpcmiddleware.GenesisResponse_GenesisJson{
|
|
GenesisTime: "123",
|
|
GenesisValidatorsRoot: "0x456",
|
|
GenesisForkVersion: "0x789",
|
|
},
|
|
}
|
|
|
|
testCases := []struct {
|
|
name string
|
|
headers map[string]string
|
|
data *bytes.Buffer
|
|
responseJson interface{}
|
|
}{
|
|
{
|
|
name: "nil headers",
|
|
headers: nil,
|
|
data: bytes.NewBuffer(dataBytes),
|
|
responseJson: &rpcmiddleware.GenesisResponseJson{},
|
|
},
|
|
{
|
|
name: "empty headers",
|
|
headers: map[string]string{},
|
|
data: bytes.NewBuffer(dataBytes),
|
|
responseJson: &rpcmiddleware.GenesisResponseJson{},
|
|
},
|
|
{
|
|
name: "nil response json",
|
|
headers: map[string]string{"DummyHeaderKey": "DummyHeaderValue"},
|
|
data: bytes.NewBuffer(dataBytes),
|
|
responseJson: nil,
|
|
},
|
|
}
|
|
|
|
for _, testCase := range testCases {
|
|
t.Run(testCase.name, func(t *testing.T) {
|
|
mux := http.NewServeMux()
|
|
mux.HandleFunc(endpoint, func(w http.ResponseWriter, r *http.Request) {
|
|
// Make sure the request headers have been set
|
|
for headerKey, headerValue := range testCase.headers {
|
|
assert.Equal(t, headerValue, r.Header.Get(headerKey))
|
|
}
|
|
|
|
// Make sure the data matches
|
|
receivedBytes := make([]byte, len(dataBytes))
|
|
numBytes, err := r.Body.Read(receivedBytes)
|
|
assert.Equal(t, io.EOF, err)
|
|
assert.Equal(t, len(dataBytes), numBytes)
|
|
assert.DeepEqual(t, dataBytes, receivedBytes)
|
|
|
|
marshalledJson, err := json.Marshal(genesisJson)
|
|
require.NoError(t, err)
|
|
|
|
_, err = w.Write(marshalledJson)
|
|
require.NoError(t, err)
|
|
})
|
|
server := httptest.NewServer(mux)
|
|
defer server.Close()
|
|
|
|
ctx := context.Background()
|
|
|
|
jsonRestHandler := beaconApiJsonRestHandler{
|
|
httpClient: http.Client{Timeout: time.Second * 5},
|
|
host: server.URL,
|
|
}
|
|
|
|
_, err := jsonRestHandler.PostRestJson(
|
|
ctx,
|
|
endpoint,
|
|
testCase.headers,
|
|
testCase.data,
|
|
testCase.responseJson,
|
|
)
|
|
|
|
assert.NoError(t, err)
|
|
|
|
if testCase.responseJson != nil {
|
|
assert.DeepEqual(t, genesisJson, testCase.responseJson)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestPostRestJson_Error(t *testing.T) {
|
|
const endpoint = "/example/rest/api/endpoint"
|
|
|
|
testCases := []struct {
|
|
name string
|
|
funcHandler func(w http.ResponseWriter, r *http.Request)
|
|
expectedErrorJson *apimiddleware.DefaultErrorJson
|
|
expectedErrorMessage string
|
|
timeout time.Duration
|
|
responseJson *rpcmiddleware.GenesisResponseJson
|
|
data *bytes.Buffer
|
|
}{
|
|
{
|
|
name: "nil POST data",
|
|
funcHandler: httpErrorJsonHandler(http.StatusNotFound, "Not found"),
|
|
expectedErrorMessage: "POST data is nil",
|
|
timeout: time.Second * 5,
|
|
data: nil,
|
|
},
|
|
{
|
|
name: "400 error",
|
|
funcHandler: httpErrorJsonHandler(http.StatusBadRequest, "Bad request"),
|
|
expectedErrorMessage: "error 400: Bad request",
|
|
expectedErrorJson: &apimiddleware.DefaultErrorJson{
|
|
Code: http.StatusBadRequest,
|
|
Message: "Bad request",
|
|
},
|
|
timeout: time.Second * 5,
|
|
responseJson: &rpcmiddleware.GenesisResponseJson{},
|
|
data: &bytes.Buffer{},
|
|
},
|
|
{
|
|
name: "404 error",
|
|
funcHandler: httpErrorJsonHandler(http.StatusNotFound, "Not found"),
|
|
expectedErrorMessage: "error 404: Not found",
|
|
expectedErrorJson: &apimiddleware.DefaultErrorJson{
|
|
Code: http.StatusNotFound,
|
|
Message: "Not found",
|
|
},
|
|
timeout: time.Second * 5,
|
|
data: &bytes.Buffer{},
|
|
},
|
|
{
|
|
name: "500 error",
|
|
funcHandler: httpErrorJsonHandler(http.StatusInternalServerError, "Internal server error"),
|
|
expectedErrorMessage: "error 500: Internal server error",
|
|
expectedErrorJson: &apimiddleware.DefaultErrorJson{
|
|
Code: http.StatusInternalServerError,
|
|
Message: "Internal server error",
|
|
},
|
|
timeout: time.Second * 5,
|
|
data: &bytes.Buffer{},
|
|
},
|
|
{
|
|
name: "999 error",
|
|
funcHandler: httpErrorJsonHandler(999, "Invalid error"),
|
|
expectedErrorMessage: "error 999: Invalid error",
|
|
expectedErrorJson: &apimiddleware.DefaultErrorJson{
|
|
Code: 999,
|
|
Message: "Invalid error",
|
|
},
|
|
timeout: time.Second * 5,
|
|
data: &bytes.Buffer{},
|
|
},
|
|
{
|
|
name: "bad error json formatting",
|
|
funcHandler: invalidJsonErrHandler,
|
|
expectedErrorMessage: "failed to decode error json",
|
|
timeout: time.Second * 5,
|
|
data: &bytes.Buffer{},
|
|
},
|
|
{
|
|
name: "bad response json formatting",
|
|
funcHandler: invalidJsonResponseHandler,
|
|
expectedErrorMessage: "failed to decode response json",
|
|
timeout: time.Second * 5,
|
|
responseJson: &rpcmiddleware.GenesisResponseJson{},
|
|
data: &bytes.Buffer{},
|
|
},
|
|
{
|
|
name: "timeout",
|
|
funcHandler: httpErrorJsonHandler(http.StatusNotFound, "Not found"),
|
|
expectedErrorMessage: "failed to send POST data to REST endpoint",
|
|
timeout: 1,
|
|
data: &bytes.Buffer{},
|
|
},
|
|
}
|
|
|
|
for _, testCase := range testCases {
|
|
t.Run(testCase.name, func(t *testing.T) {
|
|
mux := http.NewServeMux()
|
|
mux.HandleFunc(endpoint, testCase.funcHandler)
|
|
server := httptest.NewServer(mux)
|
|
defer server.Close()
|
|
|
|
ctx := context.Background()
|
|
|
|
jsonRestHandler := beaconApiJsonRestHandler{
|
|
httpClient: http.Client{Timeout: testCase.timeout},
|
|
host: server.URL,
|
|
}
|
|
|
|
errorJson, err := jsonRestHandler.PostRestJson(
|
|
ctx,
|
|
endpoint,
|
|
map[string]string{},
|
|
testCase.data,
|
|
testCase.responseJson,
|
|
)
|
|
|
|
assert.ErrorContains(t, testCase.expectedErrorMessage, err)
|
|
assert.DeepEqual(t, testCase.expectedErrorJson, errorJson)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestJsonHandler_ContextError(t *testing.T) {
|
|
const endpoint = "/example/rest/api/endpoint"
|
|
mux := http.NewServeMux()
|
|
mux.HandleFunc(endpoint, func(writer http.ResponseWriter, request *http.Request) {})
|
|
server := httptest.NewServer(mux)
|
|
defer server.Close()
|
|
|
|
// Instantiate a cancellable context.
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
// Cancel the context which results in "context canceled" error.
|
|
cancel()
|
|
|
|
jsonRestHandler := beaconApiJsonRestHandler{
|
|
httpClient: http.Client{},
|
|
host: server.URL,
|
|
}
|
|
|
|
_, err := jsonRestHandler.PostRestJson(
|
|
ctx,
|
|
endpoint,
|
|
map[string]string{},
|
|
&bytes.Buffer{},
|
|
nil,
|
|
)
|
|
|
|
assert.ErrorContains(t, context.Canceled.Error(), err)
|
|
|
|
_, err = jsonRestHandler.GetRestJsonResponse(
|
|
ctx,
|
|
endpoint,
|
|
&rpcmiddleware.GenesisResponseJson{},
|
|
)
|
|
|
|
assert.ErrorContains(t, context.Canceled.Error(), err)
|
|
}
|
|
|
|
func httpErrorJsonHandler(statusCode int, errorMessage string) func(w http.ResponseWriter, r *http.Request) {
|
|
return func(w http.ResponseWriter, _ *http.Request) {
|
|
errorJson := &apimiddleware.DefaultErrorJson{
|
|
Code: statusCode,
|
|
Message: errorMessage,
|
|
}
|
|
|
|
marshalledError, err := json.Marshal(errorJson)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
w.WriteHeader(statusCode)
|
|
_, err = w.Write(marshalledError)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func invalidJsonErrHandler(w http.ResponseWriter, _ *http.Request) {
|
|
w.WriteHeader(http.StatusNotFound)
|
|
_, err := w.Write([]byte("foo"))
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
}
|
|
|
|
func invalidJsonResponseHandler(w http.ResponseWriter, _ *http.Request) {
|
|
_, err := w.Write([]byte("foo"))
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
}
|