This commit is contained in:
alex.sharov 2023-01-13 11:35:30 +07:00
parent 120c4a20f4
commit 7c475cb52e
3 changed files with 227 additions and 19 deletions

View File

@ -191,25 +191,6 @@ func LastKey(tx Tx, table string) ([]byte, error) {
return k, nil
}
type ArrStream[V any] struct {
arr []V
i int
}
func StreamArray[V any](arr []V) UnaryStream[V] { return &ArrStream[V]{arr: arr} }
func (it *ArrStream[V]) HasNext() bool { return it.i < len(it.arr) }
func (it *ArrStream[V]) Close() {}
func (it *ArrStream[V]) Next() (V, error) {
v := it.arr[it.i]
it.i++
return v, nil
}
func (it *ArrStream[V]) NextBatch() ([]V, error) {
v := it.arr[it.i:]
it.i = len(it.arr)
return v, nil
}
// NextSubtree does []byte++. Returns false if overflow.
func NextSubtree(in []byte) ([]byte, bool) {
r := make([]byte, len(in))

137
kv/stream/stream.go Normal file
View File

@ -0,0 +1,137 @@
package stream
import (
"bytes"
"fmt"
"github.com/ledgerwatch/erigon-lib/kv"
)
type ArrStream[V any] struct {
arr []V
i int
}
func Array[V any](arr []V) kv.UnaryStream[V] { return &ArrStream[V]{arr: arr} }
func (it *ArrStream[V]) HasNext() bool { return it.i < len(it.arr) }
func (it *ArrStream[V]) Close() {}
func (it *ArrStream[V]) Next() (V, error) {
v := it.arr[it.i]
it.i++
return v, nil
}
func (it *ArrStream[V]) NextBatch() ([]V, error) {
v := it.arr[it.i:]
it.i = len(it.arr)
return v, nil
}
// MergePairsStream - merge 2 kv.Pairs streams to 1 in lexicographically order
// 1-st stream has higher priority - when 2 streams return same key
type MergePairsStream struct {
x, y kv.Pairs
xHasNext, yHasNext bool
xNextK, xNextV []byte
yNextK, yNextV []byte
nextErr error
}
func MergePairs(x, y kv.Pairs) *MergePairsStream {
m := &MergePairsStream{x: x, y: y}
m.advanceX()
m.advanceY()
return m
}
func (m *MergePairsStream) HasNext() bool { return m.xHasNext || m.yHasNext }
func (m *MergePairsStream) advanceX() {
if m.nextErr != nil {
m.xNextK, m.xNextV = nil, nil
return
}
m.xHasNext = m.x.HasNext()
if m.xHasNext {
m.xNextK, m.xNextV, m.nextErr = m.x.Next()
}
}
func (m *MergePairsStream) advanceY() {
if m.nextErr != nil {
m.yNextK, m.yNextV = nil, nil
return
}
m.yHasNext = m.y.HasNext()
if m.yHasNext {
m.yNextK, m.yNextV, m.nextErr = m.y.Next()
}
}
func (m *MergePairsStream) Next() ([]byte, []byte, error) {
if m.nextErr != nil {
return nil, nil, m.nextErr
}
if !m.xHasNext && !m.yHasNext {
panic(1)
}
if m.xHasNext && m.yHasNext {
cmp := bytes.Compare(m.xNextK, m.yNextK)
if cmp < 0 {
k, v, err := m.xNextK, m.xNextV, m.nextErr
m.advanceX()
return k, v, err
} else if cmp == 0 {
k, v, err := m.xNextK, m.xNextV, m.nextErr
m.advanceX()
m.advanceY()
return k, v, err
}
k, v, err := m.yNextK, m.yNextV, m.nextErr
m.advanceY()
return k, v, err
}
if m.xHasNext {
k, v, err := m.xNextK, m.xNextV, m.nextErr
m.advanceX()
return k, v, err
}
k, v, err := m.yNextK, m.yNextV, m.nextErr
m.advanceY()
return k, v, err
}
func (m *MergePairsStream) Keys() ([][]byte, error) { return naiveKeys(m) }
func (m *MergePairsStream) Values() ([][]byte, error) { return naiveValues(m) }
func naiveKeys(it kv.Pairs) (keys [][]byte, err error) {
for it.HasNext() {
k, _, err := it.Next()
if err != nil {
return keys, err
}
keys = append(keys, k)
}
return keys, nil
}
func naiveValues(it kv.Pairs) (values [][]byte, err error) {
for it.HasNext() {
_, v, err := it.Next()
if err != nil {
return values, err
}
values = append(values, v)
}
return values, nil
}
// PairsWithErrorStream - return N, keys and then error
type PairsWithErrorStream struct {
errorAt, i int
}
func PairsWithError(errorAt int) *PairsWithErrorStream {
return &PairsWithErrorStream{errorAt: errorAt}
}
func (m *PairsWithErrorStream) HasNext() bool { return true }
func (m *PairsWithErrorStream) Next() ([]byte, []byte, error) {
if m.i >= m.errorAt {
return nil, nil, fmt.Errorf("expected error at iteration: %d", m.errorAt)
}
m.i++
return []byte(fmt.Sprintf("%x", m.i)), []byte(fmt.Sprintf("%x", m.i)), nil
}

90
kv/stream/stream_test.go Normal file
View File

@ -0,0 +1,90 @@
package stream_test
import (
"context"
"testing"
"github.com/ledgerwatch/erigon-lib/kv"
"github.com/ledgerwatch/erigon-lib/kv/memdb"
"github.com/ledgerwatch/erigon-lib/kv/stream"
"github.com/stretchr/testify/require"
)
func TestMerge(t *testing.T) {
db := memdb.NewTestDB(t)
ctx := context.Background()
t.Run("simple", func(t *testing.T) {
require := require.New(t)
tx, _ := db.BeginRw(ctx)
defer tx.Rollback()
_ = tx.Put(kv.AccountsHistory, []byte{1}, []byte{1})
_ = tx.Put(kv.AccountsHistory, []byte{3}, []byte{1})
_ = tx.Put(kv.AccountsHistory, []byte{4}, []byte{1})
_ = tx.Put(kv.PlainState, []byte{2}, []byte{9})
_ = tx.Put(kv.PlainState, []byte{3}, []byte{9})
it, _ := tx.Range(kv.AccountsHistory, nil, nil)
it2, _ := tx.Range(kv.PlainState, nil, nil)
k, err := stream.MergePairs(it, it2).Keys()
require.NoError(err)
require.Equal([][]byte{{1}, {2}, {3}, {4}}, k)
})
t.Run("simple values", func(t *testing.T) {
require := require.New(t)
tx, _ := db.BeginRw(ctx)
defer tx.Rollback()
_ = tx.Put(kv.AccountsHistory, []byte{1}, []byte{1})
_ = tx.Put(kv.AccountsHistory, []byte{3}, []byte{1})
_ = tx.Put(kv.AccountsHistory, []byte{4}, []byte{1})
_ = tx.Put(kv.PlainState, []byte{2}, []byte{9})
_ = tx.Put(kv.PlainState, []byte{3}, []byte{9})
it, _ := tx.Range(kv.AccountsHistory, nil, nil)
it2, _ := tx.Range(kv.PlainState, nil, nil)
v, err := stream.MergePairs(it, it2).Values()
require.NoError(err)
require.Equal([][]byte{{1}, {9}, {1}, {1}}, v)
})
t.Run("empty 1st", func(t *testing.T) {
require := require.New(t)
tx, _ := db.BeginRw(ctx)
defer tx.Rollback()
_ = tx.Put(kv.PlainState, []byte{2}, []byte{9})
_ = tx.Put(kv.PlainState, []byte{3}, []byte{9})
it, _ := tx.Range(kv.AccountsHistory, nil, nil)
it2, _ := tx.Range(kv.PlainState, nil, nil)
k, err := stream.MergePairs(it, it2).Keys()
require.NoError(err)
require.Equal([][]byte{{2}, {3}}, k)
})
t.Run("empty 2nd", func(t *testing.T) {
require := require.New(t)
tx, _ := db.BeginRw(ctx)
defer tx.Rollback()
_ = tx.Put(kv.AccountsHistory, []byte{1}, []byte{1})
_ = tx.Put(kv.AccountsHistory, []byte{3}, []byte{1})
_ = tx.Put(kv.AccountsHistory, []byte{4}, []byte{1})
it, _ := tx.Range(kv.AccountsHistory, nil, nil)
it2, _ := tx.Range(kv.PlainState, nil, nil)
k, err := stream.MergePairs(it, it2).Keys()
require.NoError(err)
require.Equal([][]byte{{1}, {3}, {4}}, k)
})
t.Run("empty both", func(t *testing.T) {
require := require.New(t)
tx, _ := db.BeginRw(ctx)
defer tx.Rollback()
it, _ := tx.Range(kv.AccountsHistory, nil, nil)
it2, _ := tx.Range(kv.PlainState, nil, nil)
m := stream.MergePairs(it, it2)
require.False(m.HasNext())
})
t.Run("error handling", func(t *testing.T) {
require := require.New(t)
tx, _ := db.BeginRw(ctx)
defer tx.Rollback()
it := stream.PairsWithError(10)
it2 := stream.PairsWithError(12)
k, err := stream.MergePairs(it, it2).Keys()
require.Equal("expected error at iteration: 10", err.Error())
require.Equal(10, len(k))
})
}