mirror of
https://gitlab.com/pulsechaincom/prysm-pulse.git
synced 2024-12-25 12:57:18 +00:00
dd0ae1bbef
Co-authored-by: prylabs-bulldozer[bot] <58059840+prylabs-bulldozer[bot]@users.noreply.github.com>
327 lines
11 KiB
Go
327 lines
11 KiB
Go
package gateway
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"io/ioutil"
|
|
"net/http"
|
|
"reflect"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/ethereum/go-ethereum/common/hexutil"
|
|
"github.com/pkg/errors"
|
|
"github.com/prysmaticlabs/prysm/shared/grpcutils"
|
|
"github.com/wealdtech/go-bytesutil"
|
|
)
|
|
|
|
// DeserializeRequestBodyIntoContainer deserializes the request's body into an endpoint-specific struct.
|
|
func DeserializeRequestBodyIntoContainer(body io.Reader, requestContainer interface{}) ErrorJson {
|
|
if err := json.NewDecoder(body).Decode(&requestContainer); err != nil {
|
|
e := errors.Wrap(err, "could not decode request body")
|
|
return &DefaultErrorJson{Message: e.Error(), Code: http.StatusInternalServerError}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// ProcessRequestContainerFields processes fields of an endpoint-specific container according to field tags.
|
|
func ProcessRequestContainerFields(requestContainer interface{}) ErrorJson {
|
|
if err := processField(requestContainer, []fieldProcessor{
|
|
{
|
|
tag: "hex",
|
|
f: hexToBase64Processor,
|
|
},
|
|
}); err != nil {
|
|
e := errors.Wrapf(err, "could not process request data")
|
|
return &DefaultErrorJson{Message: e.Error(), Code: http.StatusInternalServerError}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// SetRequestBodyToRequestContainer makes the endpoint-specific container the new body of the request.
|
|
func SetRequestBodyToRequestContainer(requestContainer interface{}, req *http.Request) ErrorJson {
|
|
// Serialize the struct, which now includes a base64-encoded value, into JSON.
|
|
j, err := json.Marshal(requestContainer)
|
|
if err != nil {
|
|
e := errors.Wrapf(err, "could not marshal request")
|
|
return &DefaultErrorJson{Message: e.Error(), Code: http.StatusInternalServerError}
|
|
}
|
|
// Set the body to the new JSON.
|
|
req.Body = ioutil.NopCloser(bytes.NewReader(j))
|
|
req.Header.Set("Content-Length", strconv.Itoa(len(j)))
|
|
req.ContentLength = int64(len(j))
|
|
return nil
|
|
}
|
|
|
|
// PrepareRequestForProxying applies additional logic to the request so that it can be correctly proxied to grpc-gateway.
|
|
func (m *ApiProxyMiddleware) PrepareRequestForProxying(endpoint Endpoint, req *http.Request) ErrorJson {
|
|
req.URL.Scheme = "http"
|
|
req.URL.Host = m.GatewayAddress
|
|
req.RequestURI = ""
|
|
if errJson := HandleURLParameters(endpoint.Path, req, endpoint.GetRequestURLLiterals); errJson != nil {
|
|
return errJson
|
|
}
|
|
return HandleQueryParameters(req, endpoint.GetRequestQueryParams)
|
|
}
|
|
|
|
// ProxyRequest proxies the request to grpc-gateway.
|
|
func ProxyRequest(req *http.Request) (*http.Response, ErrorJson) {
|
|
grpcResp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
e := errors.Wrapf(err, "could not proxy request")
|
|
return nil, &DefaultErrorJson{Message: e.Error(), Code: http.StatusInternalServerError}
|
|
}
|
|
if grpcResp == nil {
|
|
return nil, &DefaultErrorJson{Message: "nil response from gRPC-gateway", Code: http.StatusInternalServerError}
|
|
}
|
|
return grpcResp, nil
|
|
}
|
|
|
|
// ReadGrpcResponseBody reads the body from the grpc-gateway's response.
|
|
func ReadGrpcResponseBody(r io.Reader) ([]byte, ErrorJson) {
|
|
body, err := ioutil.ReadAll(r)
|
|
if err != nil {
|
|
e := errors.Wrapf(err, "could not read response body")
|
|
return nil, &DefaultErrorJson{Message: e.Error(), Code: http.StatusInternalServerError}
|
|
}
|
|
return body, nil
|
|
}
|
|
|
|
// DeserializeGrpcResponseBodyIntoErrorJson deserializes the body from the grpc-gateway's response into an error struct.
|
|
// The struct can be later examined to check if the request resulted in an error.
|
|
func DeserializeGrpcResponseBodyIntoErrorJson(errJson ErrorJson, body []byte) ErrorJson {
|
|
if err := json.Unmarshal(body, errJson); err != nil {
|
|
e := errors.Wrapf(err, "could not unmarshal error")
|
|
return &DefaultErrorJson{Message: e.Error(), Code: http.StatusInternalServerError}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// HandleGrpcResponseError acts on an error that resulted from a grpc-gateway's response.
|
|
func HandleGrpcResponseError(errJson ErrorJson, resp *http.Response, w http.ResponseWriter) {
|
|
// Something went wrong, but the request completed, meaning we can write headers and the error message.
|
|
for h, vs := range resp.Header {
|
|
for _, v := range vs {
|
|
w.Header().Set(h, v)
|
|
}
|
|
}
|
|
// Set code to HTTP code because unmarshalled body contained gRPC code.
|
|
errJson.SetCode(resp.StatusCode)
|
|
WriteError(w, errJson, resp.Header)
|
|
}
|
|
|
|
// GrpcResponseIsStatusCodeOnly checks whether a grpc-gateway's response contained no body.
|
|
func GrpcResponseIsStatusCodeOnly(req *http.Request, responseContainer interface{}) bool {
|
|
return req.Method == "GET" && responseContainer == nil
|
|
}
|
|
|
|
// DeserializeGrpcResponseBodyIntoContainer deserializes the grpc-gateway's response body into an endpoint-specific struct.
|
|
func DeserializeGrpcResponseBodyIntoContainer(body []byte, responseContainer interface{}) ErrorJson {
|
|
if err := json.Unmarshal(body, &responseContainer); err != nil {
|
|
e := errors.Wrapf(err, "could not unmarshal response")
|
|
return &DefaultErrorJson{Message: e.Error(), Code: http.StatusInternalServerError}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// ProcessMiddlewareResponseFields processes fields of an endpoint-specific container according to field tags.
|
|
func ProcessMiddlewareResponseFields(responseContainer interface{}) ErrorJson {
|
|
if err := processField(responseContainer, []fieldProcessor{
|
|
{
|
|
tag: "hex",
|
|
f: base64ToHexProcessor,
|
|
},
|
|
{
|
|
tag: "enum",
|
|
f: enumToLowercaseProcessor,
|
|
},
|
|
{
|
|
tag: "time",
|
|
f: timeToUnixProcessor,
|
|
},
|
|
}); err != nil {
|
|
e := errors.Wrapf(err, "could not process response data")
|
|
return &DefaultErrorJson{Message: e.Error(), Code: http.StatusInternalServerError}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// SerializeMiddlewareResponseIntoJson serializes the endpoint-specific response struct into a JSON representation.
|
|
func SerializeMiddlewareResponseIntoJson(responseContainer interface{}) (jsonResponse []byte, errJson ErrorJson) {
|
|
j, err := json.Marshal(responseContainer)
|
|
if err != nil {
|
|
e := errors.Wrapf(err, "could not marshal response")
|
|
return nil, &DefaultErrorJson{Message: e.Error(), Code: http.StatusInternalServerError}
|
|
}
|
|
return j, nil
|
|
}
|
|
|
|
// WriteMiddlewareResponseHeadersAndBody populates headers and the body of the final response.
|
|
func WriteMiddlewareResponseHeadersAndBody(req *http.Request, grpcResp *http.Response, responseJson []byte, w http.ResponseWriter) ErrorJson {
|
|
var statusCodeHeader string
|
|
for h, vs := range grpcResp.Header {
|
|
// We don't want to expose any gRPC metadata in the HTTP response, so we skip forwarding metadata headers.
|
|
if strings.HasPrefix(h, "Grpc-Metadata") {
|
|
if h == "Grpc-Metadata-"+grpcutils.HttpCodeMetadataKey {
|
|
statusCodeHeader = vs[0]
|
|
}
|
|
} else {
|
|
for _, v := range vs {
|
|
w.Header().Set(h, v)
|
|
}
|
|
}
|
|
}
|
|
if req.Method == "GET" {
|
|
w.Header().Set("Content-Length", strconv.Itoa(len(responseJson)))
|
|
if statusCodeHeader != "" {
|
|
code, err := strconv.Atoi(statusCodeHeader)
|
|
if err != nil {
|
|
e := errors.Wrapf(err, "could not parse status code")
|
|
return &DefaultErrorJson{Message: e.Error(), Code: http.StatusInternalServerError}
|
|
}
|
|
w.WriteHeader(code)
|
|
} else {
|
|
w.WriteHeader(grpcResp.StatusCode)
|
|
}
|
|
if _, err := io.Copy(w, ioutil.NopCloser(bytes.NewReader(responseJson))); err != nil {
|
|
e := errors.Wrapf(err, "could not write response message")
|
|
return &DefaultErrorJson{Message: e.Error(), Code: http.StatusInternalServerError}
|
|
}
|
|
} else if req.Method == "POST" {
|
|
w.WriteHeader(grpcResp.StatusCode)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// WriteError writes the error by manipulating headers and the body of the final response.
|
|
func WriteError(w http.ResponseWriter, errJson ErrorJson, responseHeader http.Header) {
|
|
// Include custom error in the error JSON.
|
|
if responseHeader != nil {
|
|
customError, ok := responseHeader["Grpc-Metadata-"+grpcutils.CustomErrorMetadataKey]
|
|
if ok {
|
|
// Assume header has only one value and read the 0 index.
|
|
if err := json.Unmarshal([]byte(customError[0]), errJson); err != nil {
|
|
log.WithError(err).Error("Could not unmarshal custom error message")
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
j, err := json.Marshal(errJson)
|
|
if err != nil {
|
|
log.WithError(err).Error("Could not marshal error message")
|
|
return
|
|
}
|
|
|
|
w.Header().Set("Content-Length", strconv.Itoa(len(j)))
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(errJson.StatusCode())
|
|
if _, err := io.Copy(w, ioutil.NopCloser(bytes.NewReader(j))); err != nil {
|
|
log.WithError(err).Error("Could not write error message")
|
|
}
|
|
}
|
|
|
|
// Cleanup performs final cleanup on the initial response from grpc-gateway.
|
|
func Cleanup(grpcResponseBody io.ReadCloser) ErrorJson {
|
|
if err := grpcResponseBody.Close(); err != nil {
|
|
e := errors.Wrapf(err, "could not close response body")
|
|
return &DefaultErrorJson{Message: e.Error(), Code: http.StatusInternalServerError}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// processField calls each processor function on any field that has the matching tag set.
|
|
// It is a recursive function.
|
|
func processField(s interface{}, processors []fieldProcessor) error {
|
|
kind := reflect.TypeOf(s).Kind()
|
|
if kind != reflect.Ptr && kind != reflect.Slice && kind != reflect.Array {
|
|
return fmt.Errorf("processing fields of kind '%v' is unsupported", kind)
|
|
}
|
|
|
|
t := reflect.TypeOf(s).Elem()
|
|
v := reflect.Indirect(reflect.ValueOf(s))
|
|
|
|
for i := 0; i < t.NumField(); i++ {
|
|
switch v.Field(i).Kind() {
|
|
case reflect.Slice:
|
|
sliceElem := t.Field(i).Type.Elem()
|
|
kind := sliceElem.Kind()
|
|
// Recursively process slices to struct pointers.
|
|
if kind == reflect.Ptr && sliceElem.Elem().Kind() == reflect.Struct {
|
|
for j := 0; j < v.Field(i).Len(); j++ {
|
|
if err := processField(v.Field(i).Index(j).Interface(), processors); err != nil {
|
|
return errors.Wrapf(err, "could not process field '%s'", t.Field(i).Name)
|
|
}
|
|
}
|
|
}
|
|
// Process each string in string slices.
|
|
if kind == reflect.String {
|
|
for _, proc := range processors {
|
|
_, hasTag := t.Field(i).Tag.Lookup(proc.tag)
|
|
if hasTag {
|
|
for j := 0; j < v.Field(i).Len(); j++ {
|
|
if err := proc.f(v.Field(i).Index(j)); err != nil {
|
|
return errors.Wrapf(err, "could not process field '%s'", t.Field(i).Name)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
}
|
|
// Recursively process struct pointers.
|
|
case reflect.Ptr:
|
|
if v.Field(i).Elem().Kind() == reflect.Struct {
|
|
if err := processField(v.Field(i).Interface(), processors); err != nil {
|
|
return errors.Wrapf(err, "could not process field '%s'", t.Field(i).Name)
|
|
}
|
|
}
|
|
default:
|
|
field := t.Field(i)
|
|
for _, proc := range processors {
|
|
if _, hasTag := field.Tag.Lookup(proc.tag); hasTag {
|
|
if err := proc.f(v.Field(i)); err != nil {
|
|
return errors.Wrapf(err, "could not process field '%s'", t.Field(i).Name)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func hexToBase64Processor(v reflect.Value) error {
|
|
b, err := bytesutil.FromHexString(v.String())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
v.SetString(base64.StdEncoding.EncodeToString(b))
|
|
return nil
|
|
}
|
|
|
|
func base64ToHexProcessor(v reflect.Value) error {
|
|
b, err := base64.StdEncoding.DecodeString(v.String())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
v.SetString(hexutil.Encode(b))
|
|
return nil
|
|
}
|
|
|
|
func enumToLowercaseProcessor(v reflect.Value) error {
|
|
v.SetString(strings.ToLower(v.String()))
|
|
return nil
|
|
}
|
|
|
|
func timeToUnixProcessor(v reflect.Value) error {
|
|
t, err := time.Parse(time.RFC3339, v.String())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
v.SetString(strconv.FormatUint(uint64(t.Unix()), 10))
|
|
return nil
|
|
}
|