diff --git a/container/thread-safe/map.go b/container/thread-safe/map.go index 96932f075..9306b40c3 100644 --- a/container/thread-safe/map.go +++ b/container/thread-safe/map.go @@ -17,43 +17,80 @@ func NewThreadSafeMap[K comparable, V any](m map[K]V) *Map[K, V] { } } +// view an immutable snapshot of the map +func (m *Map[K, V]) View(fn func(mp map[K]V)) { + m.read(fn) +} + +// Do an action on the thread safe map +func (m *Map[K, V]) Do(fn func(mp map[K]V)) { + m.write(fn) +} + // Keys returns the keys of a thread-safe map. -func (m *Map[K, V]) Keys() []K { - m.lock.RLock() - defer m.lock.RUnlock() - r := make([]K, 0, len(m.items)) - for k := range m.items { - key := k - r = append(r, key) - } +func (m *Map[K, V]) Keys() (r []K) { + m.View(func(mp map[K]V) { + r = make([]K, 0, len(m.items)) + for k := range mp { + key := k + r = append(r, key) + } + }) return r } // Len of the thread-safe map. -func (m *Map[K, V]) Len() int { - m.lock.RLock() - defer m.lock.RUnlock() - return len(m.items) +func (m *Map[K, V]) Len() (l int) { + m.View(func(mp map[K]V) { + l = len(m.items) + }) + return } // Get an item from a thread-safe map. -func (m *Map[K, V]) Get(k K) (V, bool) { - m.lock.RLock() - defer m.lock.RUnlock() - v, ok := m.items[k] +func (m *Map[K, V]) Get(k K) (v V, ok bool) { + m.View(func(mp map[K]V) { + v, ok = mp[k] + }) return v, ok } +// Range runs the function fn(k K, v V) bool for each key value pair +// The keys are determined by a snapshot taken at the beginning of the range call +// If fn returns false, then the loop stops +// Only one invocation of fn will be active at one time, the iteration order is unspecified. +func (m *Map[K, V]) Range(fn func(k K, v V) bool) { + m.View(func(mp map[K]V) { + for k, v := range mp { + if !fn(k, v) { + return + } + } + }) +} + // Put an item into a thread-safe map. func (m *Map[K, V]) Put(k K, v V) { - m.lock.Lock() - defer m.lock.Unlock() - m.items[k] = v + m.Do(func(mp map[K]V) { + mp[k] = v + }) } // Delete an item from a thread-safe map. func (m *Map[K, V]) Delete(k K) { - m.lock.Lock() - defer m.lock.Unlock() - delete(m.items, k) + m.Do(func(mp map[K]V) { + delete(m.items, k) + }) +} + +func (m *Map[K, V]) read(fn func(mp map[K]V)) { + m.lock.RLock() + fn(m.items) + m.lock.RUnlock() +} + +func (m *Map[K, V]) write(fn func(mp map[K]V)) { + m.lock.Lock() + fn(m.items) + m.lock.Unlock() } diff --git a/container/thread-safe/map_test.go b/container/thread-safe/map_test.go index 8bbac1148..ca356fb5c 100644 --- a/container/thread-safe/map_test.go +++ b/container/thread-safe/map_test.go @@ -56,6 +56,19 @@ func BenchmarkMap_Generic(b *testing.B) { } } } +func BenchmarkMap_GenericTx(b *testing.B) { + items := make(map[int]string) + mm := NewThreadSafeMap(items) + for i := 0; i < b.N; i++ { + for j := 0; j < 1000; j++ { + mm.Do(func(mp map[int]string) { + mp[j] = "foo" + _ = mp[j] + delete(mp, j) + }) + } + } +} func TestMap(t *testing.T) { m := map[int]string{