API Middleware for Keymanager Standard API Endpoints (#9936)

* begin the middleware approach

* attempt middleware

* middleware works in tandem with web ui

* handle delete as well

* delete request

* DELETE working

* tool to perform imports

* functioning

* commentary

* build

* gaz

* smol test

* enable keymanager api use protonames

* edit

* one rule

* rem gw

* Fix custom compiler

(cherry picked from commit 3b1f65919e04ddf7e07c8f60cba1be883a736476)

* gen proto

* imports

* Update validator/node/node.go

Co-authored-by: Radosław Kapka <rkapka@wp.pl>

* remaining comments

* update item

* rpc

* add

* run gateway

* simplify

* rem flag

* deep source

Co-authored-by: prestonvanloon <preston@prysmaticlabs.com>
Co-authored-by: Radosław Kapka <rkapka@wp.pl>
Co-authored-by: prylabs-bulldozer[bot] <58059840+prylabs-bulldozer[bot]@users.noreply.github.com>
This commit is contained in:
Raul Jordan 2021-12-07 15:26:21 -05:00 committed by GitHub
parent cee3b626f3
commit 424c8f6b46
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 424 additions and 251 deletions

View File

@ -5,6 +5,7 @@ go_library(
srcs = [
"gateway.go",
"log.go",
"options.go",
],
importpath = "github.com/prysmaticlabs/prysm/api/gateway",
visibility = [

View File

@ -14,6 +14,7 @@ import (
type ApiProxyMiddleware struct {
GatewayAddress string
EndpointCreator EndpointFactory
router *mux.Router
}
// EndpointFactory is responsible for creating new instances of Endpoint values.
@ -29,6 +30,8 @@ type Endpoint struct {
GetResponse interface{} // The struct corresponding to the JSON structure used in a GET response.
PostRequest interface{} // The struct corresponding to the JSON structure used in a POST request.
PostResponse interface{} // The struct corresponding to the JSON structure used in a POST response.
DeleteRequest interface{} // The struct corresponding to the JSON structure used in a DELETE request.
DeleteResponse interface{} // The struct corresponding to the JSON structure used in a DELETE response.
RequestURLLiterals []string // Names of URL parameters that should not be base64-encoded.
RequestQueryParams []QueryParam // Query parameters of the request.
Err ErrorJson // The struct corresponding to the error that should be returned in case of a request failure.
@ -74,18 +77,24 @@ type fieldProcessor struct {
// Run starts the proxy, registering all proxy endpoints.
func (m *ApiProxyMiddleware) Run(gatewayRouter *mux.Router) {
for _, path := range m.EndpointCreator.Paths() {
m.handleApiPath(gatewayRouter, path, m.EndpointCreator)
gatewayRouter.HandleFunc(path, m.WithMiddleware(path))
}
m.router = gatewayRouter
}
func (m *ApiProxyMiddleware) handleApiPath(gatewayRouter *mux.Router, path string, endpointFactory EndpointFactory) {
gatewayRouter.HandleFunc(path, func(w http.ResponseWriter, req *http.Request) {
endpoint, err := endpointFactory.Create(path)
if err != nil {
errJson := InternalServerErrorWithMessage(err, "could not create endpoint")
WriteError(w, errJson, nil)
}
// ServeHTTP for the proxy middleware.
func (m *ApiProxyMiddleware) ServeHTTP(w http.ResponseWriter, req *http.Request) {
m.router.ServeHTTP(w, req)
}
// WithMiddleware wraps the given endpoint handler with the middleware logic.
func (m *ApiProxyMiddleware) WithMiddleware(path string) http.HandlerFunc {
endpoint, err := m.EndpointCreator.Create(path)
if err != nil {
log.WithError(err).Errorf("Could not create endpoint for path: %s", path)
return nil
}
return func(w http.ResponseWriter, req *http.Request) {
for _, handler := range endpoint.CustomHandlers {
if handler(m, *endpoint, w, req) {
return
@ -93,16 +102,14 @@ func (m *ApiProxyMiddleware) handleApiPath(gatewayRouter *mux.Router, path strin
}
if req.Method == "POST" {
if errJson := deserializeRequestBodyIntoContainerWrapped(endpoint, req, w); errJson != nil {
if errJson := handlePostRequestForEndpoint(endpoint, w, req); errJson != nil {
WriteError(w, errJson, nil)
return
}
}
if errJson := ProcessRequestContainerFields(endpoint.PostRequest); errJson != nil {
WriteError(w, errJson, nil)
return
}
if errJson := SetRequestBodyToRequestContainer(endpoint.PostRequest, req); errJson != nil {
if req.Method == "DELETE" {
if errJson := handleDeleteRequestForEndpoint(endpoint, req); errJson != nil {
WriteError(w, errJson, nil)
return
}
@ -137,6 +144,8 @@ func (m *ApiProxyMiddleware) handleApiPath(gatewayRouter *mux.Router, path strin
var resp interface{}
if req.Method == "GET" {
resp = endpoint.GetResponse
} else if req.Method == "DELETE" {
resp = endpoint.DeleteResponse
} else {
resp = endpoint.PostResponse
}
@ -164,7 +173,27 @@ func (m *ApiProxyMiddleware) handleApiPath(gatewayRouter *mux.Router, path strin
WriteError(w, errJson, nil)
return
}
})
}
}
func handlePostRequestForEndpoint(endpoint *Endpoint, w http.ResponseWriter, req *http.Request) ErrorJson {
if errJson := deserializeRequestBodyIntoContainerWrapped(endpoint, req, w); errJson != nil {
return errJson
}
if errJson := ProcessRequestContainerFields(endpoint.PostRequest); errJson != nil {
return errJson
}
return SetRequestBodyToRequestContainer(endpoint.PostRequest, req)
}
func handleDeleteRequestForEndpoint(endpoint *Endpoint, req *http.Request) ErrorJson {
if errJson := DeserializeRequestBodyIntoContainer(req.Body, endpoint.DeleteRequest); errJson != nil {
return errJson
}
if errJson := ProcessRequestContainerFields(endpoint.DeleteRequest); errJson != nil {
return errJson
}
return SetRequestBodyToRequestContainer(endpoint.DeleteRequest, req)
}
func deserializeRequestBodyIntoContainerWrapped(endpoint *Endpoint, req *http.Request, w http.ResponseWriter) ErrorJson {

View File

@ -34,74 +34,51 @@ type PbMux struct {
type PbHandlerRegistration func(context.Context, *gwruntime.ServeMux, *grpc.ClientConn) error
// MuxHandler is a function that implements the mux handler functionality.
type MuxHandler func(http.Handler, http.ResponseWriter, *http.Request)
type MuxHandler func(
apiMiddlewareHandler *apimiddleware.ApiProxyMiddleware,
h http.HandlerFunc,
w http.ResponseWriter,
req *http.Request,
)
// Config parameters for setting up the gateway service.
type config struct {
maxCallRecvMsgSize uint64
remoteCert string
gatewayAddr string
remoteAddr string
allowedOrigins []string
apiMiddlewareEndpointFactory apimiddleware.EndpointFactory
muxHandler MuxHandler
pbHandlers []*PbMux
router *mux.Router
}
// Gateway is the gRPC gateway to serve HTTP JSON traffic as a proxy and forward it to the gRPC server.
type Gateway struct {
conn *grpc.ClientConn
pbHandlers []*PbMux
muxHandler MuxHandler
maxCallRecvMsgSize uint64
router *mux.Router
server *http.Server
cancel context.CancelFunc
remoteCert string
gatewayAddr string
apiMiddlewareEndpointFactory apimiddleware.EndpointFactory
ctx context.Context
startFailure error
remoteAddr string
allowedOrigins []string
cfg *config
conn *grpc.ClientConn
server *http.Server
cancel context.CancelFunc
proxy *apimiddleware.ApiProxyMiddleware
ctx context.Context
startFailure error
}
// New returns a new instance of the Gateway.
func New(
ctx context.Context,
pbHandlers []*PbMux,
muxHandler MuxHandler,
remoteAddr,
gatewayAddress string,
) *Gateway {
func New(ctx context.Context, opts ...Option) (*Gateway, error) {
g := &Gateway{
pbHandlers: pbHandlers,
muxHandler: muxHandler,
router: mux.NewRouter(),
gatewayAddr: gatewayAddress,
ctx: ctx,
remoteAddr: remoteAddr,
allowedOrigins: []string{},
ctx: ctx,
cfg: &config{
router: mux.NewRouter(),
},
}
return g
}
// WithRouter allows adding a custom mux router to the gateway.
func (g *Gateway) WithRouter(r *mux.Router) *Gateway {
g.router = r
return g
}
// WithAllowedOrigins allows adding a set of allowed origins to the gateway.
func (g *Gateway) WithAllowedOrigins(origins []string) *Gateway {
g.allowedOrigins = origins
return g
}
// WithRemoteCert allows adding a custom certificate to the gateway,
func (g *Gateway) WithRemoteCert(cert string) *Gateway {
g.remoteCert = cert
return g
}
// WithMaxCallRecvMsgSize allows specifying the maximum allowed gRPC message size.
func (g *Gateway) WithMaxCallRecvMsgSize(size uint64) *Gateway {
g.maxCallRecvMsgSize = size
return g
}
// WithApiMiddleware allows adding API Middleware proxy to the gateway.
func (g *Gateway) WithApiMiddleware(endpointFactory apimiddleware.EndpointFactory) *Gateway {
g.apiMiddlewareEndpointFactory = endpointFactory
return g
for _, opt := range opts {
if err := opt(g); err != nil {
return nil, err
}
}
return g, nil
}
// Start the gateway service.
@ -109,7 +86,7 @@ func (g *Gateway) Start() {
ctx, cancel := context.WithCancel(g.ctx)
g.cancel = cancel
conn, err := g.dial(ctx, "tcp", g.remoteAddr)
conn, err := g.dial(ctx, "tcp", g.cfg.remoteAddr)
if err != nil {
log.WithError(err).Error("Failed to connect to gRPC server")
g.startFailure = err
@ -117,7 +94,7 @@ func (g *Gateway) Start() {
}
g.conn = conn
for _, h := range g.pbHandlers {
for _, h := range g.cfg.pbHandlers {
for _, r := range h.Registrations {
if err := r(ctx, h.Mux, g.conn); err != nil {
log.WithError(err).Error("Failed to register handler")
@ -126,29 +103,30 @@ func (g *Gateway) Start() {
}
}
for _, p := range h.Patterns {
g.router.PathPrefix(p).Handler(h.Mux)
g.cfg.router.PathPrefix(p).Handler(h.Mux)
}
}
corsMux := g.corsMiddleware(g.router)
corsMux := g.corsMiddleware(g.cfg.router)
_ = corsMux
if g.muxHandler != nil {
g.router.PathPrefix("/").HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
g.muxHandler(corsMux, w, r)
if g.cfg.apiMiddlewareEndpointFactory != nil && !g.cfg.apiMiddlewareEndpointFactory.IsNil() {
g.registerApiMiddleware()
}
if g.cfg.muxHandler != nil {
g.cfg.router.PathPrefix("/").HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
g.cfg.muxHandler(g.proxy, g.cfg.router.ServeHTTP, w, r)
})
}
if g.apiMiddlewareEndpointFactory != nil && !g.apiMiddlewareEndpointFactory.IsNil() {
go g.registerApiMiddleware()
}
g.server = &http.Server{
Addr: g.gatewayAddr,
Handler: g.router,
Addr: g.cfg.gatewayAddr,
Handler: g.cfg.router,
}
go func() {
log.WithField("address", g.gatewayAddr).Info("Starting gRPC gateway")
log.WithField("address", g.cfg.gatewayAddr).Info("Starting gRPC gateway")
if err := g.server.ListenAndServe(); err != http.ErrServerClosed {
log.WithError(err).Error("Failed to start gRPC gateway")
g.startFailure = err
@ -162,11 +140,9 @@ func (g *Gateway) Status() error {
if g.startFailure != nil {
return g.startFailure
}
if s := g.conn.GetState(); s != connectivity.Ready {
return fmt.Errorf("grpc server is %s", s)
}
return nil
}
@ -183,18 +159,16 @@ func (g *Gateway) Stop() error {
}
}
}
if g.cancel != nil {
g.cancel()
}
return nil
}
func (g *Gateway) corsMiddleware(h http.Handler) http.Handler {
c := cors.New(cors.Options{
AllowedOrigins: g.allowedOrigins,
AllowedMethods: []string{http.MethodPost, http.MethodGet, http.MethodOptions},
AllowedOrigins: g.cfg.allowedOrigins,
AllowedMethods: []string{http.MethodPost, http.MethodGet, http.MethodDelete, http.MethodOptions},
AllowCredentials: true,
MaxAge: 600,
AllowedHeaders: []string{"*"},
@ -236,8 +210,8 @@ func (g *Gateway) dial(ctx context.Context, network, addr string) (*grpc.ClientC
// "addr" must be a valid TCP address with a port number.
func (g *Gateway) dialTCP(ctx context.Context, addr string) (*grpc.ClientConn, error) {
security := grpc.WithInsecure()
if len(g.remoteCert) > 0 {
creds, err := credentials.NewClientTLSFromFile(g.remoteCert, "")
if len(g.cfg.remoteCert) > 0 {
creds, err := credentials.NewClientTLSFromFile(g.cfg.remoteCert, "")
if err != nil {
return nil, err
}
@ -245,7 +219,7 @@ func (g *Gateway) dialTCP(ctx context.Context, addr string) (*grpc.ClientConn, e
}
opts := []grpc.DialOption{
security,
grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(int(g.maxCallRecvMsgSize))),
grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(int(g.cfg.maxCallRecvMsgSize))),
}
return grpc.DialContext(ctx, addr, opts...)
@ -266,16 +240,16 @@ func (g *Gateway) dialUnix(ctx context.Context, addr string) (*grpc.ClientConn,
opts := []grpc.DialOption{
grpc.WithInsecure(),
grpc.WithContextDialer(f),
grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(int(g.maxCallRecvMsgSize))),
grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(int(g.cfg.maxCallRecvMsgSize))),
}
return grpc.DialContext(ctx, addr, opts...)
}
func (g *Gateway) registerApiMiddleware() {
proxy := &apimiddleware.ApiProxyMiddleware{
GatewayAddress: g.gatewayAddr,
EndpointCreator: g.apiMiddlewareEndpointFactory,
g.proxy = &apimiddleware.ApiProxyMiddleware{
GatewayAddress: g.cfg.gatewayAddr,
EndpointCreator: g.cfg.apiMiddlewareEndpointFactory,
}
log.Info("Starting API middleware")
proxy.Run(g.router)
g.proxy.Run(g.cfg.router)
}

View File

@ -40,26 +40,30 @@ func TestGateway_Customized(t *testing.T) {
size := uint64(100)
endpointFactory := &mockEndpointFactory{}
g := New(
context.Background(),
[]*PbMux{},
func(handler http.Handler, writer http.ResponseWriter, request *http.Request) {
opts := []Option{
WithRouter(r),
WithRemoteCert(cert),
WithAllowedOrigins(origins),
WithMaxCallRecvMsgSize(size),
WithApiMiddleware(endpointFactory),
WithMuxHandler(func(
_ *apimiddleware.ApiProxyMiddleware,
_ http.HandlerFunc,
_ http.ResponseWriter,
_ *http.Request,
) {
}),
}
},
"",
"",
).WithRouter(r).
WithRemoteCert(cert).
WithAllowedOrigins(origins).
WithMaxCallRecvMsgSize(size).
WithApiMiddleware(endpointFactory)
g, err := New(context.Background(), opts...)
require.NoError(t, err)
assert.Equal(t, r, g.router)
assert.Equal(t, cert, g.remoteCert)
require.Equal(t, 1, len(g.allowedOrigins))
assert.Equal(t, origins[0], g.allowedOrigins[0])
assert.Equal(t, size, g.maxCallRecvMsgSize)
assert.Equal(t, endpointFactory, g.apiMiddlewareEndpointFactory)
assert.Equal(t, r, g.cfg.router)
assert.Equal(t, cert, g.cfg.remoteCert)
require.Equal(t, 1, len(g.cfg.allowedOrigins))
assert.Equal(t, origins[0], g.cfg.allowedOrigins[0])
assert.Equal(t, size, g.cfg.maxCallRecvMsgSize)
assert.Equal(t, endpointFactory, g.cfg.apiMiddlewareEndpointFactory)
}
func TestGateway_StartStop(t *testing.T) {
@ -75,23 +79,27 @@ func TestGateway_StartStop(t *testing.T) {
selfAddress := fmt.Sprintf("%s:%d", rpcHost, ctx.Int(flags.RPCPort.Name))
gatewayAddress := fmt.Sprintf("%s:%d", gatewayHost, gatewayPort)
g := New(
ctx.Context,
[]*PbMux{},
func(handler http.Handler, writer http.ResponseWriter, request *http.Request) {
opts := []Option{
WithGatewayAddr(gatewayAddress),
WithRemoteAddr(selfAddress),
WithMuxHandler(func(
_ *apimiddleware.ApiProxyMiddleware,
_ http.HandlerFunc,
_ http.ResponseWriter,
_ *http.Request,
) {
}),
}
},
selfAddress,
gatewayAddress,
)
g, err := New(context.Background(), opts...)
require.NoError(t, err)
g.Start()
go func() {
require.LogsContain(t, hook, "Starting gRPC gateway")
require.LogsDoNotContain(t, hook, "Starting API middleware")
}()
err := g.Stop()
err = g.Stop()
require.NoError(t, err)
}
@ -106,15 +114,15 @@ func TestGateway_NilHandler_NotFoundHandlerRegistered(t *testing.T) {
selfAddress := fmt.Sprintf("%s:%d", rpcHost, ctx.Int(flags.RPCPort.Name))
gatewayAddress := fmt.Sprintf("%s:%d", gatewayHost, gatewayPort)
g := New(
ctx.Context,
[]*PbMux{},
/* muxHandler */ nil,
selfAddress,
gatewayAddress,
)
opts := []Option{
WithGatewayAddr(gatewayAddress),
WithRemoteAddr(selfAddress),
}
g, err := New(context.Background(), opts...)
require.NoError(t, err)
writer := httptest.NewRecorder()
g.router.ServeHTTP(writer, &http.Request{Method: "GET", Host: "localhost", URL: &url.URL{Path: "/foo"}})
g.cfg.router.ServeHTTP(writer, &http.Request{Method: "GET", Host: "localhost", URL: &url.URL{Path: "/foo"}})
assert.Equal(t, http.StatusNotFound, writer.Code)
}

81
api/gateway/options.go Normal file
View File

@ -0,0 +1,81 @@
package gateway
import (
"github.com/gorilla/mux"
"github.com/prysmaticlabs/prysm/api/gateway/apimiddleware"
)
type Option func(g *Gateway) error
func (g *Gateway) SetRouter(r *mux.Router) *Gateway {
g.cfg.router = r
return g
}
func WithPbHandlers(handlers []*PbMux) Option {
return func(g *Gateway) error {
g.cfg.pbHandlers = handlers
return nil
}
}
func WithMuxHandler(m MuxHandler) Option {
return func(g *Gateway) error {
g.cfg.muxHandler = m
return nil
}
}
func WithGatewayAddr(addr string) Option {
return func(g *Gateway) error {
g.cfg.gatewayAddr = addr
return nil
}
}
func WithRemoteAddr(addr string) Option {
return func(g *Gateway) error {
g.cfg.remoteAddr = addr
return nil
}
}
// WithRouter allows adding a custom mux router to the gateway.
func WithRouter(r *mux.Router) Option {
return func(g *Gateway) error {
g.cfg.router = r
return nil
}
}
// WithAllowedOrigins allows adding a set of allowed origins to the gateway.
func WithAllowedOrigins(origins []string) Option {
return func(g *Gateway) error {
g.cfg.allowedOrigins = origins
return nil
}
}
// WithRemoteCert allows adding a custom certificate to the gateway,
func WithRemoteCert(cert string) Option {
return func(g *Gateway) error {
g.cfg.remoteCert = cert
return nil
}
}
// WithMaxCallRecvMsgSize allows specifying the maximum allowed gRPC message size.
func WithMaxCallRecvMsgSize(size uint64) Option {
return func(g *Gateway) error {
g.cfg.maxCallRecvMsgSize = size
return nil
}
}
// WithApiMiddleware allows adding an API middleware proxy to the gateway.
func WithApiMiddleware(endpointFactory apimiddleware.EndpointFactory) Option {
return func(g *Gateway) error {
g.cfg.apiMiddlewareEndpointFactory = endpointFactory
return nil
}
}

View File

@ -830,19 +830,22 @@ func (b *BeaconNode) registerGRPCGateway() error {
muxs = append(muxs, gatewayConfig.EthPbMux)
}
g := apigateway.New(
b.ctx,
muxs,
gatewayConfig.Handler,
selfAddress,
gatewayAddress,
).WithAllowedOrigins(allowedOrigins).
WithRemoteCert(selfCert).
WithMaxCallRecvMsgSize(maxCallSize)
if flags.EnableHTTPEthAPI(httpModules) {
g.WithApiMiddleware(&apimiddleware.BeaconEndpointFactory{})
opts := []apigateway.Option{
apigateway.WithGatewayAddr(gatewayAddress),
apigateway.WithRemoteAddr(selfAddress),
apigateway.WithPbHandlers(muxs),
apigateway.WithMuxHandler(gatewayConfig.Handler),
apigateway.WithRemoteCert(selfCert),
apigateway.WithMaxCallRecvMsgSize(maxCallSize),
apigateway.WithAllowedOrigins(allowedOrigins),
}
if flags.EnableHTTPEthAPI(httpModules) {
opts = append(opts, apigateway.WithApiMiddleware(&apimiddleware.BeaconEndpointFactory{}))
}
g, err := apigateway.New(b.ctx, opts...)
if err != nil {
return err
}
return b.services.RegisterService(g)
}

View File

@ -52,23 +52,28 @@ func main() {
if gatewayConfig.EthPbMux != nil {
muxs = append(muxs, gatewayConfig.EthPbMux)
}
opts := []gateway.Option{
gateway.WithPbHandlers(muxs),
gateway.WithMuxHandler(gatewayConfig.Handler),
gateway.WithRemoteAddr(*beaconRPC),
gateway.WithGatewayAddr(fmt.Sprintf("%s:%d", *host, *port)),
gateway.WithAllowedOrigins(strings.Split(*allowedOrigins, ",")),
gateway.WithMaxCallRecvMsgSize(uint64(*grpcMaxMsgSize)),
}
gw := gateway.New(
context.Background(),
muxs,
gatewayConfig.Handler,
*beaconRPC,
fmt.Sprintf("%s:%d", *host, *port),
).WithAllowedOrigins(strings.Split(*allowedOrigins, ",")).
WithMaxCallRecvMsgSize(uint64(*grpcMaxMsgSize))
if flags.EnableHTTPEthAPI(*httpModules) {
gw.WithApiMiddleware(&apimiddleware.BeaconEndpointFactory{})
opts = append(opts, gateway.WithApiMiddleware(&apimiddleware.BeaconEndpointFactory{}))
}
gw, err := gateway.New(context.Background(), opts...)
if err != nil {
log.Fatal(err)
}
r := mux.NewRouter()
r.HandleFunc("/swagger/", gateway.SwaggerServer())
r.HandleFunc("/healthz", healthzServer(gw))
gw = gw.WithRouter(r)
gw.SetRouter(r)
gw.Start()

View File

@ -12,27 +12,6 @@ proto_library(
"events_service.proto",
"node_service.proto",
"validator_service.proto",
],
visibility = ["//visibility:public"],
deps = [
"//proto/eth/ext:proto",
"//proto/eth/v1:proto",
"//proto/eth/v2:proto",
"@com_google_protobuf//:descriptor_proto",
"@com_google_protobuf//:empty_proto",
"@com_google_protobuf//:timestamp_proto",
"@go_googleapis//google/api:annotations_proto",
"@com_github_grpc_ecosystem_grpc_gateway_v2//proto/gateway:event_source_proto",
],
)
# We create a custom proto library for key_management.proto as it requires a different
# compiler plugin for grpc gateway than the others. Namely, it requires adding the option
# --allow_delete_body=true to allow grpc gateway endpoints to take in DELETE HTTP requests
# with a request body properly.
proto_library(
name = "custom_proto",
srcs = [
"key_management.proto",
],
visibility = ["//visibility:public"],
@ -67,35 +46,32 @@ go_proto_library(
],
)
go_proto_library(
name = "custom_go_proto",
compilers = ["@com_github_prysmaticlabs_protoc_gen_go_cast//:go_cast_grpc",],
importpath = "github.com/prysmaticlabs/prysm/proto/eth/service",
proto = ":custom_proto",
visibility = ["//visibility:public"],
deps = [
"//proto/eth/ext:go_default_library",
"//proto/eth/v1:go_default_library",
"//proto/eth/v2:go_default_library",
"@io_bazel_rules_go//proto/wkt:descriptor_go_proto",
"@io_bazel_rules_go//proto/wkt:empty_go_proto",
"@com_github_prysmaticlabs_eth2_types//:go_default_library",
"@go_googleapis//google/api:annotations_go_proto",
"@com_github_golang_protobuf//proto:go_default_library",
"@com_github_grpc_ecosystem_grpc_gateway_v2//proto/gateway:go_default_library",
],
)
go_proto_compiler(
name = "allow_delete_body_gateway_compiler",
options = [
"logtostderr=true",
"allow_repeated_fields_in_body=true",
"allow_delete_body=true",
],
plugin = "@com_github_grpc_ecosystem_grpc_gateway_v2//protoc-gen-grpc-gateway:protoc-gen-grpc-gateway",
options = ["allow_delete_body=true"],
suffix = ".pb.gw.go",
visibility = ["//visibility:public"],
deps = [
"@com_github_grpc_ecosystem_grpc_gateway_v2//runtime:go_default_library",
"@com_github_grpc_ecosystem_grpc_gateway_v2//utilities:go_default_library",
"@org_golang_google_grpc//:go_default_library",
"@org_golang_google_grpc//codes:go_default_library",
"@org_golang_google_grpc//grpclog:go_default_library",
"@org_golang_google_grpc//metadata:go_default_library",
"@org_golang_google_grpc//status:go_default_library",
"@org_golang_google_protobuf//proto:go_default_library",
],
)
go_proto_library(
name = "go_grpc_gateway_library",
compilers = [
"@com_github_grpc_ecosystem_grpc_gateway_v2//protoc-gen-grpc-gateway:go_gen_grpc_gateway",
"allow_delete_body_gateway_compiler",
],
embed = [":go_proto"],
importpath = "github.com/prysmaticlabs/prysm/proto/eth/service",
@ -113,30 +89,9 @@ go_proto_library(
],
)
go_proto_library(
name = "custom_go_grpc_gateway_library",
compilers = [
"allow_delete_body_gateway_compiler",
],
embed = [":custom_go_proto"],
importpath = "github.com/prysmaticlabs/prysm/proto/eth/service",
protos = [":custom_proto"],
visibility = ["//proto:__subpackages__"],
deps = [
"//proto/eth/ext:go_default_library",
"@io_bazel_rules_go//proto/wkt:empty_go_proto",
"@com_github_grpc_ecosystem_grpc_gateway_v2//protoc-gen-openapiv2/options:options_go_proto",
"@com_github_prysmaticlabs_go_bitfield//:go_default_library",
"@go_googleapis//google/api:annotations_go_proto",
"@io_bazel_rules_go//proto/wkt:timestamp_go_proto",
"@io_bazel_rules_go//proto/wkt:descriptor_go_proto",
"@com_github_grpc_ecosystem_grpc_gateway_v2//proto/gateway:go_default_library",
],
)
go_library(
name = "go_default_library",
embed = [":go_grpc_gateway_library", ":custom_go_grpc_gateway_library"],
embed = [":go_grpc_gateway_library"],
importpath = "github.com/prysmaticlabs/prysm/proto/eth/service",
visibility = ["//visibility:public"],
)

View File

@ -89,18 +89,15 @@ func local_request_KeyManagement_ImportKeystores_0(ctx context.Context, marshale
}
var (
filter_KeyManagement_DeleteKeystores_0 = &utilities.DoubleArray{Encoding: map[string]int{}, Base: []int(nil), Check: []int(nil)}
)
func request_KeyManagement_DeleteKeystores_0(ctx context.Context, marshaler runtime.Marshaler, client KeyManagementClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) {
var protoReq DeleteKeystoresRequest
var metadata runtime.ServerMetadata
if err := req.ParseForm(); err != nil {
return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err)
newReader, berr := utilities.IOReaderFactory(req.Body)
if berr != nil {
return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", berr)
}
if err := runtime.PopulateQueryParameters(&protoReq, req.Form, filter_KeyManagement_DeleteKeystores_0); err != nil {
if err := marshaler.NewDecoder(newReader()).Decode(&protoReq); err != nil && err != io.EOF {
return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err)
}
@ -113,10 +110,11 @@ func local_request_KeyManagement_DeleteKeystores_0(ctx context.Context, marshale
var protoReq DeleteKeystoresRequest
var metadata runtime.ServerMetadata
if err := req.ParseForm(); err != nil {
return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err)
newReader, berr := utilities.IOReaderFactory(req.Body)
if berr != nil {
return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", berr)
}
if err := runtime.PopulateQueryParameters(&protoReq, req.Form, filter_KeyManagement_DeleteKeystores_0); err != nil {
if err := marshaler.NewDecoder(newReader()).Decode(&protoReq); err != nil && err != io.EOF {
return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err)
}

View File

@ -29,6 +29,7 @@ go_library(
],
deps = [
"//api/gateway:go_default_library",
"//api/gateway/apimiddleware:go_default_library",
"//async/event:go_default_library",
"//cmd:go_default_library",
"//cmd/validator/flags:go_default_library",
@ -38,6 +39,7 @@ go_library(
"//monitoring/backup:go_default_library",
"//monitoring/prometheus:go_default_library",
"//monitoring/tracing:go_default_library",
"//proto/eth/service:go_default_library",
"//proto/prysm/v1alpha1:go_default_library",
"//proto/prysm/v1alpha1/validator-client:go_default_library",
"//runtime:go_default_library",
@ -52,6 +54,7 @@ go_library(
"//validator/keymanager:go_default_library",
"//validator/keymanager/imported:go_default_library",
"//validator/rpc:go_default_library",
"//validator/rpc/apimiddleware:go_default_library",
"//validator/web:go_default_library",
"@com_github_grpc_ecosystem_grpc_gateway_v2//runtime:go_default_library",
"@com_github_pkg_errors//:go_default_library",

View File

@ -17,6 +17,7 @@ import (
gwruntime "github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
"github.com/pkg/errors"
"github.com/prysmaticlabs/prysm/api/gateway"
"github.com/prysmaticlabs/prysm/api/gateway/apimiddleware"
"github.com/prysmaticlabs/prysm/async/event"
"github.com/prysmaticlabs/prysm/cmd"
"github.com/prysmaticlabs/prysm/cmd/validator/flags"
@ -26,6 +27,7 @@ import (
"github.com/prysmaticlabs/prysm/monitoring/backup"
"github.com/prysmaticlabs/prysm/monitoring/prometheus"
tracing2 "github.com/prysmaticlabs/prysm/monitoring/tracing"
ethpbservice "github.com/prysmaticlabs/prysm/proto/eth/service"
pb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1"
validatorpb "github.com/prysmaticlabs/prysm/proto/prysm/v1alpha1/validator-client"
"github.com/prysmaticlabs/prysm/runtime"
@ -40,6 +42,7 @@ import (
"github.com/prysmaticlabs/prysm/validator/keymanager"
"github.com/prysmaticlabs/prysm/validator/keymanager/imported"
"github.com/prysmaticlabs/prysm/validator/rpc"
validatorMiddleware "github.com/prysmaticlabs/prysm/validator/rpc/apimiddleware"
"github.com/prysmaticlabs/prysm/validator/web"
"github.com/sirupsen/logrus"
"github.com/urfave/cli/v2"
@ -487,12 +490,14 @@ func (c *ValidatorClient) registerRPCGatewayService(cliCtx *cli.Context) error {
validatorpb.RegisterAccountsHandler,
validatorpb.RegisterBeaconHandler,
validatorpb.RegisterSlashingProtectionHandler,
ethpbservice.RegisterKeyManagementHandler,
}
mux := gwruntime.NewServeMux(
gwmux := gwruntime.NewServeMux(
gwruntime.WithMarshalerOption(gwruntime.MIMEWildcard, &gwruntime.HTTPBodyMarshaler{
Marshaler: &gwruntime.JSONPb{
MarshalOptions: protojson.MarshalOptions{
EmitUnpopulated: true,
UseProtoNames: true,
},
UnmarshalOptions: protojson.UnmarshalOptions{
DiscardUnknown: true,
@ -503,28 +508,42 @@ func (c *ValidatorClient) registerRPCGatewayService(cliCtx *cli.Context) error {
"text/event-stream", &gwruntime.EventSourceJSONPb{},
),
)
muxHandler := func(h http.Handler, w http.ResponseWriter, req *http.Request) {
if strings.HasPrefix(req.URL.Path, "/api") {
http.StripPrefix("/api", h).ServeHTTP(w, req)
muxHandler := func(apiMware *apimiddleware.ApiProxyMiddleware, h http.HandlerFunc, w http.ResponseWriter, req *http.Request) {
// The validator gateway handler requires this special logic as it serves two kinds of APIs, namely
// the standard validator keymanager API under the /eth namespace, and the Prysm internal
// validator API under the /api namespace. Finally, it also serves requests to host the validator web UI.
if strings.HasPrefix(req.URL.Path, "/api/eth/") {
req.URL.Path = strings.Replace(req.URL.Path, "/api", "", 1)
// If the prefix has /eth/, we handle it with the standard API gateway middleware.
apiMware.ServeHTTP(w, req)
} else if strings.HasPrefix(req.URL.Path, "/api") {
req.URL.Path = strings.Replace(req.URL.Path, "/api", "", 1)
// Else, we handle with the Prysm API gateway without a middleware.
h(w, req)
} else {
// Finally, we handle with the web server.
web.Handler(w, req)
}
}
pbHandler := &gateway.PbMux{
Registrations: registrations,
Patterns: []string{"/accounts/", "/v2/"},
Mux: mux,
Patterns: []string{"/accounts/", "/v2/", "/internal/eth/v1/"},
Mux: gwmux,
}
opts := []gateway.Option{
gateway.WithRemoteAddr(rpcAddr),
gateway.WithGatewayAddr(gatewayAddress),
gateway.WithMaxCallRecvMsgSize(maxCallSize),
gateway.WithPbHandlers([]*gateway.PbMux{pbHandler}),
gateway.WithAllowedOrigins(allowedOrigins),
gateway.WithApiMiddleware(&validatorMiddleware.ValidatorEndpointFactory{}),
gateway.WithMuxHandler(muxHandler),
}
gw, err := gateway.New(cliCtx.Context, opts...)
if err != nil {
return err
}
gw := gateway.New(
cliCtx.Context,
[]*gateway.PbMux{pbHandler},
muxHandler,
rpcAddr,
gatewayAddress,
).WithAllowedOrigins(allowedOrigins).WithMaxCallRecvMsgSize(maxCallSize)
return c.services.RegisterService(gw)
}

View File

@ -0,0 +1,15 @@
load("@prysm//tools/go:def.bzl", "go_library")
go_library(
name = "go_default_library",
srcs = [
"endpoint_factory.go",
"structs.go",
],
importpath = "github.com/prysmaticlabs/prysm/validator/rpc/apimiddleware",
visibility = ["//visibility:public"],
deps = [
"//api/gateway/apimiddleware:go_default_library",
"@com_github_pkg_errors//:go_default_library",
],
)

View File

@ -0,0 +1,38 @@
package apimiddleware
import (
"github.com/pkg/errors"
"github.com/prysmaticlabs/prysm/api/gateway/apimiddleware"
)
// ValidatorEndpointFactory creates endpoints used for running validator API calls through the API Middleware.
type ValidatorEndpointFactory struct {
}
func (f *ValidatorEndpointFactory) IsNil() bool {
return f == nil
}
// Paths is a collection of all valid validator API paths.
func (*ValidatorEndpointFactory) Paths() []string {
return []string{
"/eth/v1/keystores",
}
}
// Create returns a new endpoint for the provided API path.
func (*ValidatorEndpointFactory) Create(path string) (*apimiddleware.Endpoint, error) {
endpoint := apimiddleware.DefaultEndpoint()
switch path {
case "/eth/v1/keystores":
endpoint.GetResponse = &listKeystoresResponseJson{}
endpoint.PostRequest = &importKeystoresRequestJson{}
endpoint.PostResponse = &importKeystoresResponseJson{}
endpoint.DeleteRequest = &deleteKeystoresRequestJson{}
endpoint.DeleteResponse = &deleteKeystoresResponseJson{}
default:
return nil, errors.New("invalid path")
}
endpoint.Path = path
return &endpoint, nil
}

View File

@ -0,0 +1,34 @@
package apimiddleware
type listKeystoresResponseJson struct {
Keystores []*keystoreJson `json:"keystores"`
}
type keystoreJson struct {
ValidatingPubkey string `json:"validating_pubkey" hex:"true"`
DerivationPath string `json:"derivation_path"`
}
type importKeystoresRequestJson struct {
Keystores []string `json:"keystores"`
Passwords []string `json:"passwords"`
SlashingProtection string `json:"slashing_protection"`
}
type importKeystoresResponseJson struct {
Statuses []*statusJson `json:"statuses"`
}
type deleteKeystoresRequestJson struct {
PublicKeys []string `json:"public_keys" hex:"true"`
}
type statusJson struct {
Status string `json:"status"`
Message string `json:"message"`
}
type deleteKeystoresResponseJson struct {
Statuses []*statusJson `json:"statuses"`
SlashingProtection string `json:"slashing_protection"`
}

View File

@ -97,6 +97,9 @@ func (s *Server) DeleteKeystores(
if !ok {
return nil, status.Error(codes.Internal, "Keymanager kind cannot delete keys")
}
if len(req.PublicKeys) == 0 {
return &ethpbservice.DeleteKeystoresResponse{Statuses: make([]*ethpbservice.DeletedKeystoreStatus, 0)}, nil
}
statuses, err := deleter.DeleteKeystores(ctx, req.PublicKeys)
if err != nil {
return nil, status.Errorf(codes.Internal, "Could not delete keys: %v", err)

View File

@ -203,7 +203,6 @@ func TestServer_ImportKeystores(t *testing.T) {
}
})
}
func TestServer_DeleteKeystores(t *testing.T) {
ctx := context.Background()
srv := setupServerWithWallet(t)
@ -247,6 +246,14 @@ func TestServer_DeleteKeystores(t *testing.T) {
})
require.NoError(t, err)
t.Run("no slashing protection response if no keys in request even if we have a history in DB", func(t *testing.T) {
resp, err := srv.DeleteKeystores(context.Background(), &ethpbservice.DeleteKeystoresRequest{
PublicKeys: nil,
})
require.NoError(t, err)
require.Equal(t, "", resp.SlashingProtection)
})
// For ease of test setup, we'll give each public key a string identifier.
publicKeysWithId := map[string][48]byte{
"a": publicKeys[0],