mark in progress (#13750)

This commit is contained in:
Nishant Das 2024-03-16 00:46:26 +08:00 committed by GitHub
parent f343333880
commit 58b8c31c93
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 26 additions and 8 deletions

View File

@ -109,10 +109,6 @@ func (c *SkipSlotCache) Get(ctx context.Context, r [32]byte) (state.BeaconState,
// MarkInProgress a request so that any other similar requests will block on // MarkInProgress a request so that any other similar requests will block on
// Get until MarkNotInProgress is called. // Get until MarkNotInProgress is called.
func (c *SkipSlotCache) MarkInProgress(r [32]byte) error { func (c *SkipSlotCache) MarkInProgress(r [32]byte) error {
if c.disabled {
return nil
}
c.lock.Lock() c.lock.Lock()
defer c.lock.Unlock() defer c.lock.Unlock()
@ -126,10 +122,6 @@ func (c *SkipSlotCache) MarkInProgress(r [32]byte) error {
// MarkNotInProgress will release the lock on a given request. This should be // MarkNotInProgress will release the lock on a given request. This should be
// called after put. // called after put.
func (c *SkipSlotCache) MarkNotInProgress(r [32]byte) { func (c *SkipSlotCache) MarkNotInProgress(r [32]byte) {
if c.disabled {
return
}
c.lock.Lock() c.lock.Lock()
defer c.lock.Unlock() defer c.lock.Unlock()

View File

@ -2,6 +2,7 @@ package cache_test
import ( import (
"context" "context"
"sync"
"testing" "testing"
"github.com/prysmaticlabs/prysm/v5/beacon-chain/cache" "github.com/prysmaticlabs/prysm/v5/beacon-chain/cache"
@ -35,3 +36,28 @@ func TestSkipSlotCache_RoundTrip(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
assert.DeepEqual(t, res.ToProto(), s.ToProto(), "Expected equal protos to return from cache") assert.DeepEqual(t, res.ToProto(), s.ToProto(), "Expected equal protos to return from cache")
} }
func TestSkipSlotCache_DisabledAndEnabled(t *testing.T) {
ctx := context.Background()
c := cache.NewSkipSlotCache()
r := [32]byte{'a'}
c.Disable()
require.NoError(t, c.MarkInProgress(r))
c.Enable()
wg := new(sync.WaitGroup)
wg.Add(1)
go func() {
// Get call will only terminate when
// it is not longer in progress.
obj, err := c.Get(ctx, r)
require.NoError(t, err)
require.IsNil(t, obj)
wg.Done()
}()
c.MarkNotInProgress(r)
wg.Wait()
}