API Middleware for Keymanager Standard API Endpoints (#9936)

* begin the middleware approach

* attempt middleware

* middleware works in tandem with web ui

* handle delete as well

* delete request

* DELETE working

* tool to perform imports

* functioning

* commentary

* build

* gaz

* smol test

* enable keymanager api use protonames

* edit

* one rule

* rem gw

* Fix custom compiler

(cherry picked from commit 3b1f65919e04ddf7e07c8f60cba1be883a736476)

* gen proto

* imports

* Update validator/node/node.go

Co-authored-by: Radosław Kapka <rkapka@wp.pl>

* remaining comments

* update item

* rpc

* add

* run gateway

* simplify

* rem flag

* deep source

Co-authored-by: prestonvanloon <preston@prysmaticlabs.com>
Co-authored-by: Radosław Kapka <rkapka@wp.pl>
Co-authored-by: prylabs-bulldozer[bot] <58059840+prylabs-bulldozer[bot]@users.noreply.github.com>
This commit is contained in:
Raul Jordan 2021-12-07 15:26:21 -05:00 committed by GitHub
parent cee3b626f3
commit 424c8f6b46
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 424 additions and 251 deletions

View File

@ -5,6 +5,7 @@ go_library(
srcs = [ srcs = [
"gateway.go", "gateway.go",
"log.go", "log.go",
"options.go",
], ],
importpath = "github.com/prysmaticlabs/prysm/api/gateway", importpath = "github.com/prysmaticlabs/prysm/api/gateway",
visibility = [ visibility = [

View File

@ -14,6 +14,7 @@ import (
type ApiProxyMiddleware struct { type ApiProxyMiddleware struct {
GatewayAddress string GatewayAddress string
EndpointCreator EndpointFactory EndpointCreator EndpointFactory
router *mux.Router
} }
// EndpointFactory is responsible for creating new instances of Endpoint values. // EndpointFactory is responsible for creating new instances of Endpoint values.
@ -29,6 +30,8 @@ type Endpoint struct {
GetResponse interface{} // The struct corresponding to the JSON structure used in a GET response. GetResponse interface{} // The struct corresponding to the JSON structure used in a GET response.
PostRequest interface{} // The struct corresponding to the JSON structure used in a POST request. PostRequest interface{} // The struct corresponding to the JSON structure used in a POST request.
PostResponse interface{} // The struct corresponding to the JSON structure used in a POST response. PostResponse interface{} // The struct corresponding to the JSON structure used in a POST response.
DeleteRequest interface{} // The struct corresponding to the JSON structure used in a DELETE request.
DeleteResponse interface{} // The struct corresponding to the JSON structure used in a DELETE response.
RequestURLLiterals []string // Names of URL parameters that should not be base64-encoded. RequestURLLiterals []string // Names of URL parameters that should not be base64-encoded.
RequestQueryParams []QueryParam // Query parameters of the request. RequestQueryParams []QueryParam // Query parameters of the request.
Err ErrorJson // The struct corresponding to the error that should be returned in case of a request failure. Err ErrorJson // The struct corresponding to the error that should be returned in case of a request failure.
@ -74,18 +77,24 @@ type fieldProcessor struct {
// Run starts the proxy, registering all proxy endpoints. // Run starts the proxy, registering all proxy endpoints.
func (m *ApiProxyMiddleware) Run(gatewayRouter *mux.Router) { func (m *ApiProxyMiddleware) Run(gatewayRouter *mux.Router) {
for _, path := range m.EndpointCreator.Paths() { for _, path := range m.EndpointCreator.Paths() {
m.handleApiPath(gatewayRouter, path, m.EndpointCreator) gatewayRouter.HandleFunc(path, m.WithMiddleware(path))
} }
m.router = gatewayRouter
} }
func (m *ApiProxyMiddleware) handleApiPath(gatewayRouter *mux.Router, path string, endpointFactory EndpointFactory) { // ServeHTTP for the proxy middleware.
gatewayRouter.HandleFunc(path, func(w http.ResponseWriter, req *http.Request) { func (m *ApiProxyMiddleware) ServeHTTP(w http.ResponseWriter, req *http.Request) {
endpoint, err := endpointFactory.Create(path) m.router.ServeHTTP(w, req)
if err != nil { }
errJson := InternalServerErrorWithMessage(err, "could not create endpoint")
WriteError(w, errJson, nil)
}
// WithMiddleware wraps the given endpoint handler with the middleware logic.
func (m *ApiProxyMiddleware) WithMiddleware(path string) http.HandlerFunc {
endpoint, err := m.EndpointCreator.Create(path)
if err != nil {
log.WithError(err).Errorf("Could not create endpoint for path: %s", path)
return nil
}
return func(w http.ResponseWriter, req *http.Request) {
for _, handler := range endpoint.CustomHandlers { for _, handler := range endpoint.CustomHandlers {
if handler(m, *endpoint, w, req) { if handler(m, *endpoint, w, req) {
return return
@ -93,16 +102,14 @@ func (m *ApiProxyMiddleware) handleApiPath(gatewayRouter *mux.Router, path strin
} }
if req.Method == "POST" { if req.Method == "POST" {
if errJson := deserializeRequestBodyIntoContainerWrapped(endpoint, req, w); errJson != nil { if errJson := handlePostRequestForEndpoint(endpoint, w, req); errJson != nil {
WriteError(w, errJson, nil) WriteError(w, errJson, nil)
return return
} }
}
if errJson := ProcessRequestContainerFields(endpoint.PostRequest); errJson != nil { if req.Method == "DELETE" {
WriteError(w, errJson, nil) if errJson := handleDeleteRequestForEndpoint(endpoint, req); errJson != nil {
return
}
if errJson := SetRequestBodyToRequestContainer(endpoint.PostRequest, req); errJson != nil {
WriteError(w, errJson, nil) WriteError(w, errJson, nil)
return return
} }
@ -137,6 +144,8 @@ func (m *ApiProxyMiddleware) handleApiPath(gatewayRouter *mux.Router, path strin
var resp interface{} var resp interface{}
if req.Method == "GET" { if req.Method == "GET" {
resp = endpoint.GetResponse resp = endpoint.GetResponse
} else if req.Method == "DELETE" {
resp = endpoint.DeleteResponse
} else { } else {
resp = endpoint.PostResponse resp = endpoint.PostResponse
} }
@ -164,7 +173,27 @@ func (m *ApiProxyMiddleware) handleApiPath(gatewayRouter *mux.Router, path strin
WriteError(w, errJson, nil) WriteError(w, errJson, nil)
return return
} }
}) }
}
func handlePostRequestForEndpoint(endpoint *Endpoint, w http.ResponseWriter, req *http.Request) ErrorJson {
if errJson := deserializeRequestBodyIntoContainerWrapped(endpoint, req, w); errJson != nil {
return errJson
}
if errJson := ProcessRequestContainerFields(endpoint.PostRequest); errJson != nil {
return errJson
}
return SetRequestBodyToRequestContainer(endpoint.PostRequest, req)
}
func handleDeleteRequestForEndpoint(endpoint *Endpoint, req *http.Request) ErrorJson {
if errJson := DeserializeRequestBodyIntoContainer(req.Body, endpoint.DeleteRequest); errJson != nil {
return errJson
}
if errJson := ProcessRequestContainerFields(endpoint.DeleteRequest); errJson != nil {
return errJson
}
return SetRequestBodyToRequestContainer(endpoint.DeleteRequest, req)
} }
func deserializeRequestBodyIntoContainerWrapped(endpoint *Endpoint, req *http.Request, w http.ResponseWriter) ErrorJson { func deserializeRequestBodyIntoContainerWrapped(endpoint *Endpoint, req *http.Request, w http.ResponseWriter) ErrorJson {

View File

@ -34,74 +34,51 @@ type PbMux struct {
type PbHandlerRegistration func(context.Context, *gwruntime.ServeMux, *grpc.ClientConn) error type PbHandlerRegistration func(context.Context, *gwruntime.ServeMux, *grpc.ClientConn) error
// MuxHandler is a function that implements the mux handler functionality. // MuxHandler is a function that implements the mux handler functionality.
type MuxHandler func(http.Handler, http.ResponseWriter, *http.Request) type MuxHandler func(
apiMiddlewareHandler *apimiddleware.ApiProxyMiddleware,
h http.HandlerFunc,
w http.ResponseWriter,
req *http.Request,
)
// Config parameters for setting up the gateway service.
type config struct {
maxCallRecvMsgSize uint64
remoteCert string
gatewayAddr string
remoteAddr string
allowedOrigins []string
apiMiddlewareEndpointFactory apimiddleware.EndpointFactory
muxHandler MuxHandler
pbHandlers []*PbMux
router *mux.Router
}
// Gateway is the gRPC gateway to serve HTTP JSON traffic as a proxy and forward it to the gRPC server. // Gateway is the gRPC gateway to serve HTTP JSON traffic as a proxy and forward it to the gRPC server.
type Gateway struct { type Gateway struct {
conn *grpc.ClientConn cfg *config
pbHandlers []*PbMux conn *grpc.ClientConn
muxHandler MuxHandler server *http.Server
maxCallRecvMsgSize uint64 cancel context.CancelFunc
router *mux.Router proxy *apimiddleware.ApiProxyMiddleware
server *http.Server ctx context.Context
cancel context.CancelFunc startFailure error
remoteCert string
gatewayAddr string
apiMiddlewareEndpointFactory apimiddleware.EndpointFactory
ctx context.Context
startFailure error
remoteAddr string
allowedOrigins []string
} }
// New returns a new instance of the Gateway. // New returns a new instance of the Gateway.
func New( func New(ctx context.Context, opts ...Option) (*Gateway, error) {
ctx context.Context,
pbHandlers []*PbMux,
muxHandler MuxHandler,
remoteAddr,
gatewayAddress string,
) *Gateway {
g := &Gateway{ g := &Gateway{
pbHandlers: pbHandlers, ctx: ctx,
muxHandler: muxHandler, cfg: &config{
router: mux.NewRouter(), router: mux.NewRouter(),
gatewayAddr: gatewayAddress, },
ctx: ctx,
remoteAddr: remoteAddr,
allowedOrigins: []string{},
} }
return g for _, opt := range opts {
} if err := opt(g); err != nil {
return nil, err
// WithRouter allows adding a custom mux router to the gateway. }
func (g *Gateway) WithRouter(r *mux.Router) *Gateway { }
g.router = r return g, nil
return g
}
// WithAllowedOrigins allows adding a set of allowed origins to the gateway.
func (g *Gateway) WithAllowedOrigins(origins []string) *Gateway {
g.allowedOrigins = origins
return g
}
// WithRemoteCert allows adding a custom certificate to the gateway,
func (g *Gateway) WithRemoteCert(cert string) *Gateway {
g.remoteCert = cert
return g
}
// WithMaxCallRecvMsgSize allows specifying the maximum allowed gRPC message size.
func (g *Gateway) WithMaxCallRecvMsgSize(size uint64) *Gateway {
g.maxCallRecvMsgSize = size
return g
}
// WithApiMiddleware allows adding API Middleware proxy to the gateway.
func (g *Gateway) WithApiMiddleware(endpointFactory apimiddleware.EndpointFactory) *Gateway {
g.apiMiddlewareEndpointFactory = endpointFactory
return g
} }
// Start the gateway service. // Start the gateway service.
@ -109,7 +86,7 @@ func (g *Gateway) Start() {
ctx, cancel := context.WithCancel(g.ctx) ctx, cancel := context.WithCancel(g.ctx)
g.cancel = cancel g.cancel = cancel
conn, err := g.dial(ctx, "tcp", g.remoteAddr) conn, err := g.dial(ctx, "tcp", g.cfg.remoteAddr)
if err != nil { if err != nil {
log.WithError(err).Error("Failed to connect to gRPC server") log.WithError(err).Error("Failed to connect to gRPC server")
g.startFailure = err g.startFailure = err
@ -117,7 +94,7 @@ func (g *Gateway) Start() {
} }
g.conn = conn g.conn = conn
for _, h := range g.pbHandlers { for _, h := range g.cfg.pbHandlers {
for _, r := range h.Registrations { for _, r := range h.Registrations {
if err := r(ctx, h.Mux, g.conn); err != nil { if err := r(ctx, h.Mux, g.conn); err != nil {
log.WithError(err).Error("Failed to register handler") log.WithError(err).Error("Failed to register handler")
@ -126,29 +103,30 @@ func (g *Gateway) Start() {
} }
} }
for _, p := range h.Patterns { for _, p := range h.Patterns {
g.router.PathPrefix(p).Handler(h.Mux) g.cfg.router.PathPrefix(p).Handler(h.Mux)
} }
} }
corsMux := g.corsMiddleware(g.router) corsMux := g.corsMiddleware(g.cfg.router)
_ = corsMux
if g.muxHandler != nil { if g.cfg.apiMiddlewareEndpointFactory != nil && !g.cfg.apiMiddlewareEndpointFactory.IsNil() {
g.router.PathPrefix("/").HandlerFunc(func(w http.ResponseWriter, r *http.Request) { g.registerApiMiddleware()
g.muxHandler(corsMux, w, r) }
if g.cfg.muxHandler != nil {
g.cfg.router.PathPrefix("/").HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
g.cfg.muxHandler(g.proxy, g.cfg.router.ServeHTTP, w, r)
}) })
} }
if g.apiMiddlewareEndpointFactory != nil && !g.apiMiddlewareEndpointFactory.IsNil() {
go g.registerApiMiddleware()
}
g.server = &http.Server{ g.server = &http.Server{
Addr: g.gatewayAddr, Addr: g.cfg.gatewayAddr,
Handler: g.router, Handler: g.cfg.router,
} }
go func() { go func() {
log.WithField("address", g.gatewayAddr).Info("Starting gRPC gateway") log.WithField("address", g.cfg.gatewayAddr).Info("Starting gRPC gateway")
if err := g.server.ListenAndServe(); err != http.ErrServerClosed { if err := g.server.ListenAndServe(); err != http.ErrServerClosed {
log.WithError(err).Error("Failed to start gRPC gateway") log.WithError(err).Error("Failed to start gRPC gateway")
g.startFailure = err g.startFailure = err
@ -162,11 +140,9 @@ func (g *Gateway) Status() error {
if g.startFailure != nil { if g.startFailure != nil {
return g.startFailure return g.startFailure
} }
if s := g.conn.GetState(); s != connectivity.Ready { if s := g.conn.GetState(); s != connectivity.Ready {
return fmt.Errorf("grpc server is %s", s) return fmt.Errorf("grpc server is %s", s)
} }
return nil return nil
} }
@ -183,18 +159,16 @@ func (g *Gateway) Stop() error {
} }
} }
} }
if g.cancel != nil { if g.cancel != nil {
g.cancel() g.cancel()
} }
return nil return nil
} }
func (g *Gateway) corsMiddleware(h http.Handler) http.Handler { func (g *Gateway) corsMiddleware(h http.Handler) http.Handler {
c := cors.New(cors.Options{ c := cors.New(cors.Options{
AllowedOrigins: g.allowedOrigins, AllowedOrigins: g.cfg.allowedOrigins,
AllowedMethods: []string{http.MethodPost, http.MethodGet, http.MethodOptions}, AllowedMethods: []string{http.MethodPost, http.MethodGet, http.MethodDelete, http.MethodOptions},
AllowCredentials: true, AllowCredentials: true,
MaxAge: 600, MaxAge: 600,
AllowedHeaders: []string{"*"}, AllowedHeaders: []string{"*"},
@ -236,8 +210,8 @@ func (g *Gateway) dial(ctx context.Context, network, addr string) (*grpc.ClientC
// "addr" must be a valid TCP address with a port number. // "addr" must be a valid TCP address with a port number.
func (g *Gateway) dialTCP(ctx context.Context, addr string) (*grpc.ClientConn, error) { func (g *Gateway) dialTCP(ctx context.Context, addr string) (*grpc.ClientConn, error) {
security := grpc.WithInsecure() security := grpc.WithInsecure()
if len(g.remoteCert) > 0 { if len(g.cfg.remoteCert) > 0 {
creds, err := credentials.NewClientTLSFromFile(g.remoteCert, "") creds, err := credentials.NewClientTLSFromFile(g.cfg.remoteCert, "")
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -245,7 +219,7 @@ func (g *Gateway) dialTCP(ctx context.Context, addr string) (*grpc.ClientConn, e
} }
opts := []grpc.DialOption{ opts := []grpc.DialOption{
security, security,
grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(int(g.maxCallRecvMsgSize))), grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(int(g.cfg.maxCallRecvMsgSize))),
} }
return grpc.DialContext(ctx, addr, opts...) return grpc.DialContext(ctx, addr, opts...)
@ -266,16 +240,16 @@ func (g *Gateway) dialUnix(ctx context.Context, addr string) (*grpc.ClientConn,
opts := []grpc.DialOption{ opts := []grpc.DialOption{
grpc.WithInsecure(), grpc.WithInsecure(),
grpc.WithContextDialer(f), grpc.WithContextDialer(f),
grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(int(g.maxCallRecvMsgSize))), grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(int(g.cfg.maxCallRecvMsgSize))),
} }
return grpc.DialContext(ctx, addr, opts...) return grpc.DialContext(ctx, addr, opts...)
} }
func (g *Gateway) registerApiMiddleware() { func (g *Gateway) registerApiMiddleware() {
proxy := &apimiddleware.ApiProxyMiddleware{ g.proxy = &apimiddleware.ApiProxyMiddleware{
GatewayAddress: g.gatewayAddr, GatewayAddress: g.cfg.gatewayAddr,
EndpointCreator: g.apiMiddlewareEndpointFactory, EndpointCreator: g.cfg.apiMiddlewareEndpointFactory,
} }
log.Info("Starting API middleware") log.Info("Starting API middleware")
proxy.Run(g.router) g.proxy.Run(g.cfg.router)
} }

View File

@ -40,26 +40,30 @@ func TestGateway_Customized(t *testing.T) {
size := uint64(100) size := uint64(100)
endpointFactory := &mockEndpointFactory{} endpointFactory := &mockEndpointFactory{}
g := New( opts := []Option{
context.Background(), WithRouter(r),
[]*PbMux{}, WithRemoteCert(cert),
func(handler http.Handler, writer http.ResponseWriter, request *http.Request) { WithAllowedOrigins(origins),
WithMaxCallRecvMsgSize(size),
WithApiMiddleware(endpointFactory),
WithMuxHandler(func(
_ *apimiddleware.ApiProxyMiddleware,
_ http.HandlerFunc,
_ http.ResponseWriter,
_ *http.Request,
) {
}),
}
}, g, err := New(context.Background(), opts...)
"", require.NoError(t, err)
"",
).WithRouter(r).
WithRemoteCert(cert).
WithAllowedOrigins(origins).
WithMaxCallRecvMsgSize(size).
WithApiMiddleware(endpointFactory)
assert.Equal(t, r, g.router) assert.Equal(t, r, g.cfg.router)
assert.Equal(t, cert, g.remoteCert) assert.Equal(t, cert, g.cfg.remoteCert)
require.Equal(t, 1, len(g.allowedOrigins)) require.Equal(t, 1, len(g.cfg.allowedOrigins))
assert.Equal(t, origins[0], g.allowedOrigins[0]) assert.Equal(t, origins[0], g.cfg.allowedOrigins[0])
assert.Equal(t, size, g.maxCallRecvMsgSize) assert.Equal(t, size, g.cfg.maxCallRecvMsgSize)
assert.Equal(t, endpointFactory, g.apiMiddlewareEndpointFactory) assert.Equal(t, endpointFactory, g.cfg.apiMiddlewareEndpointFactory)
} }
func TestGateway_StartStop(t *testing.T) { func TestGateway_StartStop(t *testing.T) {
@ -75,23 +79,27 @@ func TestGateway_StartStop(t *testing.T) {
selfAddress := fmt.Sprintf("%s:%d", rpcHost, ctx.Int(flags.RPCPort.Name)) selfAddress := fmt.Sprintf("%s:%d", rpcHost, ctx.Int(flags.RPCPort.Name))
gatewayAddress := fmt.Sprintf("%s:%d", gatewayHost, gatewayPort) gatewayAddress := fmt.Sprintf("%s:%d", gatewayHost, gatewayPort)
g := New( opts := []Option{
ctx.Context, WithGatewayAddr(gatewayAddress),
[]*PbMux{}, WithRemoteAddr(selfAddress),
func(handler http.Handler, writer http.ResponseWriter, request *http.Request) { WithMuxHandler(func(
_ *apimiddleware.ApiProxyMiddleware,
_ http.HandlerFunc,
_ http.ResponseWriter,
_ *http.Request,
) {
}),
}
}, g, err := New(context.Background(), opts...)
selfAddress, require.NoError(t, err)
gatewayAddress,
)
g.Start() g.Start()
go func() { go func() {
require.LogsContain(t, hook, "Starting gRPC gateway") require.LogsContain(t, hook, "Starting gRPC gateway")
require.LogsDoNotContain(t, hook, "Starting API middleware") require.LogsDoNotContain(t, hook, "Starting API middleware")
}() }()
err = g.Stop()
err := g.Stop()
require.NoError(t, err) require.NoError(t, err)
} }
@ -106,15 +114,15 @@ func TestGateway_NilHandler_NotFoundHandlerRegistered(t *testing.T) {
selfAddress := fmt.Sprintf("%s:%d", rpcHost, ctx.Int(flags.RPCPort.Name)) selfAddress := fmt.Sprintf("%s:%d", rpcHost, ctx.Int(flags.RPCPort.Name))
gatewayAddress := fmt.Sprintf("%s:%d", gatewayHost, gatewayPort) gatewayAddress := fmt.Sprintf("%s:%d", gatewayHost, gatewayPort)
g := New( opts := []Option{
ctx.Context, WithGatewayAddr(gatewayAddress),
[]*PbMux{}, WithRemoteAddr(selfAddress),
/* muxHandler */ nil, }
selfAddress,
gatewayAddress, g, err := New(context.Background(), opts...)
) require.NoError(t, err)
writer := httptest.NewRecorder() writer := httptest.NewRecorder()
g.router.ServeHTTP(writer, &http.Request{Method: "GET", Host: "localhost", URL: &url.URL{Path: "/foo"}}) g.cfg.router.ServeHTTP(writer, &http.Request{Method: "GET", Host: "localhost", URL: &url.URL{Path: "/foo"}})
assert.Equal(t, http.StatusNotFound, writer.Code) assert.Equal(t, http.StatusNotFound, writer.Code)
} }

81
api/gateway/options.go Normal file
View File

@ -0,0 +1,81 @@
package gateway
import (
"github.com/gorilla/mux"
"github.com/prysmaticlabs/prysm/api/gateway/apimiddleware"
)
type Option func(g *Gateway) error
func (g *Gateway) SetRouter(r *mux.Router) *Gateway {
g.cfg.router = r
return g
}
func WithPbHandlers(handlers []*PbMux) Option {
return func(g *Gateway) error {
g.cfg.pbHandlers = handlers
return nil
}
}
func WithMuxHandler(m MuxHandler) Option {
return func(g *Gateway) error {
g.cfg.muxHandler = m
return nil
}
}
func WithGatewayAddr(addr string) Option {
return func(g *Gateway) error {
g.cfg.gatewayAddr = addr
return nil
}
}
func WithRemoteAddr(addr string) Option {
return func(g *Gateway) error {
g.cfg.remoteAddr = addr
return nil
}
}
// WithRouter allows adding a custom mux router to the gateway.
func WithRouter(r *mux.Router) Option {
return func(g *Gateway) error {
g.cfg.router = r
return nil
}
}
// WithAllowedOrigins allows adding a set of allowed origins to the gateway.
func WithAllowedOrigins(origins []string) Option {
return func(g *Gateway) error {
g.cfg.allowedOrigins = origins
return nil
}
}
// WithRemoteCert allows adding a custom certificate to the gateway,
func WithRemoteCert(cert string) Option {
return func(g *Gateway) error {
g.cfg.remoteCert = cert
return nil
}
}
// WithMaxCallRecvMsgSize allows specifying the maximum allowed gRPC message size.
func WithMaxCallRecvMsgSize(size uint64) Option {
return func(g *Gateway) error {
g.cfg.maxCallRecvMsgSize = size
return nil
}
}
// WithApiMiddleware allows adding an API middleware proxy to the gateway.
func WithApiMiddleware(endpointFactory apimiddleware.EndpointFactory) Option {
return func(g *Gateway) error {
g.cfg.apiMiddlewareEndpointFactory = endpointFactory
return nil
}
}

View File

@ -830,19 +830,22 @@ func (b *BeaconNode) registerGRPCGateway() error {
muxs = append(muxs, gatewayConfig.EthPbMux) muxs = append(muxs, gatewayConfig.EthPbMux)
} }
g := apigateway.New( opts := []apigateway.Option{
b.ctx, apigateway.WithGatewayAddr(gatewayAddress),
muxs, apigateway.WithRemoteAddr(selfAddress),
gatewayConfig.Handler, apigateway.WithPbHandlers(muxs),
selfAddress, apigateway.WithMuxHandler(gatewayConfig.Handler),
gatewayAddress, apigateway.WithRemoteCert(selfCert),
).WithAllowedOrigins(allowedOrigins). apigateway.WithMaxCallRecvMsgSize(maxCallSize),
WithRemoteCert(selfCert). apigateway.WithAllowedOrigins(allowedOrigins),
WithMaxCallRecvMsgSize(maxCallSize) }
if flags.EnableHTTPEthAPI(httpModules) { if flags.EnableHTTPEthAPI(httpModules) {
g.WithApiMiddleware(&apimiddleware.BeaconEndpointFactory{}) opts = append(opts, apigateway.WithApiMiddleware(&apimiddleware.BeaconEndpointFactory{}))
}
g, err := apigateway.New(b.ctx, opts...)
if err != nil {
return err
} }
return b.services.RegisterService(g) return b.services.RegisterService(g)
} }

View File

@ -52,23 +52,28 @@ func main() {
if gatewayConfig.EthPbMux != nil { if gatewayConfig.EthPbMux != nil {
muxs = append(muxs, gatewayConfig.EthPbMux) muxs = append(muxs, gatewayConfig.EthPbMux)
} }
opts := []gateway.Option{
gateway.WithPbHandlers(muxs),
gateway.WithMuxHandler(gatewayConfig.Handler),
gateway.WithRemoteAddr(*beaconRPC),
gateway.WithGatewayAddr(fmt.Sprintf("%s:%d", *host, *port)),
gateway.WithAllowedOrigins(strings.Split(*allowedOrigins, ",")),
gateway.WithMaxCallRecvMsgSize(uint64(*grpcMaxMsgSize)),
}
gw := gateway.New(
context.Background(),
muxs,
gatewayConfig.Handler,
*beaconRPC,
fmt.Sprintf("%s:%d", *host, *port),
).WithAllowedOrigins(strings.Split(*allowedOrigins, ",")).
WithMaxCallRecvMsgSize(uint64(*grpcMaxMsgSize))
if flags.EnableHTTPEthAPI(*httpModules) { if flags.EnableHTTPEthAPI(*httpModules) {
gw.WithApiMiddleware(&apimiddleware.BeaconEndpointFactory{}) opts = append(opts, gateway.WithApiMiddleware(&apimiddleware.BeaconEndpointFactory{}))
}
gw, err := gateway.New(context.Background(), opts...)
if err != nil {
log.Fatal(err)
} }
r := mux.NewRouter() r := mux.NewRouter()
r.HandleFunc("/swagger/", gateway.SwaggerServer()) r.HandleFunc("/swagger/", gateway.SwaggerServer())
r.HandleFunc("/healthz", healthzServer(gw)) r.HandleFunc("/healthz", healthzServer(gw))
gw = gw.WithRouter(r) gw.SetRouter(r)
gw.Start() gw.Start()

View File

@ -12,27 +12,6 @@ proto_library(
"events_service.proto", "events_service.proto",
"node_service.proto", "node_service.proto",
"validator_service.proto", "validator_service.proto",
],
visibility = ["//visibility:public"],
deps = [
"//proto/eth/ext:proto",
"//proto/eth/v1:proto",
"//proto/eth/v2:proto",
"@com_google_protobuf//:descriptor_proto",
"@com_google_protobuf//:empty_proto",
"@com_google_protobuf//:timestamp_proto",
"@go_googleapis//google/api:annotations_proto",
"@com_github_grpc_ecosystem_grpc_gateway_v2//proto/gateway:event_source_proto",
],
)
# We create a custom proto library for key_management.proto as it requires a different
# compiler plugin for grpc gateway than the others. Namely, it requires adding the option
# --allow_delete_body=true to allow grpc gateway endpoints to take in DELETE HTTP requests
# with a request body properly.
proto_library(
name = "custom_proto",
srcs = [
"key_management.proto", "key_management.proto",
], ],
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
@ -67,35 +46,32 @@ go_proto_library(
], ],
) )
go_proto_library(
name = "custom_go_proto",
compilers = ["@com_github_prysmaticlabs_protoc_gen_go_cast//:go_cast_grpc",],
importpath = "github.com/prysmaticlabs/prysm/proto/eth/service",
proto = ":custom_proto",
visibility = ["//visibility:public"],
deps = [
"//proto/eth/ext:go_default_library",
"//proto/eth/v1:go_default_library",
"//proto/eth/v2:go_default_library",
"@io_bazel_rules_go//proto/wkt:descriptor_go_proto",
"@io_bazel_rules_go//proto/wkt:empty_go_proto",
"@com_github_prysmaticlabs_eth2_types//:go_default_library",
"@go_googleapis//google/api:annotations_go_proto",
"@com_github_golang_protobuf//proto:go_default_library",
"@com_github_grpc_ecosystem_grpc_gateway_v2//proto/gateway:go_default_library",
],
)
go_proto_compiler( go_proto_compiler(
name = "allow_delete_body_gateway_compiler", name = "allow_delete_body_gateway_compiler",
options = [
"logtostderr=true",
"allow_repeated_fields_in_body=true",
"allow_delete_body=true",
],
plugin = "@com_github_grpc_ecosystem_grpc_gateway_v2//protoc-gen-grpc-gateway:protoc-gen-grpc-gateway", plugin = "@com_github_grpc_ecosystem_grpc_gateway_v2//protoc-gen-grpc-gateway:protoc-gen-grpc-gateway",
options = ["allow_delete_body=true"], suffix = ".pb.gw.go",
visibility = ["//visibility:public"],
deps = [
"@com_github_grpc_ecosystem_grpc_gateway_v2//runtime:go_default_library",
"@com_github_grpc_ecosystem_grpc_gateway_v2//utilities:go_default_library",
"@org_golang_google_grpc//:go_default_library",
"@org_golang_google_grpc//codes:go_default_library",
"@org_golang_google_grpc//grpclog:go_default_library",
"@org_golang_google_grpc//metadata:go_default_library",
"@org_golang_google_grpc//status:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
],
) )
go_proto_library( go_proto_library(
name = "go_grpc_gateway_library", name = "go_grpc_gateway_library",
compilers = [ compilers = [
"@com_github_grpc_ecosystem_grpc_gateway_v2//protoc-gen-grpc-gateway:go_gen_grpc_gateway", "allow_delete_body_gateway_compiler",
], ],
embed = [":go_proto"], embed = [":go_proto"],
importpath = "github.com/prysmaticlabs/prysm/proto/eth/service", importpath = "github.com/prysmaticlabs/prysm/proto/eth/service",
@ -113,30 +89,9 @@ go_proto_library(
], ],
) )
go_proto_library(
name = "custom_go_grpc_gateway_library",
compilers = [
"allow_delete_body_gateway_compiler",
],
embed = [":custom_go_proto"],
importpath = "github.com/prysmaticlabs/prysm/proto/eth/service",
protos = [":custom_proto"],
visibility = ["//proto:__subpackages__"],
deps = [
"//proto/eth/ext:go_default_library",
"@io_bazel_rules_go//proto/wkt:empty_go_proto",
"@com_github_grpc_ecosystem_grpc_gateway_v2//protoc-gen-openapiv2/options:options_go_proto",
"@com_github_prysmaticlabs_go_bitfield//:go_default_library",
"@go_googleapis//google/api:annotations_go_proto",
"@io_bazel_rules_go//proto/wkt:timestamp_go_proto",
"@io_bazel_rules_go//proto/wkt:descriptor_go_proto",
"@com_github_grpc_ecosystem_grpc_gateway_v2//proto/gateway:go_default_library",
],
)
go_library( go_library(
name = "go_default_library", name = "go_default_library",
embed = [":go_grpc_gateway_library", ":custom_go_grpc_gateway_library"], embed = [":go_grpc_gateway_library"],
importpath = "github.com/prysmaticlabs/prysm/proto/eth/service", importpath = "github.com/prysmaticlabs/prysm/proto/eth/service",
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
) )

View File

@ -89,18 +89,15 @@ func local_request_KeyManagement_ImportKeystores_0(ctx context.Context, marshale
} }
var (
filter_KeyManagement_DeleteKeystores_0 = &utilities.DoubleArray{Encoding: map[string]int{}, Base: []int(nil), Check: []int(nil)}
)
func request_KeyManagement_DeleteKeystores_0(ctx context.Context, marshaler runtime.Marshaler, client KeyManagementClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { func request_KeyManagement_DeleteKeystores_0(ctx context.Context, marshaler runtime.Marshaler, client KeyManagementClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) {
var protoReq DeleteKeystoresRequest var protoReq DeleteKeystoresRequest
var metadata runtime.ServerMetadata var metadata runtime.ServerMetadata
if err := req.ParseForm(); err != nil { newReader, berr := utilities.IOReaderFactory(req.Body)
return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) if berr != nil {
return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", berr)
} }
if err := runtime.PopulateQueryParameters(&protoReq, req.Form, filter_KeyManagement_DeleteKeystores_0); err != nil { if err := marshaler.NewDecoder(newReader()).Decode(&protoReq); err != nil && err != io.EOF {
return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err)
} }
@ -113,10 +110,11 @@ func local_request_KeyManagement_DeleteKeystores_0(ctx context.Context, marshale
var protoReq DeleteKeystoresRequest var protoReq DeleteKeystoresRequest
var metadata runtime.ServerMetadata var metadata runtime.ServerMetadata
if err := req.ParseForm(); err != nil { newReader, berr := utilities.IOReaderFactory(req.Body)
return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) if berr != nil {
return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", berr)
} }
if err := runtime.PopulateQueryParameters(&protoReq, req.Form, filter_KeyManagement_DeleteKeystores_0); err != nil { if err := marshaler.NewDecoder(newReader()).Decode(&protoReq); err != nil && err != io.EOF {
return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err)
} }

View File

@ -29,6 +29,7 @@ go_library(
], ],
deps = [ deps = [
"//api/gateway:go_default_library", "//api/gateway:go_default_library",
"//api/gateway/apimiddleware:go_default_library",
"//async/event:go_default_library", "//async/event:go_default_library",
"//cmd:go_default_library", "//cmd:go_default_library",
"//cmd/validator/flags:go_default_library", "//cmd/validator/flags:go_default_library",
@ -38,6 +39,7 @@ go_library(
"//monitoring/backup:go_default_library", "//monitoring/backup:go_default_library",
"//monitoring/prometheus:go_default_library", "//monitoring/prometheus:go_default_library",
"//monitoring/tracing:go_default_library", "//monitoring/tracing:go_default_library",
"//proto/eth/service:go_default_library",
"//proto/prysm/v1alpha1:go_default_library", "//proto/prysm/v1alpha1:go_default_library",
"//proto/prysm/v1alpha1/validator-client:go_default_library", "//proto/prysm/v1alpha1/validator-client:go_default_library",
"//runtime:go_default_library", "//runtime:go_default_library",
@ -52,6 +54,7 @@ go_library(
"//validator/keymanager:go_default_library", "//validator/keymanager:go_default_library",
"//validator/keymanager/imported:go_default_library", "//validator/keymanager/imported:go_default_library",
"//validator/rpc:go_default_library", "//validator/rpc:go_default_library",
"//validator/rpc/apimiddleware:go_default_library",
"//validator/web:go_default_library", "//validator/web:go_default_library",
"@com_github_grpc_ecosystem_grpc_gateway_v2//runtime:go_default_library", "@com_github_grpc_ecosystem_grpc_gateway_v2//runtime:go_default_library",
"@com_github_pkg_errors//:go_default_library", "@com_github_pkg_errors//:go_default_library",

View File

@ -17,6 +17,7 @@ import (
gwruntime "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" gwruntime "github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/prysmaticlabs/prysm/api/gateway" "github.com/prysmaticlabs/prysm/api/gateway"
"github.com/prysmaticlabs/prysm/api/gateway/apimiddleware"
"github.com/prysmaticlabs/prysm/async/event" "github.com/prysmaticlabs/prysm/async/event"
"github.com/prysmaticlabs/prysm/cmd" "github.com/prysmaticlabs/prysm/cmd"
"github.com/prysmaticlabs/prysm/cmd/validator/flags" "github.com/prysmaticlabs/prysm/cmd/validator/flags"
@ -26,6 +27,7 @@ import (
"github.com/prysmaticlabs/prysm/monitoring/backup" "github.com/prysmaticlabs/prysm/monitoring/backup"
"github.com/prysmaticlabs/prysm/monitoring/prometheus" "github.com/prysmaticlabs/prysm/monitoring/prometheus"
tracing2 "github.com/prysmaticlabs/prysm/monitoring/tracing" tracing2 "github.com/prysmaticlabs/prysm/monitoring/tracing"
ethpbservice "github.com/prysmaticlabs/prysm/proto/eth/service"
pb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1" pb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1"
validatorpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1/validator-client" validatorpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1/validator-client"
"github.com/prysmaticlabs/prysm/runtime" "github.com/prysmaticlabs/prysm/runtime"
@ -40,6 +42,7 @@ import (
"github.com/prysmaticlabs/prysm/validator/keymanager" "github.com/prysmaticlabs/prysm/validator/keymanager"
"github.com/prysmaticlabs/prysm/validator/keymanager/imported" "github.com/prysmaticlabs/prysm/validator/keymanager/imported"
"github.com/prysmaticlabs/prysm/validator/rpc" "github.com/prysmaticlabs/prysm/validator/rpc"
validatorMiddleware "github.com/prysmaticlabs/prysm/validator/rpc/apimiddleware"
"github.com/prysmaticlabs/prysm/validator/web" "github.com/prysmaticlabs/prysm/validator/web"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/urfave/cli/v2" "github.com/urfave/cli/v2"
@ -487,12 +490,14 @@ func (c *ValidatorClient) registerRPCGatewayService(cliCtx *cli.Context) error {
validatorpb.RegisterAccountsHandler, validatorpb.RegisterAccountsHandler,
validatorpb.RegisterBeaconHandler, validatorpb.RegisterBeaconHandler,
validatorpb.RegisterSlashingProtectionHandler, validatorpb.RegisterSlashingProtectionHandler,
ethpbservice.RegisterKeyManagementHandler,
} }
mux := gwruntime.NewServeMux( gwmux := gwruntime.NewServeMux(
gwruntime.WithMarshalerOption(gwruntime.MIMEWildcard, &gwruntime.HTTPBodyMarshaler{ gwruntime.WithMarshalerOption(gwruntime.MIMEWildcard, &gwruntime.HTTPBodyMarshaler{
Marshaler: &gwruntime.JSONPb{ Marshaler: &gwruntime.JSONPb{
MarshalOptions: protojson.MarshalOptions{ MarshalOptions: protojson.MarshalOptions{
EmitUnpopulated: true, EmitUnpopulated: true,
UseProtoNames: true,
}, },
UnmarshalOptions: protojson.UnmarshalOptions{ UnmarshalOptions: protojson.UnmarshalOptions{
DiscardUnknown: true, DiscardUnknown: true,
@ -503,28 +508,42 @@ func (c *ValidatorClient) registerRPCGatewayService(cliCtx *cli.Context) error {
"text/event-stream", &gwruntime.EventSourceJSONPb{}, "text/event-stream", &gwruntime.EventSourceJSONPb{},
), ),
) )
muxHandler := func(h http.Handler, w http.ResponseWriter, req *http.Request) { muxHandler := func(apiMware *apimiddleware.ApiProxyMiddleware, h http.HandlerFunc, w http.ResponseWriter, req *http.Request) {
if strings.HasPrefix(req.URL.Path, "/api") { // The validator gateway handler requires this special logic as it serves two kinds of APIs, namely
http.StripPrefix("/api", h).ServeHTTP(w, req) // the standard validator keymanager API under the /eth namespace, and the Prysm internal
// validator API under the /api namespace. Finally, it also serves requests to host the validator web UI.
if strings.HasPrefix(req.URL.Path, "/api/eth/") {
req.URL.Path = strings.Replace(req.URL.Path, "/api", "", 1)
// If the prefix has /eth/, we handle it with the standard API gateway middleware.
apiMware.ServeHTTP(w, req)
} else if strings.HasPrefix(req.URL.Path, "/api") {
req.URL.Path = strings.Replace(req.URL.Path, "/api", "", 1)
// Else, we handle with the Prysm API gateway without a middleware.
h(w, req)
} else { } else {
// Finally, we handle with the web server.
web.Handler(w, req) web.Handler(w, req)
} }
} }
pbHandler := &gateway.PbMux{ pbHandler := &gateway.PbMux{
Registrations: registrations, Registrations: registrations,
Patterns: []string{"/accounts/", "/v2/"}, Patterns: []string{"/accounts/", "/v2/", "/internal/eth/v1/"},
Mux: mux, Mux: gwmux,
}
opts := []gateway.Option{
gateway.WithRemoteAddr(rpcAddr),
gateway.WithGatewayAddr(gatewayAddress),
gateway.WithMaxCallRecvMsgSize(maxCallSize),
gateway.WithPbHandlers([]*gateway.PbMux{pbHandler}),
gateway.WithAllowedOrigins(allowedOrigins),
gateway.WithApiMiddleware(&validatorMiddleware.ValidatorEndpointFactory{}),
gateway.WithMuxHandler(muxHandler),
}
gw, err := gateway.New(cliCtx.Context, opts...)
if err != nil {
return err
} }
gw := gateway.New(
cliCtx.Context,
[]*gateway.PbMux{pbHandler},
muxHandler,
rpcAddr,
gatewayAddress,
).WithAllowedOrigins(allowedOrigins).WithMaxCallRecvMsgSize(maxCallSize)
return c.services.RegisterService(gw) return c.services.RegisterService(gw)
} }

View File

@ -0,0 +1,15 @@
load("@prysm//tools/go:def.bzl", "go_library")
go_library(
name = "go_default_library",
srcs = [
"endpoint_factory.go",
"structs.go",
],
importpath = "github.com/prysmaticlabs/prysm/validator/rpc/apimiddleware",
visibility = ["//visibility:public"],
deps = [
"//api/gateway/apimiddleware:go_default_library",
"@com_github_pkg_errors//:go_default_library",
],
)

View File

@ -0,0 +1,38 @@
package apimiddleware
import (
"github.com/pkg/errors"
"github.com/prysmaticlabs/prysm/api/gateway/apimiddleware"
)
// ValidatorEndpointFactory creates endpoints used for running validator API calls through the API Middleware.
type ValidatorEndpointFactory struct {
}
func (f *ValidatorEndpointFactory) IsNil() bool {
return f == nil
}
// Paths is a collection of all valid validator API paths.
func (*ValidatorEndpointFactory) Paths() []string {
return []string{
"/eth/v1/keystores",
}
}
// Create returns a new endpoint for the provided API path.
func (*ValidatorEndpointFactory) Create(path string) (*apimiddleware.Endpoint, error) {
endpoint := apimiddleware.DefaultEndpoint()
switch path {
case "/eth/v1/keystores":
endpoint.GetResponse = &listKeystoresResponseJson{}
endpoint.PostRequest = &importKeystoresRequestJson{}
endpoint.PostResponse = &importKeystoresResponseJson{}
endpoint.DeleteRequest = &deleteKeystoresRequestJson{}
endpoint.DeleteResponse = &deleteKeystoresResponseJson{}
default:
return nil, errors.New("invalid path")
}
endpoint.Path = path
return &endpoint, nil
}

View File

@ -0,0 +1,34 @@
package apimiddleware
type listKeystoresResponseJson struct {
Keystores []*keystoreJson `json:"keystores"`
}
type keystoreJson struct {
ValidatingPubkey string `json:"validating_pubkey" hex:"true"`
DerivationPath string `json:"derivation_path"`
}
type importKeystoresRequestJson struct {
Keystores []string `json:"keystores"`
Passwords []string `json:"passwords"`
SlashingProtection string `json:"slashing_protection"`
}
type importKeystoresResponseJson struct {
Statuses []*statusJson `json:"statuses"`
}
type deleteKeystoresRequestJson struct {
PublicKeys []string `json:"public_keys" hex:"true"`
}
type statusJson struct {
Status string `json:"status"`
Message string `json:"message"`
}
type deleteKeystoresResponseJson struct {
Statuses []*statusJson `json:"statuses"`
SlashingProtection string `json:"slashing_protection"`
}

View File

@ -97,6 +97,9 @@ func (s *Server) DeleteKeystores(
if !ok { if !ok {
return nil, status.Error(codes.Internal, "Keymanager kind cannot delete keys") return nil, status.Error(codes.Internal, "Keymanager kind cannot delete keys")
} }
if len(req.PublicKeys) == 0 {
return &ethpbservice.DeleteKeystoresResponse{Statuses: make([]*ethpbservice.DeletedKeystoreStatus, 0)}, nil
}
statuses, err := deleter.DeleteKeystores(ctx, req.PublicKeys) statuses, err := deleter.DeleteKeystores(ctx, req.PublicKeys)
if err != nil { if err != nil {
return nil, status.Errorf(codes.Internal, "Could not delete keys: %v", err) return nil, status.Errorf(codes.Internal, "Could not delete keys: %v", err)

View File

@ -203,7 +203,6 @@ func TestServer_ImportKeystores(t *testing.T) {
} }
}) })
} }
func TestServer_DeleteKeystores(t *testing.T) { func TestServer_DeleteKeystores(t *testing.T) {
ctx := context.Background() ctx := context.Background()
srv := setupServerWithWallet(t) srv := setupServerWithWallet(t)
@ -247,6 +246,14 @@ func TestServer_DeleteKeystores(t *testing.T) {
}) })
require.NoError(t, err) require.NoError(t, err)
t.Run("no slashing protection response if no keys in request even if we have a history in DB", func(t *testing.T) {
resp, err := srv.DeleteKeystores(context.Background(), &ethpbservice.DeleteKeystoresRequest{
PublicKeys: nil,
})
require.NoError(t, err)
require.Equal(t, "", resp.SlashingProtection)
})
// For ease of test setup, we'll give each public key a string identifier. // For ease of test setup, we'll give each public key a string identifier.
publicKeysWithId := map[string][48]byte{ publicKeysWithId := map[string][48]byte{
"a": publicKeys[0], "a": publicKeys[0],