From 261a343ea59adb3bd7e19377b3d6d8d561dfb59e Mon Sep 17 00:00:00 2001
From: Victor Farazdagi <simple.square@gmail.com>
Date: Mon, 1 Jun 2020 16:35:29 +0300
Subject: [PATCH] Fixes intersection functions (uint64, int64, []byte) (#6067)

---
 shared/sliceutil/slice.go      |  72 +++++++++----------
 shared/sliceutil/slice_test.go | 127 ++++++++++++++++++++++++++-------
 2 files changed, 136 insertions(+), 63 deletions(-)

diff --git a/shared/sliceutil/slice.go b/shared/sliceutil/slice.go
index 10b677b91..95356e994 100644
--- a/shared/sliceutil/slice.go
+++ b/shared/sliceutil/slice.go
@@ -41,14 +41,18 @@ func IntersectionUint64(s ...[]uint64) []uint64 {
 		return s[0]
 	}
 	intersect := make([]uint64, 0)
-	for i := 1; i < len(s); i++ {
-		m := make(map[uint64]bool)
-		for j := 0; j < len(s[i-1]); j++ {
-			m[s[i-1][j]] = true
-		}
-		for j := 0; j < len(s[i]); j++ {
-			if _, found := m[s[i][j]]; found {
-				intersect = append(intersect, s[i][j])
+	m := make(map[uint64]int)
+	for _, k := range s[0] {
+		m[k] = 1
+	}
+	for i, num := 1, len(s); i < num; i++ {
+		for _, k := range s[i] {
+			// Increment and check only if item is present in both, and no increment has happened yet.
+			if _, found := m[k]; found && (i-m[k]) == 0 {
+				m[k]++
+				if m[k] == num {
+					intersect = append(intersect, k)
+				}
 			}
 		}
 	}
@@ -152,19 +156,22 @@ func IntersectionInt64(s ...[]int64) []int64 {
 	if len(s) == 1 {
 		return s[0]
 	}
-	set := make([]int64, 0)
-	m := make(map[int64]bool)
-	for i := 1; i < len(s); i++ {
-		for j := 0; j < len(s[i-1]); j++ {
-			m[s[i-1][j]] = true
-		}
-		for j := 0; j < len(s[i]); j++ {
-			if _, found := m[s[i][j]]; found {
-				set = append(set, s[i][j])
+	intersect := make([]int64, 0)
+	m := make(map[int64]int)
+	for _, k := range s[0] {
+		m[k] = 1
+	}
+	for i, num := 1, len(s); i < num; i++ {
+		for _, k := range s[i] {
+			if _, found := m[k]; found && (i-m[k]) == 0 {
+				m[k]++
+				if m[k] == num {
+					intersect = append(intersect, k)
+				}
 			}
 		}
 	}
-	return set
+	return intersect
 }
 
 // UnionInt64 of any number of int64 slices with time
@@ -256,26 +263,19 @@ func IntersectionByteSlices(s ...[][]byte) [][]byte {
 		return s[0]
 	}
 	inter := make([][]byte, 0)
-	for i := 1; i < len(s); i++ {
-		hash := make(map[string]bool)
-		for _, e := range s[i-1] {
-			hash[string(e)] = true
-		}
-		for _, e := range s[i] {
-			if hash[string(e)] {
-				inter = append(inter, e)
+	m := make(map[string]int)
+	for _, k := range s[0] {
+		m[string(k)] = 1
+	}
+	for i, num := 1, len(s); i < num; i++ {
+		for _, k := range s[i] {
+			if _, found := m[string(k)]; found && (i-m[string(k)]) == 0 {
+				m[string(k)]++
+				if m[string(k)] == num {
+					inter = append(inter, k)
+				}
 			}
 		}
-		tmp := make([][]byte, 0)
-		// Remove duplicates from slice.
-		encountered := make(map[string]bool)
-		for _, element := range inter {
-			if !encountered[string(element)] {
-				tmp = append(tmp, element)
-				encountered[string(element)] = true
-			}
-		}
-		inter = tmp
 	}
 	return inter
 }
diff --git a/shared/sliceutil/slice_test.go b/shared/sliceutil/slice_test.go
index c4bb309f6..6118f904e 100644
--- a/shared/sliceutil/slice_test.go
+++ b/shared/sliceutil/slice_test.go
@@ -2,6 +2,7 @@ package sliceutil_test
 
 import (
 	"reflect"
+	"sort"
 	"testing"
 
 	"github.com/prysmaticlabs/prysm/shared/sliceutil"
@@ -32,22 +33,45 @@ func TestIntersectionUint64(t *testing.T) {
 	testCases := []struct {
 		setA []uint64
 		setB []uint64
+		setC []uint64
 		out  []uint64
 	}{
-		{[]uint64{2, 3, 5}, []uint64{3}, []uint64{3}},
-		{[]uint64{2, 3, 5}, []uint64{3, 5}, []uint64{3, 5}},
-		{[]uint64{2, 3, 5}, []uint64{5, 3, 2}, []uint64{5, 3, 2}},
-		{[]uint64{2, 3, 5}, []uint64{2, 3, 5}, []uint64{2, 3, 5}},
-		{[]uint64{2, 3, 5}, []uint64{}, []uint64{}},
-		{[]uint64{}, []uint64{2, 3, 5}, []uint64{}},
-		{[]uint64{}, []uint64{}, []uint64{}},
-		{[]uint64{1}, []uint64{1}, []uint64{1}},
+		{[]uint64{2, 3, 5}, []uint64{3}, []uint64{3}, []uint64{3}},
+		{[]uint64{2, 3, 5}, []uint64{3, 5}, []uint64{5}, []uint64{5}},
+		{[]uint64{2, 3, 5}, []uint64{3, 5}, []uint64{3, 5}, []uint64{3, 5}},
+		{[]uint64{2, 3, 5}, []uint64{5, 3, 2}, []uint64{3, 2, 5}, []uint64{2, 3, 5}},
+		{[]uint64{3, 2, 5}, []uint64{5, 3, 2}, []uint64{3, 2, 5}, []uint64{2, 3, 5}},
+		{[]uint64{3, 3, 5}, []uint64{5, 3, 2}, []uint64{3, 2, 5}, []uint64{3, 5}},
+		{[]uint64{2, 3, 5}, []uint64{2, 3, 5}, []uint64{2, 3, 5}, []uint64{2, 3, 5}},
+		{[]uint64{2, 3, 5}, []uint64{}, []uint64{}, []uint64{}},
+		{[]uint64{2, 3, 5}, []uint64{2, 3, 5}, []uint64{}, []uint64{}},
+		{[]uint64{2, 3}, []uint64{2, 3, 5}, []uint64{5}, []uint64{}},
+		{[]uint64{2, 2, 2}, []uint64{2, 2, 2}, []uint64{}, []uint64{}},
+		{[]uint64{}, []uint64{2, 3, 5}, []uint64{}, []uint64{}},
+		{[]uint64{}, []uint64{}, []uint64{}, []uint64{}},
+		{[]uint64{1}, []uint64{1}, []uint64{}, []uint64{}},
+		{[]uint64{1, 1, 1}, []uint64{1, 1}, []uint64{1, 2, 3}, []uint64{1}},
 	}
 	for _, tt := range testCases {
-		result := sliceutil.IntersectionUint64(tt.setA, tt.setB)
+		setA := append([]uint64{}, tt.setA...)
+		setB := append([]uint64{}, tt.setB...)
+		setC := append([]uint64{}, tt.setC...)
+		result := sliceutil.IntersectionUint64(setA, setB, setC)
+		sort.Slice(result, func(i, j int) bool {
+			return result[i] < result[j]
+		})
 		if !reflect.DeepEqual(result, tt.out) {
 			t.Errorf("got %d, want %d", result, tt.out)
 		}
+		if !reflect.DeepEqual(setA, tt.setA) {
+			t.Errorf("slice modified, got %v, want %v", setA, tt.setA)
+		}
+		if !reflect.DeepEqual(setB, tt.setB) {
+			t.Errorf("slice modified, got %v, want %v", setB, tt.setB)
+		}
+		if !reflect.DeepEqual(setC, tt.setC) {
+			t.Errorf("slice modified, got %v, want %v", setC, tt.setC)
+		}
 	}
 }
 
@@ -73,22 +97,45 @@ func TestIntersectionInt64(t *testing.T) {
 	testCases := []struct {
 		setA []int64
 		setB []int64
+		setC []int64
 		out  []int64
 	}{
-		{[]int64{2, 3, 5}, []int64{3}, []int64{3}},
-		{[]int64{2, 3, 5}, []int64{3, 5}, []int64{3, 5}},
-		{[]int64{2, 3, 5}, []int64{5, 3, 2}, []int64{5, 3, 2}},
-		{[]int64{2, 3, 5}, []int64{2, 3, 5}, []int64{2, 3, 5}},
-		{[]int64{2, 3, 5}, []int64{}, []int64{}},
-		{[]int64{}, []int64{2, 3, 5}, []int64{}},
-		{[]int64{}, []int64{}, []int64{}},
-		{[]int64{1}, []int64{1}, []int64{1}},
+		{[]int64{2, 3, 5}, []int64{3}, []int64{3}, []int64{3}},
+		{[]int64{2, 3, 5}, []int64{3, 5}, []int64{5}, []int64{5}},
+		{[]int64{2, 3, 5}, []int64{3, 5}, []int64{3, 5}, []int64{3, 5}},
+		{[]int64{2, 3, 5}, []int64{5, 3, 2}, []int64{3, 2, 5}, []int64{2, 3, 5}},
+		{[]int64{3, 2, 5}, []int64{5, 3, 2}, []int64{3, 2, 5}, []int64{2, 3, 5}},
+		{[]int64{3, 3, 5}, []int64{5, 3, 2}, []int64{3, 2, 5}, []int64{3, 5}},
+		{[]int64{2, 3, 5}, []int64{2, 3, 5}, []int64{2, 3, 5}, []int64{2, 3, 5}},
+		{[]int64{2, 3, 5}, []int64{}, []int64{}, []int64{}},
+		{[]int64{2, 3, 5}, []int64{2, 3, 5}, []int64{}, []int64{}},
+		{[]int64{2, 3}, []int64{2, 3, 5}, []int64{5}, []int64{}},
+		{[]int64{2, 2, 2}, []int64{2, 2, 2}, []int64{}, []int64{}},
+		{[]int64{}, []int64{2, 3, 5}, []int64{}, []int64{}},
+		{[]int64{}, []int64{}, []int64{}, []int64{}},
+		{[]int64{1}, []int64{1}, []int64{}, []int64{}},
+		{[]int64{1, 1, 1}, []int64{1, 1}, []int64{1, 2, 3}, []int64{1}},
 	}
 	for _, tt := range testCases {
-		result := sliceutil.IntersectionInt64(tt.setA, tt.setB)
+		setA := append([]int64{}, tt.setA...)
+		setB := append([]int64{}, tt.setB...)
+		setC := append([]int64{}, tt.setC...)
+		result := sliceutil.IntersectionInt64(setA, setB, setC)
+		sort.Slice(result, func(i, j int) bool {
+			return result[i] < result[j]
+		})
 		if !reflect.DeepEqual(result, tt.out) {
 			t.Errorf("got %d, want %d", result, tt.out)
 		}
+		if !reflect.DeepEqual(setA, tt.setA) {
+			t.Errorf("slice modified, got %v, want %v", setA, tt.setA)
+		}
+		if !reflect.DeepEqual(setB, tt.setB) {
+			t.Errorf("slice modified, got %v, want %v", setB, tt.setB)
+		}
+		if !reflect.DeepEqual(setC, tt.setC) {
+			t.Errorf("slice modified, got %v, want %v", setC, tt.setC)
+		}
 	}
 }
 
@@ -308,10 +355,12 @@ func TestUnionByteSlices(t *testing.T) {
 
 func TestIntersectionByteSlices(t *testing.T) {
 	testCases := []struct {
+		name   string
 		input  [][][]byte
 		result [][]byte
 	}{
 		{
+			name: "intersect with empty set",
 			input: [][][]byte{
 				{
 					{1, 2, 3},
@@ -321,11 +370,12 @@ func TestIntersectionByteSlices(t *testing.T) {
 					{1, 2},
 					{4, 5},
 				},
+				{},
 			},
-			result: [][]byte{{4, 5}},
+			result: [][]byte{},
 		},
-		// Ensure duplicate elements are removed in the resulting set.
 		{
+			name: "ensure duplicate elements are removed in the resulting set",
 			input: [][][]byte{
 				{
 					{1, 2, 3},
@@ -337,11 +387,15 @@ func TestIntersectionByteSlices(t *testing.T) {
 					{4, 5},
 					{4, 5},
 				},
+				{
+					{4, 5},
+					{4, 5},
+				},
 			},
 			result: [][]byte{{4, 5}},
 		},
-		// Ensure no intersection returns an empty set.
 		{
+			name: "ensure no intersection returns an empty set",
 			input: [][][]byte{
 				{
 					{1, 2, 3},
@@ -350,11 +404,14 @@ func TestIntersectionByteSlices(t *testing.T) {
 				{
 					{1, 2},
 				},
+				{
+					{1, 2},
+				},
 			},
 			result: [][]byte{},
 		},
-		//  Intersection between A and A should return A.
 		{
+			name: "intersection between A and A should return A",
 			input: [][][]byte{
 				{
 					{1, 2},
@@ -362,17 +419,33 @@ func TestIntersectionByteSlices(t *testing.T) {
 				{
 					{1, 2},
 				},
+				{
+					{1, 2},
+				},
 			},
 			result: [][]byte{{1, 2}},
 		},
 	}
 	for _, tt := range testCases {
-		result := sliceutil.IntersectionByteSlices(tt.input...)
-		if !reflect.DeepEqual(result, tt.result) {
-			t.Errorf("IntersectionByteSlices(%v)=%v, wanted: %v",
-				tt.input, result, tt.result)
-		}
+		t.Run(tt.name, func(t *testing.T) {
+			result := sliceutil.IntersectionByteSlices(tt.input...)
+			if !reflect.DeepEqual(result, tt.result) {
+				t.Errorf("IntersectionByteSlices(%v)=%v, wanted: %v",
+					tt.input, result, tt.result)
+			}
+		})
 	}
+	t.Run("properly handle duplicates", func(t *testing.T) {
+		input := [][][]byte{
+			{{1, 2}, {1, 2}},
+			{{1, 2}, {1, 2}},
+			{},
+		}
+		result := sliceutil.IntersectionByteSlices(input...)
+		if !reflect.DeepEqual(result, [][]byte{}) {
+			t.Errorf("IntersectionByteSlices(%v)=%v, wanted: %v", input, result, [][]byte{})
+		}
+	})
 }
 
 func TestSplitCommaSeparated(t *testing.T) {