package apimiddleware import ( "bytes" "context" "encoding/json" "net/http" "net/http/httptest" "strings" "testing" "time" "github.com/prysmaticlabs/prysm/beacon-chain/rpc/eventsv1" "github.com/prysmaticlabs/prysm/shared/gateway" "github.com/prysmaticlabs/prysm/shared/grpcutils" "github.com/prysmaticlabs/prysm/shared/testutil/assert" "github.com/prysmaticlabs/prysm/shared/testutil/require" "github.com/r3labs/sse" ) func TestSSZRequested(t *testing.T) { t.Run("ssz_requested", func(t *testing.T) { request := httptest.NewRequest("GET", "http://foo.example", nil) request.Header["Accept"] = []string{"application/octet-stream"} result := sszRequested(request) assert.Equal(t, true, result) }) t.Run("multiple_content_types", func(t *testing.T) { request := httptest.NewRequest("GET", "http://foo.example", nil) request.Header["Accept"] = []string{"application/json", "application/octet-stream"} result := sszRequested(request) assert.Equal(t, true, result) }) t.Run("no_header", func(t *testing.T) { request := httptest.NewRequest("GET", "http://foo.example", nil) result := sszRequested(request) assert.Equal(t, false, result) }) t.Run("other_content_type", func(t *testing.T) { request := httptest.NewRequest("GET", "http://foo.example", nil) request.Header["Accept"] = []string{"application/json"} result := sszRequested(request) assert.Equal(t, false, result) }) } func TestPrepareSSZRequestForProxying(t *testing.T) { middleware := &gateway.ApiProxyMiddleware{ GatewayAddress: "http://gateway.example", } endpoint := gateway.Endpoint{ Path: "http://foo.example", } var body bytes.Buffer request := httptest.NewRequest("GET", "http://foo.example", &body) errJson := prepareSSZRequestForProxying(middleware, endpoint, request, "/ssz") require.Equal(t, true, errJson == nil) assert.Equal(t, "/ssz", request.URL.Path) } func TestSerializeMiddlewareResponseIntoSSZ(t *testing.T) { t.Run("ok", func(t *testing.T) { ssz, errJson := serializeMiddlewareResponseIntoSSZ("Zm9v") require.Equal(t, true, errJson == nil) assert.DeepEqual(t, []byte("foo"), ssz) }) t.Run("invalid_data", func(t *testing.T) { _, errJson := serializeMiddlewareResponseIntoSSZ("invalid") require.Equal(t, false, errJson == nil) assert.Equal(t, true, strings.Contains(errJson.Msg(), "could not decode response body into base64")) assert.Equal(t, http.StatusInternalServerError, errJson.StatusCode()) }) } func TestWriteSSZResponseHeaderAndBody(t *testing.T) { t.Run("ok", func(t *testing.T) { response := &http.Response{ Header: http.Header{ "Foo": []string{"foo"}, "Grpc-Metadata-" + grpcutils.HttpCodeMetadataKey: []string{"204"}, }, } responseSsz := []byte("ssz") writer := httptest.NewRecorder() writer.Body = &bytes.Buffer{} errJson := writeSSZResponseHeaderAndBody(response, writer, responseSsz, "test.ssz") require.Equal(t, true, errJson == nil) v, ok := writer.Header()["Foo"] require.Equal(t, true, ok, "header not found") require.Equal(t, 1, len(v), "wrong number of header values") assert.Equal(t, "foo", v[0]) v, ok = writer.Header()["Content-Length"] require.Equal(t, true, ok, "header not found") require.Equal(t, 1, len(v), "wrong number of header values") assert.Equal(t, "3", v[0]) v, ok = writer.Header()["Content-Type"] require.Equal(t, true, ok, "header not found") require.Equal(t, 1, len(v), "wrong number of header values") assert.Equal(t, "application/octet-stream", v[0]) v, ok = writer.Header()["Content-Disposition"] require.Equal(t, true, ok, "header not found") require.Equal(t, 1, len(v), "wrong number of header values") assert.Equal(t, "attachment; filename=test.ssz", v[0]) assert.Equal(t, 204, writer.Code) }) t.Run("no_grpc_status_code_header", func(t *testing.T) { response := &http.Response{ Header: http.Header{}, StatusCode: 204, } responseSsz := []byte("ssz") writer := httptest.NewRecorder() writer.Body = &bytes.Buffer{} errJson := writeSSZResponseHeaderAndBody(response, writer, responseSsz, "test.ssz") require.Equal(t, true, errJson == nil) assert.Equal(t, 204, writer.Code) }) t.Run("invalid_status_code", func(t *testing.T) { response := &http.Response{ Header: http.Header{ "Foo": []string{"foo"}, "Grpc-Metadata-" + grpcutils.HttpCodeMetadataKey: []string{"invalid"}, }, } responseSsz := []byte("ssz") writer := httptest.NewRecorder() writer.Body = &bytes.Buffer{} errJson := writeSSZResponseHeaderAndBody(response, writer, responseSsz, "test.ssz") require.Equal(t, false, errJson == nil) assert.Equal(t, true, strings.Contains(errJson.Msg(), "could not parse status code")) assert.Equal(t, http.StatusInternalServerError, errJson.StatusCode()) }) } func TestReceiveEvents(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) ch := make(chan *sse.Event) w := httptest.NewRecorder() w.Body = &bytes.Buffer{} req := httptest.NewRequest("GET", "http://foo.example", &bytes.Buffer{}) req = req.WithContext(ctx) go func() { base64Val := "Zm9v" data := &eventFinalizedCheckpointJson{ Block: base64Val, State: base64Val, Epoch: "1", } bData, err := json.Marshal(data) require.NoError(t, err) msg := &sse.Event{ Data: bData, Event: []byte(eventsv1.FinalizedCheckpointTopic), } ch <- msg time.Sleep(time.Second) cancel() }() errJson := receiveEvents(ch, w, req) assert.Equal(t, true, errJson == nil) } func TestReceiveEvents_EventNotSupported(t *testing.T) { ch := make(chan *sse.Event) w := httptest.NewRecorder() w.Body = &bytes.Buffer{} req := httptest.NewRequest("GET", "http://foo.example", &bytes.Buffer{}) go func() { msg := &sse.Event{ Data: []byte("foo"), Event: []byte("not_supported"), } ch <- msg }() errJson := receiveEvents(ch, w, req) require.NotNil(t, errJson) assert.Equal(t, "Event type 'not_supported' not supported", errJson.Msg()) } func TestWriteEvent(t *testing.T) { base64Val := "Zm9v" data := &eventFinalizedCheckpointJson{ Block: base64Val, State: base64Val, Epoch: "1", } bData, err := json.Marshal(data) require.NoError(t, err) msg := &sse.Event{ Data: bData, Event: []byte("test_event"), } w := httptest.NewRecorder() w.Body = &bytes.Buffer{} errJson := writeEvent(msg, w, &eventFinalizedCheckpointJson{}) require.Equal(t, true, errJson == nil) written := w.Body.String() assert.Equal(t, "event: test_event\ndata: {\"block\":\"0x666f6f\",\"state\":\"0x666f6f\",\"epoch\":\"1\"}\n\n", written) }