mirror of
https://gitlab.com/pulsechaincom/prysm-pulse.git
synced 2025-01-10 03:31:20 +00:00
410 lines
13 KiB
Go
410 lines
13 KiB
Go
package apimiddleware
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/prysmaticlabs/prysm/api/grpc"
|
|
"github.com/prysmaticlabs/prysm/testing/assert"
|
|
"github.com/prysmaticlabs/prysm/testing/require"
|
|
"github.com/sirupsen/logrus/hooks/test"
|
|
)
|
|
|
|
type testRequestContainer struct {
|
|
TestString string
|
|
TestHexString string `hex:"true"`
|
|
}
|
|
|
|
func defaultRequestContainer() *testRequestContainer {
|
|
return &testRequestContainer{
|
|
TestString: "test string",
|
|
TestHexString: "0x666F6F", // hex encoding of "foo"
|
|
}
|
|
}
|
|
|
|
type testResponseContainer struct {
|
|
TestString string
|
|
TestHex string `hex:"true"`
|
|
TestEnum string `enum:"true"`
|
|
TestTime string `time:"true"`
|
|
}
|
|
|
|
func defaultResponseContainer() *testResponseContainer {
|
|
return &testResponseContainer{
|
|
TestString: "test string",
|
|
TestHex: "Zm9v", // base64 encoding of "foo"
|
|
TestEnum: "Test Enum",
|
|
TestTime: "2006-01-02T15:04:05Z",
|
|
}
|
|
}
|
|
|
|
type testErrorJson struct {
|
|
Message string
|
|
Code int
|
|
CustomField string
|
|
}
|
|
|
|
// StatusCode returns the error's underlying error code.
|
|
func (e *testErrorJson) StatusCode() int {
|
|
return e.Code
|
|
}
|
|
|
|
// Msg returns the error's underlying message.
|
|
func (e *testErrorJson) Msg() string {
|
|
return e.Message
|
|
}
|
|
|
|
// SetCode sets the error's underlying error code.
|
|
func (e *testErrorJson) SetCode(code int) {
|
|
e.Code = code
|
|
}
|
|
|
|
// SetMsg sets the error's underlying message.
|
|
func (e *testErrorJson) SetMsg(msg string) {
|
|
e.Message = msg
|
|
}
|
|
|
|
func TestDeserializeRequestBodyIntoContainer(t *testing.T) {
|
|
t.Run("ok", func(t *testing.T) {
|
|
var bodyJson bytes.Buffer
|
|
err := json.NewEncoder(&bodyJson).Encode(defaultRequestContainer())
|
|
require.NoError(t, err)
|
|
|
|
container := &testRequestContainer{}
|
|
errJson := DeserializeRequestBodyIntoContainer(&bodyJson, container)
|
|
require.Equal(t, true, errJson == nil)
|
|
assert.Equal(t, "test string", container.TestString)
|
|
})
|
|
|
|
t.Run("error", func(t *testing.T) {
|
|
var bodyJson bytes.Buffer
|
|
bodyJson.Write([]byte("foo"))
|
|
errJson := DeserializeRequestBodyIntoContainer(&bodyJson, &testRequestContainer{})
|
|
require.NotNil(t, errJson)
|
|
assert.Equal(t, true, strings.Contains(errJson.Msg(), "could not decode request body"))
|
|
assert.Equal(t, http.StatusInternalServerError, errJson.StatusCode())
|
|
})
|
|
|
|
t.Run("unknown field", func(t *testing.T) {
|
|
var bodyJson bytes.Buffer
|
|
bodyJson.Write([]byte("{\"foo\":\"foo\"}"))
|
|
errJson := DeserializeRequestBodyIntoContainer(&bodyJson, &testRequestContainer{})
|
|
require.NotNil(t, errJson)
|
|
assert.Equal(t, true, strings.Contains(errJson.Msg(), "could not decode request body"))
|
|
assert.Equal(t, http.StatusBadRequest, errJson.StatusCode())
|
|
})
|
|
}
|
|
|
|
func TestProcessRequestContainerFields(t *testing.T) {
|
|
t.Run("ok", func(t *testing.T) {
|
|
container := defaultRequestContainer()
|
|
|
|
errJson := ProcessRequestContainerFields(container)
|
|
require.Equal(t, true, errJson == nil)
|
|
assert.Equal(t, "Zm9v", container.TestHexString)
|
|
})
|
|
|
|
t.Run("error", func(t *testing.T) {
|
|
errJson := ProcessRequestContainerFields("foo")
|
|
require.NotNil(t, errJson)
|
|
assert.Equal(t, true, strings.Contains(errJson.Msg(), "could not process request data"))
|
|
assert.Equal(t, http.StatusInternalServerError, errJson.StatusCode())
|
|
})
|
|
}
|
|
|
|
func TestSetRequestBodyToRequestContainer(t *testing.T) {
|
|
var body bytes.Buffer
|
|
request := httptest.NewRequest("GET", "http://foo.example", &body)
|
|
|
|
errJson := SetRequestBodyToRequestContainer(defaultRequestContainer(), request)
|
|
require.Equal(t, true, errJson == nil)
|
|
container := &testRequestContainer{}
|
|
require.NoError(t, json.NewDecoder(request.Body).Decode(container))
|
|
assert.Equal(t, "test string", container.TestString)
|
|
contentLengthHeader, ok := request.Header["Content-Length"]
|
|
require.Equal(t, true, ok)
|
|
require.Equal(t, 1, len(contentLengthHeader), "wrong number of header values")
|
|
assert.Equal(t, "55", contentLengthHeader[0])
|
|
assert.Equal(t, int64(55), request.ContentLength)
|
|
}
|
|
|
|
func TestPrepareRequestForProxying(t *testing.T) {
|
|
middleware := &ApiProxyMiddleware{
|
|
GatewayAddress: "http://gateway.example",
|
|
}
|
|
// We will set some params to make the request more interesting.
|
|
endpoint := Endpoint{
|
|
Path: "/{url_param}",
|
|
RequestURLLiterals: []string{"url_param"},
|
|
RequestQueryParams: []QueryParam{{Name: "query_param"}},
|
|
}
|
|
var body bytes.Buffer
|
|
request := httptest.NewRequest("GET", "http://foo.example?query_param=bar", &body)
|
|
|
|
errJson := middleware.PrepareRequestForProxying(endpoint, request)
|
|
require.Equal(t, true, errJson == nil)
|
|
assert.Equal(t, "http", request.URL.Scheme)
|
|
assert.Equal(t, middleware.GatewayAddress, request.URL.Host)
|
|
assert.Equal(t, "", request.RequestURI)
|
|
}
|
|
|
|
func TestReadGrpcResponseBody(t *testing.T) {
|
|
var b bytes.Buffer
|
|
b.Write([]byte("foo"))
|
|
|
|
body, jsonErr := ReadGrpcResponseBody(&b)
|
|
require.Equal(t, true, jsonErr == nil)
|
|
assert.Equal(t, "foo", string(body))
|
|
}
|
|
|
|
func TestHandleGrpcResponseError(t *testing.T) {
|
|
response := &http.Response{
|
|
StatusCode: 400,
|
|
Header: http.Header{
|
|
"Foo": []string{"foo"},
|
|
"Bar": []string{"bar"},
|
|
},
|
|
}
|
|
writer := httptest.NewRecorder()
|
|
errJson := &testErrorJson{
|
|
Message: "foo",
|
|
Code: 400,
|
|
}
|
|
b, err := json.Marshal(errJson)
|
|
require.NoError(t, err)
|
|
|
|
hasError, e := HandleGrpcResponseError(errJson, response, b, writer)
|
|
require.Equal(t, true, e == nil)
|
|
assert.Equal(t, true, hasError)
|
|
v, ok := writer.Header()["Foo"]
|
|
require.Equal(t, true, ok, "header not found")
|
|
require.Equal(t, 1, len(v), "wrong number of header values")
|
|
assert.Equal(t, "foo", v[0])
|
|
v, ok = writer.Header()["Bar"]
|
|
require.Equal(t, true, ok, "header not found")
|
|
require.Equal(t, 1, len(v), "wrong number of header values")
|
|
assert.Equal(t, "bar", v[0])
|
|
assert.Equal(t, 400, errJson.StatusCode())
|
|
}
|
|
|
|
func TestGrpcResponseIsEmpty(t *testing.T) {
|
|
t.Run("nil", func(t *testing.T) {
|
|
assert.Equal(t, true, GrpcResponseIsEmpty(nil))
|
|
})
|
|
t.Run("empty_slice", func(t *testing.T) {
|
|
assert.Equal(t, true, GrpcResponseIsEmpty(make([]byte, 0)))
|
|
})
|
|
t.Run("empty_brackets", func(t *testing.T) {
|
|
assert.Equal(t, true, GrpcResponseIsEmpty([]byte("{}")))
|
|
})
|
|
t.Run("non_empty", func(t *testing.T) {
|
|
assert.Equal(t, false, GrpcResponseIsEmpty([]byte("{\"foo\":\"bar\"})")))
|
|
})
|
|
}
|
|
|
|
func TestDeserializeGrpcResponseBodyIntoContainer(t *testing.T) {
|
|
t.Run("ok", func(t *testing.T) {
|
|
body, err := json.Marshal(defaultRequestContainer())
|
|
require.NoError(t, err)
|
|
|
|
container := &testRequestContainer{}
|
|
errJson := DeserializeGrpcResponseBodyIntoContainer(body, container)
|
|
require.Equal(t, true, errJson == nil)
|
|
assert.Equal(t, "test string", container.TestString)
|
|
})
|
|
|
|
t.Run("error", func(t *testing.T) {
|
|
var bodyJson bytes.Buffer
|
|
bodyJson.Write([]byte("foo"))
|
|
errJson := DeserializeGrpcResponseBodyIntoContainer(bodyJson.Bytes(), &testRequestContainer{})
|
|
require.NotNil(t, errJson)
|
|
assert.Equal(t, true, strings.Contains(errJson.Msg(), "could not unmarshal response"))
|
|
assert.Equal(t, http.StatusInternalServerError, errJson.StatusCode())
|
|
})
|
|
}
|
|
|
|
func TestProcessMiddlewareResponseFields(t *testing.T) {
|
|
t.Run("Ok", func(t *testing.T) {
|
|
container := defaultResponseContainer()
|
|
|
|
errJson := ProcessMiddlewareResponseFields(container)
|
|
require.Equal(t, true, errJson == nil)
|
|
assert.Equal(t, "0x666f6f", container.TestHex)
|
|
assert.Equal(t, "test enum", container.TestEnum)
|
|
assert.Equal(t, "1136214245", container.TestTime)
|
|
})
|
|
|
|
t.Run("error", func(t *testing.T) {
|
|
errJson := ProcessMiddlewareResponseFields("foo")
|
|
require.NotNil(t, errJson)
|
|
assert.Equal(t, true, strings.Contains(errJson.Msg(), "could not process response data"))
|
|
assert.Equal(t, http.StatusInternalServerError, errJson.StatusCode())
|
|
})
|
|
}
|
|
|
|
func TestSerializeMiddlewareResponseIntoJson(t *testing.T) {
|
|
container := defaultResponseContainer()
|
|
j, errJson := SerializeMiddlewareResponseIntoJson(container)
|
|
assert.Equal(t, true, errJson == nil)
|
|
cToDeserialize := &testResponseContainer{}
|
|
require.NoError(t, json.Unmarshal(j, cToDeserialize))
|
|
assert.Equal(t, "test string", cToDeserialize.TestString)
|
|
}
|
|
|
|
func TestWriteMiddlewareResponseHeadersAndBody(t *testing.T) {
|
|
t.Run("GET", func(t *testing.T) {
|
|
response := &http.Response{
|
|
Header: http.Header{
|
|
"Foo": []string{"foo"},
|
|
"Grpc-Metadata-" + grpc.HttpCodeMetadataKey: []string{"204"},
|
|
},
|
|
}
|
|
container := defaultResponseContainer()
|
|
responseJson, err := json.Marshal(container)
|
|
require.NoError(t, err)
|
|
writer := httptest.NewRecorder()
|
|
writer.Body = &bytes.Buffer{}
|
|
|
|
errJson := WriteMiddlewareResponseHeadersAndBody(response, responseJson, writer)
|
|
require.Equal(t, true, errJson == nil)
|
|
v, ok := writer.Header()["Foo"]
|
|
require.Equal(t, true, ok, "header not found")
|
|
require.Equal(t, 1, len(v), "wrong number of header values")
|
|
assert.Equal(t, "foo", v[0])
|
|
v, ok = writer.Header()["Content-Length"]
|
|
require.Equal(t, true, ok, "header not found")
|
|
require.Equal(t, 1, len(v), "wrong number of header values")
|
|
assert.Equal(t, "102", v[0])
|
|
assert.Equal(t, 204, writer.Code)
|
|
assert.DeepEqual(t, responseJson, writer.Body.Bytes())
|
|
})
|
|
|
|
t.Run("GET_no_grpc_status_code_header", func(t *testing.T) {
|
|
response := &http.Response{
|
|
Header: http.Header{},
|
|
StatusCode: 204,
|
|
}
|
|
container := defaultResponseContainer()
|
|
responseJson, err := json.Marshal(container)
|
|
require.NoError(t, err)
|
|
writer := httptest.NewRecorder()
|
|
|
|
errJson := WriteMiddlewareResponseHeadersAndBody(response, responseJson, writer)
|
|
require.Equal(t, true, errJson == nil)
|
|
assert.Equal(t, 204, writer.Code)
|
|
})
|
|
|
|
t.Run("GET_invalid_status_code", func(t *testing.T) {
|
|
response := &http.Response{
|
|
Header: http.Header{},
|
|
}
|
|
|
|
// Set invalid status code.
|
|
response.Header["Grpc-Metadata-"+grpc.HttpCodeMetadataKey] = []string{"invalid"}
|
|
|
|
container := defaultResponseContainer()
|
|
responseJson, err := json.Marshal(container)
|
|
require.NoError(t, err)
|
|
writer := httptest.NewRecorder()
|
|
|
|
errJson := WriteMiddlewareResponseHeadersAndBody(response, responseJson, writer)
|
|
require.Equal(t, false, errJson == nil)
|
|
assert.Equal(t, true, strings.Contains(errJson.Msg(), "could not parse status code"))
|
|
assert.Equal(t, http.StatusInternalServerError, errJson.StatusCode())
|
|
})
|
|
|
|
t.Run("POST", func(t *testing.T) {
|
|
response := &http.Response{
|
|
Header: http.Header{},
|
|
StatusCode: 204,
|
|
}
|
|
container := defaultResponseContainer()
|
|
responseJson, err := json.Marshal(container)
|
|
require.NoError(t, err)
|
|
writer := httptest.NewRecorder()
|
|
|
|
errJson := WriteMiddlewareResponseHeadersAndBody(response, responseJson, writer)
|
|
require.Equal(t, true, errJson == nil)
|
|
assert.Equal(t, 204, writer.Code)
|
|
})
|
|
|
|
t.Run("POST_with_response_body", func(t *testing.T) {
|
|
response := &http.Response{
|
|
Header: http.Header{},
|
|
StatusCode: 204,
|
|
}
|
|
container := defaultResponseContainer()
|
|
responseJson, err := json.Marshal(container)
|
|
require.NoError(t, err)
|
|
writer := httptest.NewRecorder()
|
|
writer.Body = &bytes.Buffer{}
|
|
|
|
errJson := WriteMiddlewareResponseHeadersAndBody(response, responseJson, writer)
|
|
require.Equal(t, true, errJson == nil)
|
|
assert.Equal(t, 204, writer.Code)
|
|
assert.DeepEqual(t, responseJson, writer.Body.Bytes())
|
|
})
|
|
|
|
t.Run("POST_with_empty_json_body", func(t *testing.T) {
|
|
response := &http.Response{
|
|
Header: http.Header{},
|
|
StatusCode: 204,
|
|
}
|
|
responseJson, err := json.Marshal(struct{}{})
|
|
require.NoError(t, err)
|
|
writer := httptest.NewRecorder()
|
|
writer.Body = &bytes.Buffer{}
|
|
|
|
errJson := WriteMiddlewareResponseHeadersAndBody(response, responseJson, writer)
|
|
require.Equal(t, true, errJson == nil)
|
|
assert.Equal(t, 204, writer.Code)
|
|
assert.DeepEqual(t, []byte(nil), writer.Body.Bytes())
|
|
assert.Equal(t, "0", writer.Header()["Content-Length"][0])
|
|
})
|
|
}
|
|
|
|
func TestWriteError(t *testing.T) {
|
|
t.Run("ok", func(t *testing.T) {
|
|
responseHeader := http.Header{
|
|
"Grpc-Metadata-" + grpc.CustomErrorMetadataKey: []string{"{\"CustomField\":\"bar\"}"},
|
|
}
|
|
errJson := &testErrorJson{
|
|
Message: "foo",
|
|
Code: 500,
|
|
}
|
|
writer := httptest.NewRecorder()
|
|
writer.Body = &bytes.Buffer{}
|
|
|
|
WriteError(writer, errJson, responseHeader)
|
|
v, ok := writer.Header()["Content-Length"]
|
|
require.Equal(t, true, ok, "header not found")
|
|
require.Equal(t, 1, len(v), "wrong number of header values")
|
|
assert.Equal(t, "48", v[0])
|
|
v, ok = writer.Header()["Content-Type"]
|
|
require.Equal(t, true, ok, "header not found")
|
|
require.Equal(t, 1, len(v), "wrong number of header values")
|
|
assert.Equal(t, "application/json", v[0])
|
|
assert.Equal(t, 500, writer.Code)
|
|
eDeserialize := &testErrorJson{}
|
|
require.NoError(t, json.Unmarshal(writer.Body.Bytes(), eDeserialize))
|
|
assert.Equal(t, "foo", eDeserialize.Message)
|
|
assert.Equal(t, 500, eDeserialize.Code)
|
|
assert.Equal(t, "bar", eDeserialize.CustomField)
|
|
})
|
|
|
|
t.Run("invalid_custom_error_header", func(t *testing.T) {
|
|
logHook := test.NewGlobal()
|
|
|
|
responseHeader := http.Header{
|
|
"Grpc-Metadata-" + grpc.CustomErrorMetadataKey: []string{"invalid"},
|
|
}
|
|
|
|
WriteError(httptest.NewRecorder(), &testErrorJson{}, responseHeader)
|
|
assert.LogsContain(t, logHook, "Could not unmarshal custom error message")
|
|
})
|
|
}
|