diff --git a/crypto/bls/blst/public_key.go b/crypto/bls/blst/public_key.go index 16d52a41b..f722a5a5e 100644 --- a/crypto/bls/blst/public_key.go +++ b/crypto/bls/blst/public_key.go @@ -12,7 +12,7 @@ import ( "github.com/prysmaticlabs/prysm/v4/crypto/bls/common" ) -var maxKeys = 1_000_000 +var maxKeys = 2_000_000 var pubkeyCache *nonblocking.LRU[[48]byte, common.PublicKey] // PublicKey used in the BLS signature scheme. @@ -22,12 +22,19 @@ type PublicKey struct { // PublicKeyFromBytes creates a BLS public key from a BigEndian byte slice. func PublicKeyFromBytes(pubKey []byte) (common.PublicKey, error) { + return publicKeyFromBytes(pubKey, true) +} + +func publicKeyFromBytes(pubKey []byte, cacheCopy bool) (common.PublicKey, error) { if len(pubKey) != params.BeaconConfig().BLSPubkeyLength { return nil, fmt.Errorf("public key must be %d bytes", params.BeaconConfig().BLSPubkeyLength) } newKey := (*[fieldparams.BLSPubkeyLength]byte)(pubKey) if cv, ok := pubkeyCache.Get(*newKey); ok { - return cv.Copy(), nil + if cacheCopy { + return cv.Copy(), nil + } + return cv, nil } // Subgroup check NOT done when decompressing pubkey. p := new(blstPublicKey).Uncompress(pubKey) @@ -54,7 +61,7 @@ func AggregatePublicKeys(pubs [][]byte) (common.PublicKey, error) { agg := new(blstAggregatePublicKey) mulP1 := make([]*blstPublicKey, 0, len(pubs)) for _, pubkey := range pubs { - pubKeyObj, err := PublicKeyFromBytes(pubkey) + pubKeyObj, err := publicKeyFromBytes(pubkey, false) if err != nil { return nil, err } diff --git a/crypto/bls/blst/public_key_test.go b/crypto/bls/blst/public_key_test.go index cd83038c1..b27ef29fa 100644 --- a/crypto/bls/blst/public_key_test.go +++ b/crypto/bls/blst/public_key_test.go @@ -5,6 +5,7 @@ package blst_test import ( "bytes" "errors" + "sync" "testing" "github.com/prysmaticlabs/prysm/v4/crypto/bls/blst" @@ -92,6 +93,84 @@ func TestPublicKey_Aggregate(t *testing.T) { require.DeepEqual(t, resKey.Marshal(), aggKey.Marshal(), "Pubkey does not match up") } +func TestPublicKey_Aggregation_NoCorruption(t *testing.T) { + pubkeys := []common.PublicKey{} + for i := 0; i < 100; i++ { + priv, err := blst.RandKey() + require.NoError(t, err) + pubkey := priv.PublicKey() + pubkeys = append(pubkeys, pubkey) + } + + compressedKeys := [][]byte{} + // Fill up the cache + for _, pkey := range pubkeys { + _, err := blst.PublicKeyFromBytes(pkey.Marshal()) + require.NoError(t, err) + compressedKeys = append(compressedKeys, pkey.Marshal()) + } + + wg := new(sync.WaitGroup) + + // Aggregate different sets of keys. + wg.Add(1) + go func() { + _, err := blst.AggregatePublicKeys(compressedKeys) + require.NoError(t, err) + wg.Done() + }() + + wg.Add(1) + go func() { + _, err := blst.AggregatePublicKeys(compressedKeys[:10]) + require.NoError(t, err) + wg.Done() + }() + + wg.Add(1) + go func() { + _, err := blst.AggregatePublicKeys(compressedKeys[:40]) + require.NoError(t, err) + wg.Done() + }() + + wg.Add(1) + go func() { + _, err := blst.AggregatePublicKeys(compressedKeys[20:60]) + require.NoError(t, err) + wg.Done() + }() + + wg.Add(1) + go func() { + _, err := blst.AggregatePublicKeys(compressedKeys[80:]) + require.NoError(t, err) + wg.Done() + }() + + wg.Add(1) + go func() { + _, err := blst.AggregatePublicKeys(compressedKeys[60:90]) + require.NoError(t, err) + wg.Done() + }() + + wg.Add(1) + go func() { + _, err := blst.AggregatePublicKeys(compressedKeys[40:99]) + require.NoError(t, err) + wg.Done() + }() + + wg.Wait() + + for _, pkey := range pubkeys { + cachedPubkey, err := blst.PublicKeyFromBytes(pkey.Marshal()) + require.NoError(t, err) + assert.Equal(t, true, cachedPubkey.Equals(pkey)) + } +} + func TestPublicKeysEmpty(t *testing.T) { var pubs [][]byte _, err := blst.AggregatePublicKeys(pubs)