Avoid Public Key Copies During Aggregation (#12944)

* Add in optimization

* add better test

---------

Co-authored-by: prylabs-bulldozer[bot] <58059840+prylabs-bulldozer[bot]@users.noreply.github.com>
This commit is contained in:
Nishant Das 2023-09-26 19:13:19 +08:00 committed by GitHub
parent 42c192d97d
commit 0919b2245f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 89 additions and 3 deletions

View File

@ -12,7 +12,7 @@ import (
"github.com/prysmaticlabs/prysm/v4/crypto/bls/common"
)
var maxKeys = 1_000_000
var maxKeys = 2_000_000
var pubkeyCache *nonblocking.LRU[[48]byte, common.PublicKey]
// PublicKey used in the BLS signature scheme.
@ -22,13 +22,20 @@ type PublicKey struct {
// PublicKeyFromBytes creates a BLS public key from a BigEndian byte slice.
func PublicKeyFromBytes(pubKey []byte) (common.PublicKey, error) {
return publicKeyFromBytes(pubKey, true)
}
func publicKeyFromBytes(pubKey []byte, cacheCopy bool) (common.PublicKey, error) {
if len(pubKey) != params.BeaconConfig().BLSPubkeyLength {
return nil, fmt.Errorf("public key must be %d bytes", params.BeaconConfig().BLSPubkeyLength)
}
newKey := (*[fieldparams.BLSPubkeyLength]byte)(pubKey)
if cv, ok := pubkeyCache.Get(*newKey); ok {
if cacheCopy {
return cv.Copy(), nil
}
return cv, nil
}
// Subgroup check NOT done when decompressing pubkey.
p := new(blstPublicKey).Uncompress(pubKey)
if p == nil {
@ -54,7 +61,7 @@ func AggregatePublicKeys(pubs [][]byte) (common.PublicKey, error) {
agg := new(blstAggregatePublicKey)
mulP1 := make([]*blstPublicKey, 0, len(pubs))
for _, pubkey := range pubs {
pubKeyObj, err := PublicKeyFromBytes(pubkey)
pubKeyObj, err := publicKeyFromBytes(pubkey, false)
if err != nil {
return nil, err
}

View File

@ -5,6 +5,7 @@ package blst_test
import (
"bytes"
"errors"
"sync"
"testing"
"github.com/prysmaticlabs/prysm/v4/crypto/bls/blst"
@ -92,6 +93,84 @@ func TestPublicKey_Aggregate(t *testing.T) {
require.DeepEqual(t, resKey.Marshal(), aggKey.Marshal(), "Pubkey does not match up")
}
func TestPublicKey_Aggregation_NoCorruption(t *testing.T) {
pubkeys := []common.PublicKey{}
for i := 0; i < 100; i++ {
priv, err := blst.RandKey()
require.NoError(t, err)
pubkey := priv.PublicKey()
pubkeys = append(pubkeys, pubkey)
}
compressedKeys := [][]byte{}
// Fill up the cache
for _, pkey := range pubkeys {
_, err := blst.PublicKeyFromBytes(pkey.Marshal())
require.NoError(t, err)
compressedKeys = append(compressedKeys, pkey.Marshal())
}
wg := new(sync.WaitGroup)
// Aggregate different sets of keys.
wg.Add(1)
go func() {
_, err := blst.AggregatePublicKeys(compressedKeys)
require.NoError(t, err)
wg.Done()
}()
wg.Add(1)
go func() {
_, err := blst.AggregatePublicKeys(compressedKeys[:10])
require.NoError(t, err)
wg.Done()
}()
wg.Add(1)
go func() {
_, err := blst.AggregatePublicKeys(compressedKeys[:40])
require.NoError(t, err)
wg.Done()
}()
wg.Add(1)
go func() {
_, err := blst.AggregatePublicKeys(compressedKeys[20:60])
require.NoError(t, err)
wg.Done()
}()
wg.Add(1)
go func() {
_, err := blst.AggregatePublicKeys(compressedKeys[80:])
require.NoError(t, err)
wg.Done()
}()
wg.Add(1)
go func() {
_, err := blst.AggregatePublicKeys(compressedKeys[60:90])
require.NoError(t, err)
wg.Done()
}()
wg.Add(1)
go func() {
_, err := blst.AggregatePublicKeys(compressedKeys[40:99])
require.NoError(t, err)
wg.Done()
}()
wg.Wait()
for _, pkey := range pubkeys {
cachedPubkey, err := blst.PublicKeyFromBytes(pkey.Marshal())
require.NoError(t, err)
assert.Equal(t, true, cachedPubkey.Equals(pkey))
}
}
func TestPublicKeysEmpty(t *testing.T) {
var pubs [][]byte
_, err := blst.AggregatePublicKeys(pubs)