mirror of
https://gitlab.com/pulsechaincom/prysm-pulse.git
synced 2025-01-03 00:27:38 +00:00
Avoid Public Key Copies During Aggregation (#12944)
* Add in optimization * add better test --------- Co-authored-by: prylabs-bulldozer[bot] <58059840+prylabs-bulldozer[bot]@users.noreply.github.com>
This commit is contained in:
parent
42c192d97d
commit
0919b2245f
@ -12,7 +12,7 @@ import (
|
||||
"github.com/prysmaticlabs/prysm/v4/crypto/bls/common"
|
||||
)
|
||||
|
||||
var maxKeys = 1_000_000
|
||||
var maxKeys = 2_000_000
|
||||
var pubkeyCache *nonblocking.LRU[[48]byte, common.PublicKey]
|
||||
|
||||
// PublicKey used in the BLS signature scheme.
|
||||
@ -22,13 +22,20 @@ type PublicKey struct {
|
||||
|
||||
// PublicKeyFromBytes creates a BLS public key from a BigEndian byte slice.
|
||||
func PublicKeyFromBytes(pubKey []byte) (common.PublicKey, error) {
|
||||
return publicKeyFromBytes(pubKey, true)
|
||||
}
|
||||
|
||||
func publicKeyFromBytes(pubKey []byte, cacheCopy bool) (common.PublicKey, error) {
|
||||
if len(pubKey) != params.BeaconConfig().BLSPubkeyLength {
|
||||
return nil, fmt.Errorf("public key must be %d bytes", params.BeaconConfig().BLSPubkeyLength)
|
||||
}
|
||||
newKey := (*[fieldparams.BLSPubkeyLength]byte)(pubKey)
|
||||
if cv, ok := pubkeyCache.Get(*newKey); ok {
|
||||
if cacheCopy {
|
||||
return cv.Copy(), nil
|
||||
}
|
||||
return cv, nil
|
||||
}
|
||||
// Subgroup check NOT done when decompressing pubkey.
|
||||
p := new(blstPublicKey).Uncompress(pubKey)
|
||||
if p == nil {
|
||||
@ -54,7 +61,7 @@ func AggregatePublicKeys(pubs [][]byte) (common.PublicKey, error) {
|
||||
agg := new(blstAggregatePublicKey)
|
||||
mulP1 := make([]*blstPublicKey, 0, len(pubs))
|
||||
for _, pubkey := range pubs {
|
||||
pubKeyObj, err := PublicKeyFromBytes(pubkey)
|
||||
pubKeyObj, err := publicKeyFromBytes(pubkey, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -5,6 +5,7 @@ package blst_test
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/prysmaticlabs/prysm/v4/crypto/bls/blst"
|
||||
@ -92,6 +93,84 @@ func TestPublicKey_Aggregate(t *testing.T) {
|
||||
require.DeepEqual(t, resKey.Marshal(), aggKey.Marshal(), "Pubkey does not match up")
|
||||
}
|
||||
|
||||
func TestPublicKey_Aggregation_NoCorruption(t *testing.T) {
|
||||
pubkeys := []common.PublicKey{}
|
||||
for i := 0; i < 100; i++ {
|
||||
priv, err := blst.RandKey()
|
||||
require.NoError(t, err)
|
||||
pubkey := priv.PublicKey()
|
||||
pubkeys = append(pubkeys, pubkey)
|
||||
}
|
||||
|
||||
compressedKeys := [][]byte{}
|
||||
// Fill up the cache
|
||||
for _, pkey := range pubkeys {
|
||||
_, err := blst.PublicKeyFromBytes(pkey.Marshal())
|
||||
require.NoError(t, err)
|
||||
compressedKeys = append(compressedKeys, pkey.Marshal())
|
||||
}
|
||||
|
||||
wg := new(sync.WaitGroup)
|
||||
|
||||
// Aggregate different sets of keys.
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
_, err := blst.AggregatePublicKeys(compressedKeys)
|
||||
require.NoError(t, err)
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
_, err := blst.AggregatePublicKeys(compressedKeys[:10])
|
||||
require.NoError(t, err)
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
_, err := blst.AggregatePublicKeys(compressedKeys[:40])
|
||||
require.NoError(t, err)
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
_, err := blst.AggregatePublicKeys(compressedKeys[20:60])
|
||||
require.NoError(t, err)
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
_, err := blst.AggregatePublicKeys(compressedKeys[80:])
|
||||
require.NoError(t, err)
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
_, err := blst.AggregatePublicKeys(compressedKeys[60:90])
|
||||
require.NoError(t, err)
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
_, err := blst.AggregatePublicKeys(compressedKeys[40:99])
|
||||
require.NoError(t, err)
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
for _, pkey := range pubkeys {
|
||||
cachedPubkey, err := blst.PublicKeyFromBytes(pkey.Marshal())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, true, cachedPubkey.Equals(pkey))
|
||||
}
|
||||
}
|
||||
|
||||
func TestPublicKeysEmpty(t *testing.T) {
|
||||
var pubs [][]byte
|
||||
_, err := blst.AggregatePublicKeys(pubs)
|
||||
|
Loading…
Reference in New Issue
Block a user