diff --git a/api/gateway/BUILD.bazel b/api/gateway/BUILD.bazel index 91a943afc..9b9643d95 100644 --- a/api/gateway/BUILD.bazel +++ b/api/gateway/BUILD.bazel @@ -5,6 +5,7 @@ go_library( srcs = [ "gateway.go", "log.go", + "options.go", ], importpath = "github.com/prysmaticlabs/prysm/api/gateway", visibility = [ diff --git a/api/gateway/apimiddleware/api_middleware.go b/api/gateway/apimiddleware/api_middleware.go index 9f4040616..733e16e53 100644 --- a/api/gateway/apimiddleware/api_middleware.go +++ b/api/gateway/apimiddleware/api_middleware.go @@ -14,6 +14,7 @@ import ( type ApiProxyMiddleware struct { GatewayAddress string EndpointCreator EndpointFactory + router *mux.Router } // EndpointFactory is responsible for creating new instances of Endpoint values. @@ -29,6 +30,8 @@ type Endpoint struct { GetResponse interface{} // The struct corresponding to the JSON structure used in a GET response. PostRequest interface{} // The struct corresponding to the JSON structure used in a POST request. PostResponse interface{} // The struct corresponding to the JSON structure used in a POST response. + DeleteRequest interface{} // The struct corresponding to the JSON structure used in a DELETE request. + DeleteResponse interface{} // The struct corresponding to the JSON structure used in a DELETE response. RequestURLLiterals []string // Names of URL parameters that should not be base64-encoded. RequestQueryParams []QueryParam // Query parameters of the request. Err ErrorJson // The struct corresponding to the error that should be returned in case of a request failure. @@ -74,18 +77,24 @@ type fieldProcessor struct { // Run starts the proxy, registering all proxy endpoints. func (m *ApiProxyMiddleware) Run(gatewayRouter *mux.Router) { for _, path := range m.EndpointCreator.Paths() { - m.handleApiPath(gatewayRouter, path, m.EndpointCreator) + gatewayRouter.HandleFunc(path, m.WithMiddleware(path)) } + m.router = gatewayRouter } -func (m *ApiProxyMiddleware) handleApiPath(gatewayRouter *mux.Router, path string, endpointFactory EndpointFactory) { - gatewayRouter.HandleFunc(path, func(w http.ResponseWriter, req *http.Request) { - endpoint, err := endpointFactory.Create(path) - if err != nil { - errJson := InternalServerErrorWithMessage(err, "could not create endpoint") - WriteError(w, errJson, nil) - } +// ServeHTTP for the proxy middleware. +func (m *ApiProxyMiddleware) ServeHTTP(w http.ResponseWriter, req *http.Request) { + m.router.ServeHTTP(w, req) +} +// WithMiddleware wraps the given endpoint handler with the middleware logic. +func (m *ApiProxyMiddleware) WithMiddleware(path string) http.HandlerFunc { + endpoint, err := m.EndpointCreator.Create(path) + if err != nil { + log.WithError(err).Errorf("Could not create endpoint for path: %s", path) + return nil + } + return func(w http.ResponseWriter, req *http.Request) { for _, handler := range endpoint.CustomHandlers { if handler(m, *endpoint, w, req) { return @@ -93,16 +102,14 @@ func (m *ApiProxyMiddleware) handleApiPath(gatewayRouter *mux.Router, path strin } if req.Method == "POST" { - if errJson := deserializeRequestBodyIntoContainerWrapped(endpoint, req, w); errJson != nil { + if errJson := handlePostRequestForEndpoint(endpoint, w, req); errJson != nil { WriteError(w, errJson, nil) return } + } - if errJson := ProcessRequestContainerFields(endpoint.PostRequest); errJson != nil { - WriteError(w, errJson, nil) - return - } - if errJson := SetRequestBodyToRequestContainer(endpoint.PostRequest, req); errJson != nil { + if req.Method == "DELETE" { + if errJson := handleDeleteRequestForEndpoint(endpoint, req); errJson != nil { WriteError(w, errJson, nil) return } @@ -137,6 +144,8 @@ func (m *ApiProxyMiddleware) handleApiPath(gatewayRouter *mux.Router, path strin var resp interface{} if req.Method == "GET" { resp = endpoint.GetResponse + } else if req.Method == "DELETE" { + resp = endpoint.DeleteResponse } else { resp = endpoint.PostResponse } @@ -164,7 +173,27 @@ func (m *ApiProxyMiddleware) handleApiPath(gatewayRouter *mux.Router, path strin WriteError(w, errJson, nil) return } - }) + } +} + +func handlePostRequestForEndpoint(endpoint *Endpoint, w http.ResponseWriter, req *http.Request) ErrorJson { + if errJson := deserializeRequestBodyIntoContainerWrapped(endpoint, req, w); errJson != nil { + return errJson + } + if errJson := ProcessRequestContainerFields(endpoint.PostRequest); errJson != nil { + return errJson + } + return SetRequestBodyToRequestContainer(endpoint.PostRequest, req) +} + +func handleDeleteRequestForEndpoint(endpoint *Endpoint, req *http.Request) ErrorJson { + if errJson := DeserializeRequestBodyIntoContainer(req.Body, endpoint.DeleteRequest); errJson != nil { + return errJson + } + if errJson := ProcessRequestContainerFields(endpoint.DeleteRequest); errJson != nil { + return errJson + } + return SetRequestBodyToRequestContainer(endpoint.DeleteRequest, req) } func deserializeRequestBodyIntoContainerWrapped(endpoint *Endpoint, req *http.Request, w http.ResponseWriter) ErrorJson { diff --git a/api/gateway/gateway.go b/api/gateway/gateway.go index 3ec0ac1f3..45e400a15 100644 --- a/api/gateway/gateway.go +++ b/api/gateway/gateway.go @@ -34,74 +34,51 @@ type PbMux struct { type PbHandlerRegistration func(context.Context, *gwruntime.ServeMux, *grpc.ClientConn) error // MuxHandler is a function that implements the mux handler functionality. -type MuxHandler func(http.Handler, http.ResponseWriter, *http.Request) +type MuxHandler func( + apiMiddlewareHandler *apimiddleware.ApiProxyMiddleware, + h http.HandlerFunc, + w http.ResponseWriter, + req *http.Request, +) + +// Config parameters for setting up the gateway service. +type config struct { + maxCallRecvMsgSize uint64 + remoteCert string + gatewayAddr string + remoteAddr string + allowedOrigins []string + apiMiddlewareEndpointFactory apimiddleware.EndpointFactory + muxHandler MuxHandler + pbHandlers []*PbMux + router *mux.Router +} // Gateway is the gRPC gateway to serve HTTP JSON traffic as a proxy and forward it to the gRPC server. type Gateway struct { - conn *grpc.ClientConn - pbHandlers []*PbMux - muxHandler MuxHandler - maxCallRecvMsgSize uint64 - router *mux.Router - server *http.Server - cancel context.CancelFunc - remoteCert string - gatewayAddr string - apiMiddlewareEndpointFactory apimiddleware.EndpointFactory - ctx context.Context - startFailure error - remoteAddr string - allowedOrigins []string + cfg *config + conn *grpc.ClientConn + server *http.Server + cancel context.CancelFunc + proxy *apimiddleware.ApiProxyMiddleware + ctx context.Context + startFailure error } // New returns a new instance of the Gateway. -func New( - ctx context.Context, - pbHandlers []*PbMux, - muxHandler MuxHandler, - remoteAddr, - gatewayAddress string, -) *Gateway { +func New(ctx context.Context, opts ...Option) (*Gateway, error) { g := &Gateway{ - pbHandlers: pbHandlers, - muxHandler: muxHandler, - router: mux.NewRouter(), - gatewayAddr: gatewayAddress, - ctx: ctx, - remoteAddr: remoteAddr, - allowedOrigins: []string{}, + ctx: ctx, + cfg: &config{ + router: mux.NewRouter(), + }, } - return g -} - -// WithRouter allows adding a custom mux router to the gateway. -func (g *Gateway) WithRouter(r *mux.Router) *Gateway { - g.router = r - return g -} - -// WithAllowedOrigins allows adding a set of allowed origins to the gateway. -func (g *Gateway) WithAllowedOrigins(origins []string) *Gateway { - g.allowedOrigins = origins - return g -} - -// WithRemoteCert allows adding a custom certificate to the gateway, -func (g *Gateway) WithRemoteCert(cert string) *Gateway { - g.remoteCert = cert - return g -} - -// WithMaxCallRecvMsgSize allows specifying the maximum allowed gRPC message size. -func (g *Gateway) WithMaxCallRecvMsgSize(size uint64) *Gateway { - g.maxCallRecvMsgSize = size - return g -} - -// WithApiMiddleware allows adding API Middleware proxy to the gateway. -func (g *Gateway) WithApiMiddleware(endpointFactory apimiddleware.EndpointFactory) *Gateway { - g.apiMiddlewareEndpointFactory = endpointFactory - return g + for _, opt := range opts { + if err := opt(g); err != nil { + return nil, err + } + } + return g, nil } // Start the gateway service. @@ -109,7 +86,7 @@ func (g *Gateway) Start() { ctx, cancel := context.WithCancel(g.ctx) g.cancel = cancel - conn, err := g.dial(ctx, "tcp", g.remoteAddr) + conn, err := g.dial(ctx, "tcp", g.cfg.remoteAddr) if err != nil { log.WithError(err).Error("Failed to connect to gRPC server") g.startFailure = err @@ -117,7 +94,7 @@ func (g *Gateway) Start() { } g.conn = conn - for _, h := range g.pbHandlers { + for _, h := range g.cfg.pbHandlers { for _, r := range h.Registrations { if err := r(ctx, h.Mux, g.conn); err != nil { log.WithError(err).Error("Failed to register handler") @@ -126,29 +103,30 @@ func (g *Gateway) Start() { } } for _, p := range h.Patterns { - g.router.PathPrefix(p).Handler(h.Mux) + g.cfg.router.PathPrefix(p).Handler(h.Mux) } } - corsMux := g.corsMiddleware(g.router) + corsMux := g.corsMiddleware(g.cfg.router) + _ = corsMux - if g.muxHandler != nil { - g.router.PathPrefix("/").HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - g.muxHandler(corsMux, w, r) + if g.cfg.apiMiddlewareEndpointFactory != nil && !g.cfg.apiMiddlewareEndpointFactory.IsNil() { + g.registerApiMiddleware() + } + + if g.cfg.muxHandler != nil { + g.cfg.router.PathPrefix("/").HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + g.cfg.muxHandler(g.proxy, g.cfg.router.ServeHTTP, w, r) }) } - if g.apiMiddlewareEndpointFactory != nil && !g.apiMiddlewareEndpointFactory.IsNil() { - go g.registerApiMiddleware() - } - g.server = &http.Server{ - Addr: g.gatewayAddr, - Handler: g.router, + Addr: g.cfg.gatewayAddr, + Handler: g.cfg.router, } go func() { - log.WithField("address", g.gatewayAddr).Info("Starting gRPC gateway") + log.WithField("address", g.cfg.gatewayAddr).Info("Starting gRPC gateway") if err := g.server.ListenAndServe(); err != http.ErrServerClosed { log.WithError(err).Error("Failed to start gRPC gateway") g.startFailure = err @@ -162,11 +140,9 @@ func (g *Gateway) Status() error { if g.startFailure != nil { return g.startFailure } - if s := g.conn.GetState(); s != connectivity.Ready { return fmt.Errorf("grpc server is %s", s) } - return nil } @@ -183,18 +159,16 @@ func (g *Gateway) Stop() error { } } } - if g.cancel != nil { g.cancel() } - return nil } func (g *Gateway) corsMiddleware(h http.Handler) http.Handler { c := cors.New(cors.Options{ - AllowedOrigins: g.allowedOrigins, - AllowedMethods: []string{http.MethodPost, http.MethodGet, http.MethodOptions}, + AllowedOrigins: g.cfg.allowedOrigins, + AllowedMethods: []string{http.MethodPost, http.MethodGet, http.MethodDelete, http.MethodOptions}, AllowCredentials: true, MaxAge: 600, AllowedHeaders: []string{"*"}, @@ -236,8 +210,8 @@ func (g *Gateway) dial(ctx context.Context, network, addr string) (*grpc.ClientC // "addr" must be a valid TCP address with a port number. func (g *Gateway) dialTCP(ctx context.Context, addr string) (*grpc.ClientConn, error) { security := grpc.WithInsecure() - if len(g.remoteCert) > 0 { - creds, err := credentials.NewClientTLSFromFile(g.remoteCert, "") + if len(g.cfg.remoteCert) > 0 { + creds, err := credentials.NewClientTLSFromFile(g.cfg.remoteCert, "") if err != nil { return nil, err } @@ -245,7 +219,7 @@ func (g *Gateway) dialTCP(ctx context.Context, addr string) (*grpc.ClientConn, e } opts := []grpc.DialOption{ security, - grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(int(g.maxCallRecvMsgSize))), + grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(int(g.cfg.maxCallRecvMsgSize))), } return grpc.DialContext(ctx, addr, opts...) @@ -266,16 +240,16 @@ func (g *Gateway) dialUnix(ctx context.Context, addr string) (*grpc.ClientConn, opts := []grpc.DialOption{ grpc.WithInsecure(), grpc.WithContextDialer(f), - grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(int(g.maxCallRecvMsgSize))), + grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(int(g.cfg.maxCallRecvMsgSize))), } return grpc.DialContext(ctx, addr, opts...) } func (g *Gateway) registerApiMiddleware() { - proxy := &apimiddleware.ApiProxyMiddleware{ - GatewayAddress: g.gatewayAddr, - EndpointCreator: g.apiMiddlewareEndpointFactory, + g.proxy = &apimiddleware.ApiProxyMiddleware{ + GatewayAddress: g.cfg.gatewayAddr, + EndpointCreator: g.cfg.apiMiddlewareEndpointFactory, } log.Info("Starting API middleware") - proxy.Run(g.router) + g.proxy.Run(g.cfg.router) } diff --git a/api/gateway/gateway_test.go b/api/gateway/gateway_test.go index 8a00ff859..2d03e6d20 100644 --- a/api/gateway/gateway_test.go +++ b/api/gateway/gateway_test.go @@ -40,26 +40,30 @@ func TestGateway_Customized(t *testing.T) { size := uint64(100) endpointFactory := &mockEndpointFactory{} - g := New( - context.Background(), - []*PbMux{}, - func(handler http.Handler, writer http.ResponseWriter, request *http.Request) { + opts := []Option{ + WithRouter(r), + WithRemoteCert(cert), + WithAllowedOrigins(origins), + WithMaxCallRecvMsgSize(size), + WithApiMiddleware(endpointFactory), + WithMuxHandler(func( + _ *apimiddleware.ApiProxyMiddleware, + _ http.HandlerFunc, + _ http.ResponseWriter, + _ *http.Request, + ) { + }), + } - }, - "", - "", - ).WithRouter(r). - WithRemoteCert(cert). - WithAllowedOrigins(origins). - WithMaxCallRecvMsgSize(size). - WithApiMiddleware(endpointFactory) + g, err := New(context.Background(), opts...) + require.NoError(t, err) - assert.Equal(t, r, g.router) - assert.Equal(t, cert, g.remoteCert) - require.Equal(t, 1, len(g.allowedOrigins)) - assert.Equal(t, origins[0], g.allowedOrigins[0]) - assert.Equal(t, size, g.maxCallRecvMsgSize) - assert.Equal(t, endpointFactory, g.apiMiddlewareEndpointFactory) + assert.Equal(t, r, g.cfg.router) + assert.Equal(t, cert, g.cfg.remoteCert) + require.Equal(t, 1, len(g.cfg.allowedOrigins)) + assert.Equal(t, origins[0], g.cfg.allowedOrigins[0]) + assert.Equal(t, size, g.cfg.maxCallRecvMsgSize) + assert.Equal(t, endpointFactory, g.cfg.apiMiddlewareEndpointFactory) } func TestGateway_StartStop(t *testing.T) { @@ -75,23 +79,27 @@ func TestGateway_StartStop(t *testing.T) { selfAddress := fmt.Sprintf("%s:%d", rpcHost, ctx.Int(flags.RPCPort.Name)) gatewayAddress := fmt.Sprintf("%s:%d", gatewayHost, gatewayPort) - g := New( - ctx.Context, - []*PbMux{}, - func(handler http.Handler, writer http.ResponseWriter, request *http.Request) { + opts := []Option{ + WithGatewayAddr(gatewayAddress), + WithRemoteAddr(selfAddress), + WithMuxHandler(func( + _ *apimiddleware.ApiProxyMiddleware, + _ http.HandlerFunc, + _ http.ResponseWriter, + _ *http.Request, + ) { + }), + } - }, - selfAddress, - gatewayAddress, - ) + g, err := New(context.Background(), opts...) + require.NoError(t, err) g.Start() go func() { require.LogsContain(t, hook, "Starting gRPC gateway") require.LogsDoNotContain(t, hook, "Starting API middleware") }() - - err := g.Stop() + err = g.Stop() require.NoError(t, err) } @@ -106,15 +114,15 @@ func TestGateway_NilHandler_NotFoundHandlerRegistered(t *testing.T) { selfAddress := fmt.Sprintf("%s:%d", rpcHost, ctx.Int(flags.RPCPort.Name)) gatewayAddress := fmt.Sprintf("%s:%d", gatewayHost, gatewayPort) - g := New( - ctx.Context, - []*PbMux{}, - /* muxHandler */ nil, - selfAddress, - gatewayAddress, - ) + opts := []Option{ + WithGatewayAddr(gatewayAddress), + WithRemoteAddr(selfAddress), + } + + g, err := New(context.Background(), opts...) + require.NoError(t, err) writer := httptest.NewRecorder() - g.router.ServeHTTP(writer, &http.Request{Method: "GET", Host: "localhost", URL: &url.URL{Path: "/foo"}}) + g.cfg.router.ServeHTTP(writer, &http.Request{Method: "GET", Host: "localhost", URL: &url.URL{Path: "/foo"}}) assert.Equal(t, http.StatusNotFound, writer.Code) } diff --git a/api/gateway/options.go b/api/gateway/options.go new file mode 100644 index 000000000..0cf932108 --- /dev/null +++ b/api/gateway/options.go @@ -0,0 +1,81 @@ +package gateway + +import ( + "github.com/gorilla/mux" + "github.com/prysmaticlabs/prysm/api/gateway/apimiddleware" +) + +type Option func(g *Gateway) error + +func (g *Gateway) SetRouter(r *mux.Router) *Gateway { + g.cfg.router = r + return g +} + +func WithPbHandlers(handlers []*PbMux) Option { + return func(g *Gateway) error { + g.cfg.pbHandlers = handlers + return nil + } +} + +func WithMuxHandler(m MuxHandler) Option { + return func(g *Gateway) error { + g.cfg.muxHandler = m + return nil + } +} + +func WithGatewayAddr(addr string) Option { + return func(g *Gateway) error { + g.cfg.gatewayAddr = addr + return nil + } +} + +func WithRemoteAddr(addr string) Option { + return func(g *Gateway) error { + g.cfg.remoteAddr = addr + return nil + } +} + +// WithRouter allows adding a custom mux router to the gateway. +func WithRouter(r *mux.Router) Option { + return func(g *Gateway) error { + g.cfg.router = r + return nil + } +} + +// WithAllowedOrigins allows adding a set of allowed origins to the gateway. +func WithAllowedOrigins(origins []string) Option { + return func(g *Gateway) error { + g.cfg.allowedOrigins = origins + return nil + } +} + +// WithRemoteCert allows adding a custom certificate to the gateway, +func WithRemoteCert(cert string) Option { + return func(g *Gateway) error { + g.cfg.remoteCert = cert + return nil + } +} + +// WithMaxCallRecvMsgSize allows specifying the maximum allowed gRPC message size. +func WithMaxCallRecvMsgSize(size uint64) Option { + return func(g *Gateway) error { + g.cfg.maxCallRecvMsgSize = size + return nil + } +} + +// WithApiMiddleware allows adding an API middleware proxy to the gateway. +func WithApiMiddleware(endpointFactory apimiddleware.EndpointFactory) Option { + return func(g *Gateway) error { + g.cfg.apiMiddlewareEndpointFactory = endpointFactory + return nil + } +} diff --git a/beacon-chain/node/node.go b/beacon-chain/node/node.go index 7e2ce1899..5a5eb61ba 100644 --- a/beacon-chain/node/node.go +++ b/beacon-chain/node/node.go @@ -830,19 +830,22 @@ func (b *BeaconNode) registerGRPCGateway() error { muxs = append(muxs, gatewayConfig.EthPbMux) } - g := apigateway.New( - b.ctx, - muxs, - gatewayConfig.Handler, - selfAddress, - gatewayAddress, - ).WithAllowedOrigins(allowedOrigins). - WithRemoteCert(selfCert). - WithMaxCallRecvMsgSize(maxCallSize) - if flags.EnableHTTPEthAPI(httpModules) { - g.WithApiMiddleware(&apimiddleware.BeaconEndpointFactory{}) + opts := []apigateway.Option{ + apigateway.WithGatewayAddr(gatewayAddress), + apigateway.WithRemoteAddr(selfAddress), + apigateway.WithPbHandlers(muxs), + apigateway.WithMuxHandler(gatewayConfig.Handler), + apigateway.WithRemoteCert(selfCert), + apigateway.WithMaxCallRecvMsgSize(maxCallSize), + apigateway.WithAllowedOrigins(allowedOrigins), + } + if flags.EnableHTTPEthAPI(httpModules) { + opts = append(opts, apigateway.WithApiMiddleware(&apimiddleware.BeaconEndpointFactory{})) + } + g, err := apigateway.New(b.ctx, opts...) + if err != nil { + return err } - return b.services.RegisterService(g) } diff --git a/beacon-chain/server/main.go b/beacon-chain/server/main.go index 3a1c5c609..e3a7f9a93 100644 --- a/beacon-chain/server/main.go +++ b/beacon-chain/server/main.go @@ -52,23 +52,28 @@ func main() { if gatewayConfig.EthPbMux != nil { muxs = append(muxs, gatewayConfig.EthPbMux) } + opts := []gateway.Option{ + gateway.WithPbHandlers(muxs), + gateway.WithMuxHandler(gatewayConfig.Handler), + gateway.WithRemoteAddr(*beaconRPC), + gateway.WithGatewayAddr(fmt.Sprintf("%s:%d", *host, *port)), + gateway.WithAllowedOrigins(strings.Split(*allowedOrigins, ",")), + gateway.WithMaxCallRecvMsgSize(uint64(*grpcMaxMsgSize)), + } - gw := gateway.New( - context.Background(), - muxs, - gatewayConfig.Handler, - *beaconRPC, - fmt.Sprintf("%s:%d", *host, *port), - ).WithAllowedOrigins(strings.Split(*allowedOrigins, ",")). - WithMaxCallRecvMsgSize(uint64(*grpcMaxMsgSize)) if flags.EnableHTTPEthAPI(*httpModules) { - gw.WithApiMiddleware(&apimiddleware.BeaconEndpointFactory{}) + opts = append(opts, gateway.WithApiMiddleware(&apimiddleware.BeaconEndpointFactory{})) + } + + gw, err := gateway.New(context.Background(), opts...) + if err != nil { + log.Fatal(err) } r := mux.NewRouter() r.HandleFunc("/swagger/", gateway.SwaggerServer()) r.HandleFunc("/healthz", healthzServer(gw)) - gw = gw.WithRouter(r) + gw.SetRouter(r) gw.Start() diff --git a/proto/eth/service/BUILD.bazel b/proto/eth/service/BUILD.bazel index 1f3bc8347..bc18c253c 100644 --- a/proto/eth/service/BUILD.bazel +++ b/proto/eth/service/BUILD.bazel @@ -12,27 +12,6 @@ proto_library( "events_service.proto", "node_service.proto", "validator_service.proto", - ], - visibility = ["//visibility:public"], - deps = [ - "//proto/eth/ext:proto", - "//proto/eth/v1:proto", - "//proto/eth/v2:proto", - "@com_google_protobuf//:descriptor_proto", - "@com_google_protobuf//:empty_proto", - "@com_google_protobuf//:timestamp_proto", - "@go_googleapis//google/api:annotations_proto", - "@com_github_grpc_ecosystem_grpc_gateway_v2//proto/gateway:event_source_proto", - ], -) - -# We create a custom proto library for key_management.proto as it requires a different -# compiler plugin for grpc gateway than the others. Namely, it requires adding the option -# --allow_delete_body=true to allow grpc gateway endpoints to take in DELETE HTTP requests -# with a request body properly. -proto_library( - name = "custom_proto", - srcs = [ "key_management.proto", ], visibility = ["//visibility:public"], @@ -67,35 +46,32 @@ go_proto_library( ], ) -go_proto_library( - name = "custom_go_proto", - compilers = ["@com_github_prysmaticlabs_protoc_gen_go_cast//:go_cast_grpc",], - importpath = "github.com/prysmaticlabs/prysm/proto/eth/service", - proto = ":custom_proto", - visibility = ["//visibility:public"], - deps = [ - "//proto/eth/ext:go_default_library", - "//proto/eth/v1:go_default_library", - "//proto/eth/v2:go_default_library", - "@io_bazel_rules_go//proto/wkt:descriptor_go_proto", - "@io_bazel_rules_go//proto/wkt:empty_go_proto", - "@com_github_prysmaticlabs_eth2_types//:go_default_library", - "@go_googleapis//google/api:annotations_go_proto", - "@com_github_golang_protobuf//proto:go_default_library", - "@com_github_grpc_ecosystem_grpc_gateway_v2//proto/gateway:go_default_library", - ], -) - go_proto_compiler( name = "allow_delete_body_gateway_compiler", + options = [ + "logtostderr=true", + "allow_repeated_fields_in_body=true", + "allow_delete_body=true", + ], plugin = "@com_github_grpc_ecosystem_grpc_gateway_v2//protoc-gen-grpc-gateway:protoc-gen-grpc-gateway", - options = ["allow_delete_body=true"], + suffix = ".pb.gw.go", + visibility = ["//visibility:public"], + deps = [ + "@com_github_grpc_ecosystem_grpc_gateway_v2//runtime:go_default_library", + "@com_github_grpc_ecosystem_grpc_gateway_v2//utilities:go_default_library", + "@org_golang_google_grpc//:go_default_library", + "@org_golang_google_grpc//codes:go_default_library", + "@org_golang_google_grpc//grpclog:go_default_library", + "@org_golang_google_grpc//metadata:go_default_library", + "@org_golang_google_grpc//status:go_default_library", + "@org_golang_google_protobuf//proto:go_default_library", + ], ) go_proto_library( name = "go_grpc_gateway_library", compilers = [ - "@com_github_grpc_ecosystem_grpc_gateway_v2//protoc-gen-grpc-gateway:go_gen_grpc_gateway", + "allow_delete_body_gateway_compiler", ], embed = [":go_proto"], importpath = "github.com/prysmaticlabs/prysm/proto/eth/service", @@ -113,30 +89,9 @@ go_proto_library( ], ) -go_proto_library( - name = "custom_go_grpc_gateway_library", - compilers = [ - "allow_delete_body_gateway_compiler", - ], - embed = [":custom_go_proto"], - importpath = "github.com/prysmaticlabs/prysm/proto/eth/service", - protos = [":custom_proto"], - visibility = ["//proto:__subpackages__"], - deps = [ - "//proto/eth/ext:go_default_library", - "@io_bazel_rules_go//proto/wkt:empty_go_proto", - "@com_github_grpc_ecosystem_grpc_gateway_v2//protoc-gen-openapiv2/options:options_go_proto", - "@com_github_prysmaticlabs_go_bitfield//:go_default_library", - "@go_googleapis//google/api:annotations_go_proto", - "@io_bazel_rules_go//proto/wkt:timestamp_go_proto", - "@io_bazel_rules_go//proto/wkt:descriptor_go_proto", - "@com_github_grpc_ecosystem_grpc_gateway_v2//proto/gateway:go_default_library", - ], -) - go_library( name = "go_default_library", - embed = [":go_grpc_gateway_library", ":custom_go_grpc_gateway_library"], + embed = [":go_grpc_gateway_library"], importpath = "github.com/prysmaticlabs/prysm/proto/eth/service", visibility = ["//visibility:public"], ) diff --git a/proto/eth/service/key_management.pb.gw.go b/proto/eth/service/key_management.pb.gw.go index 8dd77e4da..298f0e793 100755 --- a/proto/eth/service/key_management.pb.gw.go +++ b/proto/eth/service/key_management.pb.gw.go @@ -89,18 +89,15 @@ func local_request_KeyManagement_ImportKeystores_0(ctx context.Context, marshale } -var ( - filter_KeyManagement_DeleteKeystores_0 = &utilities.DoubleArray{Encoding: map[string]int{}, Base: []int(nil), Check: []int(nil)} -) - func request_KeyManagement_DeleteKeystores_0(ctx context.Context, marshaler runtime.Marshaler, client KeyManagementClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { var protoReq DeleteKeystoresRequest var metadata runtime.ServerMetadata - if err := req.ParseForm(); err != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + newReader, berr := utilities.IOReaderFactory(req.Body) + if berr != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", berr) } - if err := runtime.PopulateQueryParameters(&protoReq, req.Form, filter_KeyManagement_DeleteKeystores_0); err != nil { + if err := marshaler.NewDecoder(newReader()).Decode(&protoReq); err != nil && err != io.EOF { return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) } @@ -113,10 +110,11 @@ func local_request_KeyManagement_DeleteKeystores_0(ctx context.Context, marshale var protoReq DeleteKeystoresRequest var metadata runtime.ServerMetadata - if err := req.ParseForm(); err != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + newReader, berr := utilities.IOReaderFactory(req.Body) + if berr != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", berr) } - if err := runtime.PopulateQueryParameters(&protoReq, req.Form, filter_KeyManagement_DeleteKeystores_0); err != nil { + if err := marshaler.NewDecoder(newReader()).Decode(&protoReq); err != nil && err != io.EOF { return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) } diff --git a/validator/node/BUILD.bazel b/validator/node/BUILD.bazel index 002ee914d..fb065531d 100644 --- a/validator/node/BUILD.bazel +++ b/validator/node/BUILD.bazel @@ -29,6 +29,7 @@ go_library( ], deps = [ "//api/gateway:go_default_library", + "//api/gateway/apimiddleware:go_default_library", "//async/event:go_default_library", "//cmd:go_default_library", "//cmd/validator/flags:go_default_library", @@ -38,6 +39,7 @@ go_library( "//monitoring/backup:go_default_library", "//monitoring/prometheus:go_default_library", "//monitoring/tracing:go_default_library", + "//proto/eth/service:go_default_library", "//proto/prysm/v1alpha1:go_default_library", "//proto/prysm/v1alpha1/validator-client:go_default_library", "//runtime:go_default_library", @@ -52,6 +54,7 @@ go_library( "//validator/keymanager:go_default_library", "//validator/keymanager/imported:go_default_library", "//validator/rpc:go_default_library", + "//validator/rpc/apimiddleware:go_default_library", "//validator/web:go_default_library", "@com_github_grpc_ecosystem_grpc_gateway_v2//runtime:go_default_library", "@com_github_pkg_errors//:go_default_library", diff --git a/validator/node/node.go b/validator/node/node.go index 4d4871f55..25304dbe3 100644 --- a/validator/node/node.go +++ b/validator/node/node.go @@ -17,6 +17,7 @@ import ( gwruntime "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" "github.com/pkg/errors" "github.com/prysmaticlabs/prysm/api/gateway" + "github.com/prysmaticlabs/prysm/api/gateway/apimiddleware" "github.com/prysmaticlabs/prysm/async/event" "github.com/prysmaticlabs/prysm/cmd" "github.com/prysmaticlabs/prysm/cmd/validator/flags" @@ -26,6 +27,7 @@ import ( "github.com/prysmaticlabs/prysm/monitoring/backup" "github.com/prysmaticlabs/prysm/monitoring/prometheus" tracing2 "github.com/prysmaticlabs/prysm/monitoring/tracing" + ethpbservice "github.com/prysmaticlabs/prysm/proto/eth/service" pb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1" validatorpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1/validator-client" "github.com/prysmaticlabs/prysm/runtime" @@ -40,6 +42,7 @@ import ( "github.com/prysmaticlabs/prysm/validator/keymanager" "github.com/prysmaticlabs/prysm/validator/keymanager/imported" "github.com/prysmaticlabs/prysm/validator/rpc" + validatorMiddleware "github.com/prysmaticlabs/prysm/validator/rpc/apimiddleware" "github.com/prysmaticlabs/prysm/validator/web" "github.com/sirupsen/logrus" "github.com/urfave/cli/v2" @@ -487,12 +490,14 @@ func (c *ValidatorClient) registerRPCGatewayService(cliCtx *cli.Context) error { validatorpb.RegisterAccountsHandler, validatorpb.RegisterBeaconHandler, validatorpb.RegisterSlashingProtectionHandler, + ethpbservice.RegisterKeyManagementHandler, } - mux := gwruntime.NewServeMux( + gwmux := gwruntime.NewServeMux( gwruntime.WithMarshalerOption(gwruntime.MIMEWildcard, &gwruntime.HTTPBodyMarshaler{ Marshaler: &gwruntime.JSONPb{ MarshalOptions: protojson.MarshalOptions{ EmitUnpopulated: true, + UseProtoNames: true, }, UnmarshalOptions: protojson.UnmarshalOptions{ DiscardUnknown: true, @@ -503,28 +508,42 @@ func (c *ValidatorClient) registerRPCGatewayService(cliCtx *cli.Context) error { "text/event-stream", &gwruntime.EventSourceJSONPb{}, ), ) - muxHandler := func(h http.Handler, w http.ResponseWriter, req *http.Request) { - if strings.HasPrefix(req.URL.Path, "/api") { - http.StripPrefix("/api", h).ServeHTTP(w, req) + muxHandler := func(apiMware *apimiddleware.ApiProxyMiddleware, h http.HandlerFunc, w http.ResponseWriter, req *http.Request) { + // The validator gateway handler requires this special logic as it serves two kinds of APIs, namely + // the standard validator keymanager API under the /eth namespace, and the Prysm internal + // validator API under the /api namespace. Finally, it also serves requests to host the validator web UI. + if strings.HasPrefix(req.URL.Path, "/api/eth/") { + req.URL.Path = strings.Replace(req.URL.Path, "/api", "", 1) + // If the prefix has /eth/, we handle it with the standard API gateway middleware. + apiMware.ServeHTTP(w, req) + } else if strings.HasPrefix(req.URL.Path, "/api") { + req.URL.Path = strings.Replace(req.URL.Path, "/api", "", 1) + // Else, we handle with the Prysm API gateway without a middleware. + h(w, req) } else { + // Finally, we handle with the web server. web.Handler(w, req) } } pbHandler := &gateway.PbMux{ Registrations: registrations, - Patterns: []string{"/accounts/", "/v2/"}, - Mux: mux, + Patterns: []string{"/accounts/", "/v2/", "/internal/eth/v1/"}, + Mux: gwmux, + } + opts := []gateway.Option{ + gateway.WithRemoteAddr(rpcAddr), + gateway.WithGatewayAddr(gatewayAddress), + gateway.WithMaxCallRecvMsgSize(maxCallSize), + gateway.WithPbHandlers([]*gateway.PbMux{pbHandler}), + gateway.WithAllowedOrigins(allowedOrigins), + gateway.WithApiMiddleware(&validatorMiddleware.ValidatorEndpointFactory{}), + gateway.WithMuxHandler(muxHandler), + } + gw, err := gateway.New(cliCtx.Context, opts...) + if err != nil { + return err } - - gw := gateway.New( - cliCtx.Context, - []*gateway.PbMux{pbHandler}, - muxHandler, - rpcAddr, - gatewayAddress, - ).WithAllowedOrigins(allowedOrigins).WithMaxCallRecvMsgSize(maxCallSize) - return c.services.RegisterService(gw) } diff --git a/validator/rpc/apimiddleware/BUILD.bazel b/validator/rpc/apimiddleware/BUILD.bazel new file mode 100644 index 000000000..865ddcf58 --- /dev/null +++ b/validator/rpc/apimiddleware/BUILD.bazel @@ -0,0 +1,15 @@ +load("@prysm//tools/go:def.bzl", "go_library") + +go_library( + name = "go_default_library", + srcs = [ + "endpoint_factory.go", + "structs.go", + ], + importpath = "github.com/prysmaticlabs/prysm/validator/rpc/apimiddleware", + visibility = ["//visibility:public"], + deps = [ + "//api/gateway/apimiddleware:go_default_library", + "@com_github_pkg_errors//:go_default_library", + ], +) diff --git a/validator/rpc/apimiddleware/endpoint_factory.go b/validator/rpc/apimiddleware/endpoint_factory.go new file mode 100644 index 000000000..dee975556 --- /dev/null +++ b/validator/rpc/apimiddleware/endpoint_factory.go @@ -0,0 +1,38 @@ +package apimiddleware + +import ( + "github.com/pkg/errors" + "github.com/prysmaticlabs/prysm/api/gateway/apimiddleware" +) + +// ValidatorEndpointFactory creates endpoints used for running validator API calls through the API Middleware. +type ValidatorEndpointFactory struct { +} + +func (f *ValidatorEndpointFactory) IsNil() bool { + return f == nil +} + +// Paths is a collection of all valid validator API paths. +func (*ValidatorEndpointFactory) Paths() []string { + return []string{ + "/eth/v1/keystores", + } +} + +// Create returns a new endpoint for the provided API path. +func (*ValidatorEndpointFactory) Create(path string) (*apimiddleware.Endpoint, error) { + endpoint := apimiddleware.DefaultEndpoint() + switch path { + case "/eth/v1/keystores": + endpoint.GetResponse = &listKeystoresResponseJson{} + endpoint.PostRequest = &importKeystoresRequestJson{} + endpoint.PostResponse = &importKeystoresResponseJson{} + endpoint.DeleteRequest = &deleteKeystoresRequestJson{} + endpoint.DeleteResponse = &deleteKeystoresResponseJson{} + default: + return nil, errors.New("invalid path") + } + endpoint.Path = path + return &endpoint, nil +} diff --git a/validator/rpc/apimiddleware/structs.go b/validator/rpc/apimiddleware/structs.go new file mode 100644 index 000000000..670191230 --- /dev/null +++ b/validator/rpc/apimiddleware/structs.go @@ -0,0 +1,34 @@ +package apimiddleware + +type listKeystoresResponseJson struct { + Keystores []*keystoreJson `json:"keystores"` +} + +type keystoreJson struct { + ValidatingPubkey string `json:"validating_pubkey" hex:"true"` + DerivationPath string `json:"derivation_path"` +} + +type importKeystoresRequestJson struct { + Keystores []string `json:"keystores"` + Passwords []string `json:"passwords"` + SlashingProtection string `json:"slashing_protection"` +} + +type importKeystoresResponseJson struct { + Statuses []*statusJson `json:"statuses"` +} + +type deleteKeystoresRequestJson struct { + PublicKeys []string `json:"public_keys" hex:"true"` +} + +type statusJson struct { + Status string `json:"status"` + Message string `json:"message"` +} + +type deleteKeystoresResponseJson struct { + Statuses []*statusJson `json:"statuses"` + SlashingProtection string `json:"slashing_protection"` +} diff --git a/validator/rpc/standard_api.go b/validator/rpc/standard_api.go index 2df3feefe..e1fe986cd 100644 --- a/validator/rpc/standard_api.go +++ b/validator/rpc/standard_api.go @@ -97,6 +97,9 @@ func (s *Server) DeleteKeystores( if !ok { return nil, status.Error(codes.Internal, "Keymanager kind cannot delete keys") } + if len(req.PublicKeys) == 0 { + return ðpbservice.DeleteKeystoresResponse{Statuses: make([]*ethpbservice.DeletedKeystoreStatus, 0)}, nil + } statuses, err := deleter.DeleteKeystores(ctx, req.PublicKeys) if err != nil { return nil, status.Errorf(codes.Internal, "Could not delete keys: %v", err) diff --git a/validator/rpc/standard_api_test.go b/validator/rpc/standard_api_test.go index 603c69af9..4bc44e1d4 100644 --- a/validator/rpc/standard_api_test.go +++ b/validator/rpc/standard_api_test.go @@ -203,7 +203,6 @@ func TestServer_ImportKeystores(t *testing.T) { } }) } - func TestServer_DeleteKeystores(t *testing.T) { ctx := context.Background() srv := setupServerWithWallet(t) @@ -247,6 +246,14 @@ func TestServer_DeleteKeystores(t *testing.T) { }) require.NoError(t, err) + t.Run("no slashing protection response if no keys in request even if we have a history in DB", func(t *testing.T) { + resp, err := srv.DeleteKeystores(context.Background(), ðpbservice.DeleteKeystoresRequest{ + PublicKeys: nil, + }) + require.NoError(t, err) + require.Equal(t, "", resp.SlashingProtection) + }) + // For ease of test setup, we'll give each public key a string identifier. publicKeysWithId := map[string][48]byte{ "a": publicKeys[0],