diff --git a/shared/sliceutil/slice.go b/shared/sliceutil/slice.go index 10b677b91..95356e994 100644 --- a/shared/sliceutil/slice.go +++ b/shared/sliceutil/slice.go @@ -41,14 +41,18 @@ func IntersectionUint64(s ...[]uint64) []uint64 { return s[0] } intersect := make([]uint64, 0) - for i := 1; i < len(s); i++ { - m := make(map[uint64]bool) - for j := 0; j < len(s[i-1]); j++ { - m[s[i-1][j]] = true - } - for j := 0; j < len(s[i]); j++ { - if _, found := m[s[i][j]]; found { - intersect = append(intersect, s[i][j]) + m := make(map[uint64]int) + for _, k := range s[0] { + m[k] = 1 + } + for i, num := 1, len(s); i < num; i++ { + for _, k := range s[i] { + // Increment and check only if item is present in both, and no increment has happened yet. + if _, found := m[k]; found && (i-m[k]) == 0 { + m[k]++ + if m[k] == num { + intersect = append(intersect, k) + } } } } @@ -152,19 +156,22 @@ func IntersectionInt64(s ...[]int64) []int64 { if len(s) == 1 { return s[0] } - set := make([]int64, 0) - m := make(map[int64]bool) - for i := 1; i < len(s); i++ { - for j := 0; j < len(s[i-1]); j++ { - m[s[i-1][j]] = true - } - for j := 0; j < len(s[i]); j++ { - if _, found := m[s[i][j]]; found { - set = append(set, s[i][j]) + intersect := make([]int64, 0) + m := make(map[int64]int) + for _, k := range s[0] { + m[k] = 1 + } + for i, num := 1, len(s); i < num; i++ { + for _, k := range s[i] { + if _, found := m[k]; found && (i-m[k]) == 0 { + m[k]++ + if m[k] == num { + intersect = append(intersect, k) + } } } } - return set + return intersect } // UnionInt64 of any number of int64 slices with time @@ -256,26 +263,19 @@ func IntersectionByteSlices(s ...[][]byte) [][]byte { return s[0] } inter := make([][]byte, 0) - for i := 1; i < len(s); i++ { - hash := make(map[string]bool) - for _, e := range s[i-1] { - hash[string(e)] = true - } - for _, e := range s[i] { - if hash[string(e)] { - inter = append(inter, e) + m := make(map[string]int) + for _, k := range s[0] { + m[string(k)] = 1 + } + for i, num := 1, len(s); i < num; i++ { + for _, k := range s[i] { + if _, found := m[string(k)]; found && (i-m[string(k)]) == 0 { + m[string(k)]++ + if m[string(k)] == num { + inter = append(inter, k) + } } } - tmp := make([][]byte, 0) - // Remove duplicates from slice. - encountered := make(map[string]bool) - for _, element := range inter { - if !encountered[string(element)] { - tmp = append(tmp, element) - encountered[string(element)] = true - } - } - inter = tmp } return inter } diff --git a/shared/sliceutil/slice_test.go b/shared/sliceutil/slice_test.go index c4bb309f6..6118f904e 100644 --- a/shared/sliceutil/slice_test.go +++ b/shared/sliceutil/slice_test.go @@ -2,6 +2,7 @@ package sliceutil_test import ( "reflect" + "sort" "testing" "github.com/prysmaticlabs/prysm/shared/sliceutil" @@ -32,22 +33,45 @@ func TestIntersectionUint64(t *testing.T) { testCases := []struct { setA []uint64 setB []uint64 + setC []uint64 out []uint64 }{ - {[]uint64{2, 3, 5}, []uint64{3}, []uint64{3}}, - {[]uint64{2, 3, 5}, []uint64{3, 5}, []uint64{3, 5}}, - {[]uint64{2, 3, 5}, []uint64{5, 3, 2}, []uint64{5, 3, 2}}, - {[]uint64{2, 3, 5}, []uint64{2, 3, 5}, []uint64{2, 3, 5}}, - {[]uint64{2, 3, 5}, []uint64{}, []uint64{}}, - {[]uint64{}, []uint64{2, 3, 5}, []uint64{}}, - {[]uint64{}, []uint64{}, []uint64{}}, - {[]uint64{1}, []uint64{1}, []uint64{1}}, + {[]uint64{2, 3, 5}, []uint64{3}, []uint64{3}, []uint64{3}}, + {[]uint64{2, 3, 5}, []uint64{3, 5}, []uint64{5}, []uint64{5}}, + {[]uint64{2, 3, 5}, []uint64{3, 5}, []uint64{3, 5}, []uint64{3, 5}}, + {[]uint64{2, 3, 5}, []uint64{5, 3, 2}, []uint64{3, 2, 5}, []uint64{2, 3, 5}}, + {[]uint64{3, 2, 5}, []uint64{5, 3, 2}, []uint64{3, 2, 5}, []uint64{2, 3, 5}}, + {[]uint64{3, 3, 5}, []uint64{5, 3, 2}, []uint64{3, 2, 5}, []uint64{3, 5}}, + {[]uint64{2, 3, 5}, []uint64{2, 3, 5}, []uint64{2, 3, 5}, []uint64{2, 3, 5}}, + {[]uint64{2, 3, 5}, []uint64{}, []uint64{}, []uint64{}}, + {[]uint64{2, 3, 5}, []uint64{2, 3, 5}, []uint64{}, []uint64{}}, + {[]uint64{2, 3}, []uint64{2, 3, 5}, []uint64{5}, []uint64{}}, + {[]uint64{2, 2, 2}, []uint64{2, 2, 2}, []uint64{}, []uint64{}}, + {[]uint64{}, []uint64{2, 3, 5}, []uint64{}, []uint64{}}, + {[]uint64{}, []uint64{}, []uint64{}, []uint64{}}, + {[]uint64{1}, []uint64{1}, []uint64{}, []uint64{}}, + {[]uint64{1, 1, 1}, []uint64{1, 1}, []uint64{1, 2, 3}, []uint64{1}}, } for _, tt := range testCases { - result := sliceutil.IntersectionUint64(tt.setA, tt.setB) + setA := append([]uint64{}, tt.setA...) + setB := append([]uint64{}, tt.setB...) + setC := append([]uint64{}, tt.setC...) + result := sliceutil.IntersectionUint64(setA, setB, setC) + sort.Slice(result, func(i, j int) bool { + return result[i] < result[j] + }) if !reflect.DeepEqual(result, tt.out) { t.Errorf("got %d, want %d", result, tt.out) } + if !reflect.DeepEqual(setA, tt.setA) { + t.Errorf("slice modified, got %v, want %v", setA, tt.setA) + } + if !reflect.DeepEqual(setB, tt.setB) { + t.Errorf("slice modified, got %v, want %v", setB, tt.setB) + } + if !reflect.DeepEqual(setC, tt.setC) { + t.Errorf("slice modified, got %v, want %v", setC, tt.setC) + } } } @@ -73,22 +97,45 @@ func TestIntersectionInt64(t *testing.T) { testCases := []struct { setA []int64 setB []int64 + setC []int64 out []int64 }{ - {[]int64{2, 3, 5}, []int64{3}, []int64{3}}, - {[]int64{2, 3, 5}, []int64{3, 5}, []int64{3, 5}}, - {[]int64{2, 3, 5}, []int64{5, 3, 2}, []int64{5, 3, 2}}, - {[]int64{2, 3, 5}, []int64{2, 3, 5}, []int64{2, 3, 5}}, - {[]int64{2, 3, 5}, []int64{}, []int64{}}, - {[]int64{}, []int64{2, 3, 5}, []int64{}}, - {[]int64{}, []int64{}, []int64{}}, - {[]int64{1}, []int64{1}, []int64{1}}, + {[]int64{2, 3, 5}, []int64{3}, []int64{3}, []int64{3}}, + {[]int64{2, 3, 5}, []int64{3, 5}, []int64{5}, []int64{5}}, + {[]int64{2, 3, 5}, []int64{3, 5}, []int64{3, 5}, []int64{3, 5}}, + {[]int64{2, 3, 5}, []int64{5, 3, 2}, []int64{3, 2, 5}, []int64{2, 3, 5}}, + {[]int64{3, 2, 5}, []int64{5, 3, 2}, []int64{3, 2, 5}, []int64{2, 3, 5}}, + {[]int64{3, 3, 5}, []int64{5, 3, 2}, []int64{3, 2, 5}, []int64{3, 5}}, + {[]int64{2, 3, 5}, []int64{2, 3, 5}, []int64{2, 3, 5}, []int64{2, 3, 5}}, + {[]int64{2, 3, 5}, []int64{}, []int64{}, []int64{}}, + {[]int64{2, 3, 5}, []int64{2, 3, 5}, []int64{}, []int64{}}, + {[]int64{2, 3}, []int64{2, 3, 5}, []int64{5}, []int64{}}, + {[]int64{2, 2, 2}, []int64{2, 2, 2}, []int64{}, []int64{}}, + {[]int64{}, []int64{2, 3, 5}, []int64{}, []int64{}}, + {[]int64{}, []int64{}, []int64{}, []int64{}}, + {[]int64{1}, []int64{1}, []int64{}, []int64{}}, + {[]int64{1, 1, 1}, []int64{1, 1}, []int64{1, 2, 3}, []int64{1}}, } for _, tt := range testCases { - result := sliceutil.IntersectionInt64(tt.setA, tt.setB) + setA := append([]int64{}, tt.setA...) + setB := append([]int64{}, tt.setB...) + setC := append([]int64{}, tt.setC...) + result := sliceutil.IntersectionInt64(setA, setB, setC) + sort.Slice(result, func(i, j int) bool { + return result[i] < result[j] + }) if !reflect.DeepEqual(result, tt.out) { t.Errorf("got %d, want %d", result, tt.out) } + if !reflect.DeepEqual(setA, tt.setA) { + t.Errorf("slice modified, got %v, want %v", setA, tt.setA) + } + if !reflect.DeepEqual(setB, tt.setB) { + t.Errorf("slice modified, got %v, want %v", setB, tt.setB) + } + if !reflect.DeepEqual(setC, tt.setC) { + t.Errorf("slice modified, got %v, want %v", setC, tt.setC) + } } } @@ -308,10 +355,12 @@ func TestUnionByteSlices(t *testing.T) { func TestIntersectionByteSlices(t *testing.T) { testCases := []struct { + name string input [][][]byte result [][]byte }{ { + name: "intersect with empty set", input: [][][]byte{ { {1, 2, 3}, @@ -321,11 +370,12 @@ func TestIntersectionByteSlices(t *testing.T) { {1, 2}, {4, 5}, }, + {}, }, - result: [][]byte{{4, 5}}, + result: [][]byte{}, }, - // Ensure duplicate elements are removed in the resulting set. { + name: "ensure duplicate elements are removed in the resulting set", input: [][][]byte{ { {1, 2, 3}, @@ -337,11 +387,15 @@ func TestIntersectionByteSlices(t *testing.T) { {4, 5}, {4, 5}, }, + { + {4, 5}, + {4, 5}, + }, }, result: [][]byte{{4, 5}}, }, - // Ensure no intersection returns an empty set. { + name: "ensure no intersection returns an empty set", input: [][][]byte{ { {1, 2, 3}, @@ -350,11 +404,14 @@ func TestIntersectionByteSlices(t *testing.T) { { {1, 2}, }, + { + {1, 2}, + }, }, result: [][]byte{}, }, - // Intersection between A and A should return A. { + name: "intersection between A and A should return A", input: [][][]byte{ { {1, 2}, @@ -362,17 +419,33 @@ func TestIntersectionByteSlices(t *testing.T) { { {1, 2}, }, + { + {1, 2}, + }, }, result: [][]byte{{1, 2}}, }, } for _, tt := range testCases { - result := sliceutil.IntersectionByteSlices(tt.input...) - if !reflect.DeepEqual(result, tt.result) { - t.Errorf("IntersectionByteSlices(%v)=%v, wanted: %v", - tt.input, result, tt.result) - } + t.Run(tt.name, func(t *testing.T) { + result := sliceutil.IntersectionByteSlices(tt.input...) + if !reflect.DeepEqual(result, tt.result) { + t.Errorf("IntersectionByteSlices(%v)=%v, wanted: %v", + tt.input, result, tt.result) + } + }) } + t.Run("properly handle duplicates", func(t *testing.T) { + input := [][][]byte{ + {{1, 2}, {1, 2}}, + {{1, 2}, {1, 2}}, + {}, + } + result := sliceutil.IntersectionByteSlices(input...) + if !reflect.DeepEqual(result, [][]byte{}) { + t.Errorf("IntersectionByteSlices(%v)=%v, wanted: %v", input, result, [][]byte{}) + } + }) } func TestSplitCommaSeparated(t *testing.T) {