mirror of
https://gitlab.com/pulsechaincom/prysm-pulse.git
synced 2025-01-11 04:00:05 +00:00
d077483577
* v3 import renamings * tidy * fmt * rev * Update beacon-chain/core/epoch/precompute/reward_penalty_test.go * Update beacon-chain/core/helpers/validators_test.go * Update beacon-chain/db/alias.go * Update beacon-chain/db/alias.go * Update beacon-chain/db/alias.go * Update beacon-chain/db/iface/BUILD.bazel * Update beacon-chain/db/kv/kv.go * Update beacon-chain/db/kv/state.go * Update beacon-chain/rpc/prysm/v1alpha1/validator/attester_test.go * Update beacon-chain/rpc/prysm/v1alpha1/validator/attester_test.go * Update beacon-chain/sync/initial-sync/service.go * fix deps * fix bad replacements * fix bad replacements * change back * gohashtree version * fix deps Co-authored-by: Nishant Das <nishdas93@gmail.com> Co-authored-by: Potuz <potuz@prysmaticlabs.com>
128 lines
3.2 KiB
Go
128 lines
3.2 KiB
Go
package sync
|
|
|
|
import (
|
|
"context"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
core "github.com/libp2p/go-libp2p-core"
|
|
"github.com/libp2p/go-libp2p-core/network"
|
|
"github.com/libp2p/go-libp2p-core/protocol"
|
|
"github.com/prysmaticlabs/prysm/v3/beacon-chain/p2p"
|
|
p2ptest "github.com/prysmaticlabs/prysm/v3/beacon-chain/p2p/testing"
|
|
"github.com/prysmaticlabs/prysm/v3/testing/assert"
|
|
"github.com/prysmaticlabs/prysm/v3/testing/util"
|
|
)
|
|
|
|
func TestContextWrite_NoWrites(t *testing.T) {
|
|
p1 := p2ptest.NewTestP2P(t)
|
|
nPeer := p2ptest.NewTestP2P(t)
|
|
p1.Connect(nPeer)
|
|
|
|
wg := new(sync.WaitGroup)
|
|
prID := p2p.RPCPingTopicV1
|
|
wg.Add(1)
|
|
nPeer.BHost.SetStreamHandler(core.ProtocolID(prID), func(stream network.Stream) {
|
|
wg.Done()
|
|
// no-op
|
|
})
|
|
strm, err := p1.BHost.NewStream(context.Background(), nPeer.PeerID(), p2p.RPCPingTopicV1)
|
|
assert.NoError(t, err)
|
|
|
|
// Nothing will be written to the stream
|
|
assert.NoError(t, writeContextToStream(nil, strm, nil))
|
|
if util.WaitTimeout(wg, 1*time.Second) {
|
|
t.Fatal("Did not receive stream within 1 sec")
|
|
}
|
|
}
|
|
|
|
func TestContextRead_NoReads(t *testing.T) {
|
|
p1 := p2ptest.NewTestP2P(t)
|
|
nPeer := p2ptest.NewTestP2P(t)
|
|
p1.Connect(nPeer)
|
|
|
|
wg := new(sync.WaitGroup)
|
|
prID := p2p.RPCPingTopicV1
|
|
wg.Add(1)
|
|
wantedData := []byte{'A', 'B', 'C', 'D'}
|
|
nPeer.BHost.SetStreamHandler(core.ProtocolID(prID), func(stream network.Stream) {
|
|
// No Context will be read from it
|
|
dt, err := readContextFromStream(stream, nil)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, 0, len(dt))
|
|
|
|
// Ensure sent over data hasn't been modified.
|
|
buf := make([]byte, len(wantedData))
|
|
n, err := stream.Read(buf)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, len(wantedData), n)
|
|
assert.DeepEqual(t, wantedData, buf)
|
|
|
|
wg.Done()
|
|
})
|
|
strm, err := p1.BHost.NewStream(context.Background(), nPeer.PeerID(), p2p.RPCPingTopicV1)
|
|
assert.NoError(t, err)
|
|
|
|
n, err := strm.Write(wantedData)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, len(wantedData), n)
|
|
if util.WaitTimeout(wg, 1*time.Second) {
|
|
t.Fatal("Did not receive stream within 1 sec")
|
|
}
|
|
}
|
|
|
|
var _ = withProtocol(&fakeStream{})
|
|
|
|
type fakeStream struct {
|
|
protocol protocol.ID
|
|
}
|
|
|
|
func (fs *fakeStream) Protocol() protocol.ID {
|
|
return fs.protocol
|
|
}
|
|
|
|
func TestValidateVersion(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
version string
|
|
protocol string
|
|
error string
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "bad topic",
|
|
version: p2p.SchemaVersionV1,
|
|
protocol: "random",
|
|
error: "unable to find a valid protocol prefix",
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "valid topic with incorrect version",
|
|
version: p2p.SchemaVersionV1,
|
|
protocol: p2p.RPCBlocksByRootTopicV2,
|
|
error: "doesn't match provided version",
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "valid topic with correct version",
|
|
version: p2p.SchemaVersionV2,
|
|
protocol: p2p.RPCBlocksByRootTopicV2,
|
|
error: "",
|
|
wantErr: false,
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
stream := &fakeStream{protocol: protocol.ID(tt.protocol)}
|
|
err := validateVersion(tt.version, stream)
|
|
if (err != nil) != tt.wantErr {
|
|
t.Errorf("validateVersion() error = %v, wantErr %v", err, tt.wantErr)
|
|
}
|
|
if err != nil {
|
|
assert.ErrorContains(t, tt.error, err)
|
|
}
|
|
})
|
|
}
|
|
}
|