erigon-pulse/cl/persistence/base_encoding/uint64_diff.go

246 lines
6.3 KiB
Go
Raw Normal View History

package base_encoding
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"sync"
"github.com/klauspost/compress/zstd"
"github.com/ledgerwatch/erigon/cl/utils"
)
// make a sync.pool of compressors (zstd)
var compressorPool = sync.Pool{
New: func() interface{} {
compressor, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(zstd.SpeedFastest))
if err != nil {
panic(err)
}
return compressor
},
}
var bufferPool = sync.Pool{
New: func() interface{} {
return &bytes.Buffer{}
},
}
var plainUint64BufferPool = sync.Pool{
New: func() interface{} {
b := make([]uint64, 1028)
return &b
},
}
var plainBytesBufferPool = sync.Pool{
New: func() interface{} {
b := make([]byte, 1028)
return &b
},
}
var repeatedPatternBufferPool = sync.Pool{
New: func() interface{} {
b := make([]repeatedPatternEntry, 1028)
return &b
},
}
type repeatedPatternEntry struct {
val uint64
count uint32
}
func ComputeCompressedSerializedUint64ListDiff(w io.Writer, old, new []byte) error {
if len(old) > len(new) {
return fmt.Errorf("old list is longer than new list")
}
compressor := compressorPool.Get().(*zstd.Encoder)
defer compressorPool.Put(compressor)
compressor.Reset(w)
// Get one plain buffer from the pool
plainBufferPtr := plainUint64BufferPool.Get().(*[]uint64)
defer plainUint64BufferPool.Put(plainBufferPtr)
plainBuffer := *plainBufferPtr
plainBuffer = plainBuffer[:0]
// Get one repeated pattern buffer from the pool
repeatedPatternPtr := repeatedPatternBufferPool.Get().(*[]repeatedPatternEntry)
defer repeatedPatternBufferPool.Put(repeatedPatternPtr)
repeatedPattern := *repeatedPatternPtr
repeatedPattern = repeatedPattern[:0]
for i := 0; i < len(new); i += 8 {
if i+8 > len(old) {
// Append the remaining new bytes that were not in the old slice
plainBuffer = append(plainBuffer, binary.LittleEndian.Uint64(new[i:]))
continue
}
plainBuffer = append(plainBuffer, binary.LittleEndian.Uint64(new[i:i+8])-binary.LittleEndian.Uint64(old[i:i+8]))
}
// Find the repeated pattern
prevVal := plainBuffer[0]
count := uint32(1)
for i := 1; i < len(plainBuffer); i++ {
if plainBuffer[i] == prevVal {
count++
continue
}
repeatedPattern = append(repeatedPattern, repeatedPatternEntry{prevVal, count})
prevVal = plainBuffer[i]
count = 1
}
repeatedPattern = append(repeatedPattern, repeatedPatternEntry{prevVal, count})
if err := binary.Write(w, binary.BigEndian, uint32(len(repeatedPattern))); err != nil {
return err
}
temp := make([]byte, 8)
// Write the repeated pattern
for _, entry := range repeatedPattern {
binary.BigEndian.PutUint32(temp[:4], entry.count)
if _, err := compressor.Write(temp[:4]); err != nil {
return err
}
binary.BigEndian.PutUint64(temp, entry.val)
if _, err := compressor.Write(temp); err != nil {
return err
}
}
*repeatedPatternPtr = repeatedPattern[:0]
*plainBufferPtr = plainBuffer[:0]
return compressor.Close()
}
func ComputeCompressedSerializedEffectiveBalancesDiff(w io.Writer, old, new []byte) error {
if len(old) > len(new) {
return fmt.Errorf("old list is longer than new list")
}
compressor := compressorPool.Get().(*zstd.Encoder)
defer compressorPool.Put(compressor)
compressor.Reset(w)
// Get one plain buffer from the pool
plainBufferPtr := plainUint64BufferPool.Get().(*[]uint64)
defer plainUint64BufferPool.Put(plainBufferPtr)
plainBuffer := *plainBufferPtr
plainBuffer = plainBuffer[:0]
// Get one repeated pattern buffer from the pool
repeatedPatternPtr := repeatedPatternBufferPool.Get().(*[]repeatedPatternEntry)
defer repeatedPatternBufferPool.Put(repeatedPatternPtr)
repeatedPattern := *repeatedPatternPtr
repeatedPattern = repeatedPattern[:0]
validatorSize := 121
for i := 0; i < len(new); i += validatorSize {
// 80:88
if i+88 > len(old) {
// Append the remaining new bytes that were not in the old slice
plainBuffer = append(plainBuffer, binary.LittleEndian.Uint64(new[i+80:i+88]))
continue
}
plainBuffer = append(plainBuffer, binary.LittleEndian.Uint64(new[i+80:i+88])-binary.LittleEndian.Uint64(old[i+80:i+88]))
}
// Find the repeated pattern
prevVal := plainBuffer[0]
count := uint32(1)
for i := 1; i < len(plainBuffer); i++ {
if plainBuffer[i] == prevVal {
count++
continue
}
repeatedPattern = append(repeatedPattern, repeatedPatternEntry{prevVal, count})
prevVal = plainBuffer[i]
count = 1
}
repeatedPattern = append(repeatedPattern, repeatedPatternEntry{prevVal, count})
if err := binary.Write(w, binary.BigEndian, uint32(len(repeatedPattern))); err != nil {
return err
}
temp := make([]byte, 8)
// Write the repeated pattern
for _, entry := range repeatedPattern {
binary.BigEndian.PutUint32(temp[:4], entry.count)
if _, err := compressor.Write(temp[:4]); err != nil {
return err
}
binary.BigEndian.PutUint64(temp, entry.val)
if _, err := compressor.Write(temp); err != nil {
return err
}
}
*repeatedPatternPtr = repeatedPattern[:0]
*plainBufferPtr = plainBuffer[:0]
return compressor.Close()
}
func ApplyCompressedSerializedUint64ListDiff(old, out []byte, diff []byte) ([]byte, error) {
out = out[:0]
buffer := bufferPool.Get().(*bytes.Buffer)
defer bufferPool.Put(buffer)
buffer.Reset()
if _, err := buffer.Write(diff); err != nil {
return nil, err
}
var length uint32
if err := binary.Read(buffer, binary.BigEndian, &length); err != nil {
return nil, err
}
var entry repeatedPatternEntry
decompressor, err := zstd.NewReader(buffer)
if err != nil {
return nil, err
}
defer decompressor.Close()
temp := make([]byte, 8)
currIndex := 0
for i := 0; i < int(length); i++ {
n, err := utils.ReadZSTD(decompressor, temp[:4])
if err != nil && !errors.Is(err, io.EOF) {
return nil, err
}
if n != 4 {
return nil, io.EOF
}
entry.count = binary.BigEndian.Uint32(temp[:4])
n, err = utils.ReadZSTD(decompressor, temp)
if err != nil && !errors.Is(err, io.EOF) {
return nil, err
}
if n != 8 {
return nil, io.EOF
}
entry.val = binary.BigEndian.Uint64(temp)
for j := 0; j < int(entry.count); j++ {
if currIndex+8 > len(old) {
// Append the remaining new bytes that were not in the old slice
out = binary.LittleEndian.AppendUint64(out, entry.val)
currIndex += 8
continue
}
out = binary.LittleEndian.AppendUint64(out, binary.LittleEndian.Uint64(old[currIndex:currIndex+8])+entry.val)
currIndex += 8
}
}
return out, nil
}