erigon-pulse/common/eth2shuffle/shuffle_test.go
a 761524a986
eth2 shuffle (#7522)
moves eth2shuffle package to erigon.
2023-05-18 01:06:25 +02:00

144 lines
3.9 KiB
Go

package eth2shuffle_test
import (
"crypto/sha256"
"encoding/csv"
"encoding/hex"
"fmt"
"os"
"strconv"
"strings"
"testing"
"github.com/ledgerwatch/erigon/common/eth2shuffle"
"github.com/stretchr/testify/assert"
)
func getStandardHashFn() eth2shuffle.HashFn {
hash := sha256.New()
hashFn := func(in []byte) []byte {
hash.Reset()
hash.Write(in)
return hash.Sum(nil)
}
return hashFn
}
func readEncodedListInput(input string, requiredLen int64, lineIndex int) ([]uint64, error) {
var itemStrs []string
if input != "" {
itemStrs = strings.Split(input, ":")
} else {
itemStrs = make([]string, 0)
}
if int64(len(itemStrs)) != requiredLen {
return nil, fmt.Errorf("expected outputs length does not match list size on line %d\n", lineIndex)
}
items := make([]uint64, len(itemStrs), len(itemStrs))
for i, itemStr := range itemStrs {
item, err := strconv.ParseInt(itemStr, 10, 64)
if err != nil {
return nil, fmt.Errorf("expected list item on line %d, item %d cannot be parsed\n", lineIndex, i)
}
items[i] = uint64(item)
}
return items, nil
}
func TestAgainstSpec(t *testing.T) {
// Open CSV file
f, err := os.Open("spec/tests.csv")
if err != nil {
panic(err)
}
defer f.Close()
// Read File into a Variable
lines, err := csv.NewReader(f).ReadAll()
if err != nil {
panic(err)
}
// constant in spec
rounds := uint8(90)
// Loop through lines & turn into object
for lineIndex, line := range lines {
parsedSeed, err := hex.DecodeString(line[0])
if err != nil {
t.Fatalf("seed on line %d cannot be parsed\n", lineIndex)
}
listSize, err := strconv.ParseInt(line[1], 10, 32)
if err != nil {
t.Fatalf("list size on line %d cannot be parsed\n", lineIndex)
}
inputItems, err := readEncodedListInput(line[2], listSize, lineIndex)
assert.NoError(t, err)
expectedItems, err := readEncodedListInput(line[3], listSize, lineIndex)
assert.NoError(t, err)
t.Run("", func(listSize uint64, shuffleIn []uint64, shuffleOut []uint64) func(st *testing.T) {
return func(st *testing.T) {
seed := [32]byte{}
copy(seed[:], parsedSeed)
// run every test case in parallel. Input data is copied, for loop won't mess it up.
st.Parallel()
hashFn := getStandardHashFn()
st.Run("PermuteIndex", func(it *testing.T) {
for i := uint64(0); i < listSize; i++ {
// calculate the permuted index. (i.e. shuffle single index)
permuted := eth2shuffle.PermuteIndex(hashFn, rounds, i, listSize, seed)
// compare with expectation
if shuffleIn[i] != shuffleOut[permuted] {
it.FailNow()
}
}
})
st.Run("UnpermuteIndex", func(it *testing.T) {
// for each index, test un-permuting
for i := uint64(0); i < listSize; i++ {
// calculate the un-permuted index. (i.e. un-shuffle single index)
unpermuted := eth2shuffle.UnpermuteIndex(hashFn, rounds, i, listSize, seed)
// compare with expectation
if shuffleOut[i] != shuffleIn[unpermuted] {
it.FailNow()
}
}
})
st.Run("ShuffleList", func(it *testing.T) {
// create input, this slice will be shuffled.
testInput := make([]uint64, listSize, listSize)
copy(testInput, shuffleIn)
// shuffle!
eth2shuffle.ShuffleList(hashFn, testInput, rounds, seed)
// compare shuffled list to expected output
for i := uint64(0); i < listSize; i++ {
if testInput[i] != shuffleOut[i] {
it.FailNow()
}
}
})
st.Run("UnshuffleList", func(it *testing.T) {
// create input, this slice will be un-shuffled.
testInput := make([]uint64, listSize, listSize)
copy(testInput, shuffleOut)
// un-shuffle!
eth2shuffle.UnshuffleList(hashFn, testInput, rounds, seed)
// compare shuffled list to original input
for i := uint64(0); i < listSize; i++ {
if testInput[i] != shuffleIn[i] {
it.FailNow()
}
}
})
}
}(uint64(listSize), inputItems, expectedItems))
}
}