Source file src/unique/canonmap_test.go

     1  // Copyright 2025 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package unique
     6  
     7  import (
     8  	"internal/abi"
     9  	"runtime"
    10  	"strconv"
    11  	"sync"
    12  	"testing"
    13  	"unsafe"
    14  )
    15  
    16  func TestCanonMap(t *testing.T) {
    17  	testCanonMap(t, func() *canonMap[string] {
    18  		return newCanonMap[string]()
    19  	})
    20  }
    21  
    22  func TestCanonMapBadHash(t *testing.T) {
    23  	testCanonMap(t, func() *canonMap[string] {
    24  		return newBadCanonMap[string]()
    25  	})
    26  }
    27  
    28  func TestCanonMapTruncHash(t *testing.T) {
    29  	testCanonMap(t, func() *canonMap[string] {
    30  		// Stub out the good hash function with a different terrible one
    31  		// (truncated hash). Everything should still work as expected.
    32  		// This is useful to test independently to catch issues with
    33  		// near collisions, where only the last few bits of the hash differ.
    34  		return newTruncCanonMap[string]()
    35  	})
    36  }
    37  
    38  func testCanonMap(t *testing.T, newMap func() *canonMap[string]) {
    39  	t.Run("LoadEmpty", func(t *testing.T) {
    40  		m := newMap()
    41  
    42  		for _, s := range testData {
    43  			expectMissing(t, s)(m.Load(s))
    44  		}
    45  	})
    46  	t.Run("LoadOrStore", func(t *testing.T) {
    47  		t.Run("Sequential", func(t *testing.T) {
    48  			m := newMap()
    49  
    50  			var refs []*string
    51  			for _, s := range testData {
    52  				expectMissing(t, s)(m.Load(s))
    53  				refs = append(refs, expectPresent(t, s)(m.LoadOrStore(s)))
    54  				expectPresent(t, s)(m.Load(s))
    55  				expectPresent(t, s)(m.LoadOrStore(s))
    56  			}
    57  			drainCleanupQueue(t)
    58  			for _, s := range testData {
    59  				expectPresent(t, s)(m.Load(s))
    60  				expectPresent(t, s)(m.LoadOrStore(s))
    61  			}
    62  			runtime.KeepAlive(refs)
    63  			refs = nil
    64  			drainCleanupQueue(t)
    65  			for _, s := range testData {
    66  				expectMissing(t, s)(m.Load(s))
    67  				expectPresent(t, s)(m.LoadOrStore(s))
    68  			}
    69  		})
    70  		t.Run("ConcurrentUnsharedKeys", func(t *testing.T) {
    71  			makeKey := func(s string, id int) string {
    72  				return s + "-" + strconv.Itoa(id)
    73  			}
    74  
    75  			// Expand and shrink the map multiple times to try to get
    76  			// insertions and cleanups to overlap.
    77  			m := newMap()
    78  			gmp := runtime.GOMAXPROCS(-1)
    79  			for try := range 3 {
    80  				var wg sync.WaitGroup
    81  				for i := range gmp {
    82  					wg.Add(1)
    83  					go func(id int) {
    84  						defer wg.Done()
    85  
    86  						var refs []*string
    87  						for _, s := range testData {
    88  							key := makeKey(s, id)
    89  							if try == 0 {
    90  								expectMissing(t, key)(m.Load(key))
    91  							}
    92  							refs = append(refs, expectPresent(t, key)(m.LoadOrStore(key)))
    93  							expectPresent(t, key)(m.Load(key))
    94  							expectPresent(t, key)(m.LoadOrStore(key))
    95  						}
    96  						for i, s := range testData {
    97  							key := makeKey(s, id)
    98  							expectPresent(t, key)(m.Load(key))
    99  							if got, want := expectPresent(t, key)(m.LoadOrStore(key)), refs[i]; got != want {
   100  								t.Errorf("canonical entry %p did not match ref %p", got, want)
   101  							}
   102  						}
   103  						// N.B. We avoid trying to test entry cleanup here
   104  						// because it's going to be very flaky, especially
   105  						// in the bad hash cases.
   106  					}(i)
   107  				}
   108  				wg.Wait()
   109  			}
   110  
   111  			// Drain cleanups so everything is deleted.
   112  			drainCleanupQueue(t)
   113  
   114  			// Double-check that it's all gone.
   115  			for id := range gmp {
   116  				makeKey := func(s string) string {
   117  					return s + "-" + strconv.Itoa(id)
   118  				}
   119  				for _, s := range testData {
   120  					key := makeKey(s)
   121  					expectMissing(t, key)(m.Load(key))
   122  				}
   123  			}
   124  		})
   125  	})
   126  }
   127  
   128  func expectMissing[T comparable](t *testing.T, key T) func(got *T) {
   129  	t.Helper()
   130  	return func(got *T) {
   131  		t.Helper()
   132  
   133  		if got != nil {
   134  			t.Errorf("expected key %v to be missing from map, got %p", key, got)
   135  		}
   136  	}
   137  }
   138  
   139  func expectPresent[T comparable](t *testing.T, key T) func(got *T) *T {
   140  	t.Helper()
   141  	return func(got *T) *T {
   142  		t.Helper()
   143  
   144  		if got == nil {
   145  			t.Errorf("expected key %v to be present in map, got %p", key, got)
   146  		}
   147  		if got != nil && *got != key {
   148  			t.Errorf("key %v is present in map, but canonical version has the wrong value: got %v, want %v", key, *got, key)
   149  		}
   150  		return got
   151  	}
   152  }
   153  
   154  // newBadCanonMap creates a new canonMap for the provided key type
   155  // but with an intentionally bad hash function.
   156  func newBadCanonMap[T comparable]() *canonMap[T] {
   157  	// Stub out the good hash function with a terrible one.
   158  	// Everything should still work as expected.
   159  	m := newCanonMap[T]()
   160  	m.hash = func(_ unsafe.Pointer, _ uintptr) uintptr {
   161  		return 0
   162  	}
   163  	return m
   164  }
   165  
   166  // newTruncCanonMap creates a new canonMap for the provided key type
   167  // but with an intentionally bad hash function.
   168  func newTruncCanonMap[T comparable]() *canonMap[T] {
   169  	// Stub out the good hash function with a terrible one.
   170  	// Everything should still work as expected.
   171  	m := newCanonMap[T]()
   172  	var mx map[string]int
   173  	mapType := abi.TypeOf(mx).MapType()
   174  	hasher := mapType.Hasher
   175  	m.hash = func(p unsafe.Pointer, n uintptr) uintptr {
   176  		return hasher(p, n) & ((uintptr(1) << 4) - 1)
   177  	}
   178  	return m
   179  }
   180  

View as plain text