Source file src/unique/canonmap.go

     1  // Copyright 2025 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package unique
     6  
     7  import (
     8  	"internal/abi"
     9  	"internal/goarch"
    10  	"runtime"
    11  	"sync"
    12  	"sync/atomic"
    13  	"unsafe"
    14  	"weak"
    15  )
    16  
    17  // canonMap is a map of T -> *T. The map controls the creation
    18  // of a canonical *T, and elements of the map are automatically
    19  // deleted when the canonical *T is no longer referenced.
    20  type canonMap[T comparable] struct {
    21  	root atomic.Pointer[indirect[T]]
    22  	hash func(unsafe.Pointer, uintptr) uintptr
    23  	seed uintptr
    24  }
    25  
    26  func newCanonMap[T comparable]() *canonMap[T] {
    27  	cm := new(canonMap[T])
    28  	cm.root.Store(newIndirectNode[T](nil))
    29  
    30  	var m map[T]struct{}
    31  	mapType := abi.TypeOf(m).MapType()
    32  	cm.hash = mapType.Hasher
    33  	cm.seed = uintptr(runtime_rand())
    34  	return cm
    35  }
    36  
    37  func (m *canonMap[T]) Load(key T) *T {
    38  	hash := m.hash(abi.NoEscape(unsafe.Pointer(&key)), m.seed)
    39  
    40  	i := m.root.Load()
    41  	hashShift := 8 * goarch.PtrSize
    42  	for hashShift != 0 {
    43  		hashShift -= nChildrenLog2
    44  
    45  		n := i.children[(hash>>hashShift)&nChildrenMask].Load()
    46  		if n == nil {
    47  			return nil
    48  		}
    49  		if n.isEntry {
    50  			v, _ := n.entry().lookup(key)
    51  			return v
    52  		}
    53  		i = n.indirect()
    54  	}
    55  	panic("unique.canonMap: ran out of hash bits while iterating")
    56  }
    57  
    58  func (m *canonMap[T]) LoadOrStore(key T) *T {
    59  	hash := m.hash(abi.NoEscape(unsafe.Pointer(&key)), m.seed)
    60  
    61  	var i *indirect[T]
    62  	var hashShift uint
    63  	var slot *atomic.Pointer[node[T]]
    64  	var n *node[T]
    65  	for {
    66  		// Find the key or a candidate location for insertion.
    67  		i = m.root.Load()
    68  		hashShift = 8 * goarch.PtrSize
    69  		haveInsertPoint := false
    70  		for hashShift != 0 {
    71  			hashShift -= nChildrenLog2
    72  
    73  			slot = &i.children[(hash>>hashShift)&nChildrenMask]
    74  			n = slot.Load()
    75  			if n == nil {
    76  				// We found a nil slot which is a candidate for insertion.
    77  				haveInsertPoint = true
    78  				break
    79  			}
    80  			if n.isEntry {
    81  				// We found an existing entry, which is as far as we can go.
    82  				// If it stays this way, we'll have to replace it with an
    83  				// indirect node.
    84  				if v, _ := n.entry().lookup(key); v != nil {
    85  					return v
    86  				}
    87  				haveInsertPoint = true
    88  				break
    89  			}
    90  			i = n.indirect()
    91  		}
    92  		if !haveInsertPoint {
    93  			panic("unique.canonMap: ran out of hash bits while iterating")
    94  		}
    95  
    96  		// Grab the lock and double-check what we saw.
    97  		i.mu.Lock()
    98  		n = slot.Load()
    99  		if (n == nil || n.isEntry) && !i.dead.Load() {
   100  			// What we saw is still true, so we can continue with the insert.
   101  			break
   102  		}
   103  		// We have to start over.
   104  		i.mu.Unlock()
   105  	}
   106  	// N.B. This lock is held from when we broke out of the outer loop above.
   107  	// We specifically break this out so that we can use defer here safely.
   108  	// One option is to break this out into a new function instead, but
   109  	// there's so much local iteration state used below that this turns out
   110  	// to be cleaner.
   111  	defer i.mu.Unlock()
   112  
   113  	var oldEntry *entry[T]
   114  	if n != nil {
   115  		oldEntry = n.entry()
   116  		if v, _ := oldEntry.lookup(key); v != nil {
   117  			// Easy case: by loading again, it turns out exactly what we wanted is here!
   118  			return v
   119  		}
   120  	}
   121  	newEntry, canon, wp := newEntryNode(key, hash)
   122  	// Prune dead pointers. This is to avoid O(n) lookups when we store the exact same
   123  	// value in the set but the cleanup hasn't run yet because it got delayed for some
   124  	// reason.
   125  	oldEntry = oldEntry.prune()
   126  	if oldEntry == nil {
   127  		// Easy case: create a new entry and store it.
   128  		slot.Store(&newEntry.node)
   129  	} else {
   130  		// We possibly need to expand the entry already there into one or more new nodes.
   131  		//
   132  		// Publish the node last, which will make both oldEntry and newEntry visible. We
   133  		// don't want readers to be able to observe that oldEntry isn't in the tree.
   134  		slot.Store(m.expand(oldEntry, newEntry, hash, hashShift, i))
   135  	}
   136  	runtime.AddCleanup(canon, func(_ struct{}) {
   137  		m.cleanup(hash, wp)
   138  	}, struct{}{})
   139  	return canon
   140  }
   141  
   142  // expand takes oldEntry and newEntry whose hashes conflict from bit 64 down to hashShift and
   143  // produces a subtree of indirect nodes to hold the two new entries. newHash is the hash of
   144  // the value in the new entry.
   145  func (m *canonMap[T]) expand(oldEntry, newEntry *entry[T], newHash uintptr, hashShift uint, parent *indirect[T]) *node[T] {
   146  	// Check for a hash collision.
   147  	oldHash := oldEntry.hash
   148  	if oldHash == newHash {
   149  		// Store the old entry in the new entry's overflow list, then store
   150  		// the new entry.
   151  		newEntry.overflow.Store(oldEntry)
   152  		return &newEntry.node
   153  	}
   154  	// We have to add an indirect node. Worse still, we may need to add more than one.
   155  	newIndirect := newIndirectNode(parent)
   156  	top := newIndirect
   157  	for {
   158  		if hashShift == 0 {
   159  			panic("unique.canonMap: ran out of hash bits while inserting")
   160  		}
   161  		hashShift -= nChildrenLog2 // hashShift is for the level parent is at. We need to go deeper.
   162  		oi := (oldHash >> hashShift) & nChildrenMask
   163  		ni := (newHash >> hashShift) & nChildrenMask
   164  		if oi != ni {
   165  			newIndirect.children[oi].Store(&oldEntry.node)
   166  			newIndirect.children[ni].Store(&newEntry.node)
   167  			break
   168  		}
   169  		nextIndirect := newIndirectNode(newIndirect)
   170  		newIndirect.children[oi].Store(&nextIndirect.node)
   171  		newIndirect = nextIndirect
   172  	}
   173  	return &top.node
   174  }
   175  
   176  // cleanup deletes the entry corresponding to wp in the canon map, if it's
   177  // still in the map. wp must have a Value method that returns nil by the
   178  // time this function is called. hash must be the hash of the value that
   179  // wp once pointed to (that is, the hash of *wp.Value()).
   180  func (m *canonMap[T]) cleanup(hash uintptr, wp weak.Pointer[T]) {
   181  	var i *indirect[T]
   182  	var hashShift uint
   183  	var slot *atomic.Pointer[node[T]]
   184  	var n *node[T]
   185  	for {
   186  		// Find wp in the map by following hash.
   187  		i = m.root.Load()
   188  		hashShift = 8 * goarch.PtrSize
   189  		haveEntry := false
   190  		for hashShift != 0 {
   191  			hashShift -= nChildrenLog2
   192  
   193  			slot = &i.children[(hash>>hashShift)&nChildrenMask]
   194  			n = slot.Load()
   195  			if n == nil {
   196  				// We found a nil slot, already deleted.
   197  				return
   198  			}
   199  			if n.isEntry {
   200  				if !n.entry().hasWeakPointer(wp) {
   201  					// The weak pointer was already pruned.
   202  					return
   203  				}
   204  				haveEntry = true
   205  				break
   206  			}
   207  			i = n.indirect()
   208  		}
   209  		if !haveEntry {
   210  			panic("unique.canonMap: ran out of hash bits while iterating")
   211  		}
   212  
   213  		// Grab the lock and double-check what we saw.
   214  		i.mu.Lock()
   215  		n = slot.Load()
   216  		if n != nil && n.isEntry {
   217  			// Prune the entry node without thinking too hard. If we do
   218  			// somebody else's work, such as someone trying to insert an
   219  			// entry with the same hash (probably the same value) then
   220  			// great, they'll back out without taking the lock.
   221  			newEntry := n.entry().prune()
   222  			if newEntry == nil {
   223  				slot.Store(nil)
   224  			} else {
   225  				slot.Store(&newEntry.node)
   226  			}
   227  
   228  			// Delete interior nodes that are empty, up the tree.
   229  			//
   230  			// We'll hand-over-hand lock our way up the tree as we do this,
   231  			// since we need to delete each empty node's link in its parent,
   232  			// which requires the parents' lock.
   233  			for i.parent != nil && i.empty() {
   234  				if hashShift == 8*goarch.PtrSize {
   235  					panic("internal/sync.HashTrieMap: ran out of hash bits while iterating")
   236  				}
   237  				hashShift += nChildrenLog2
   238  
   239  				// Delete the current node in the parent.
   240  				parent := i.parent
   241  				parent.mu.Lock()
   242  				i.dead.Store(true) // Could be done outside of parent's lock.
   243  				parent.children[(hash>>hashShift)&nChildrenMask].Store(nil)
   244  				i.mu.Unlock()
   245  				i = parent
   246  			}
   247  			i.mu.Unlock()
   248  			return
   249  		}
   250  		// We have to start over.
   251  		i.mu.Unlock()
   252  	}
   253  }
   254  
   255  // node is the header for a node. It's polymorphic and
   256  // is actually either an entry or an indirect.
   257  type node[T comparable] struct {
   258  	isEntry bool
   259  }
   260  
   261  func (n *node[T]) entry() *entry[T] {
   262  	if !n.isEntry {
   263  		panic("called entry on non-entry node")
   264  	}
   265  	return (*entry[T])(unsafe.Pointer(n))
   266  }
   267  
   268  func (n *node[T]) indirect() *indirect[T] {
   269  	if n.isEntry {
   270  		panic("called indirect on entry node")
   271  	}
   272  	return (*indirect[T])(unsafe.Pointer(n))
   273  }
   274  
   275  const (
   276  	// 16 children. This seems to be the sweet spot for
   277  	// load performance: any smaller and we lose out on
   278  	// 50% or more in CPU performance. Any larger and the
   279  	// returns are minuscule (~1% improvement for 32 children).
   280  	nChildrenLog2 = 4
   281  	nChildren     = 1 << nChildrenLog2
   282  	nChildrenMask = nChildren - 1
   283  )
   284  
   285  // indirect is an internal node in the hash-trie.
   286  type indirect[T comparable] struct {
   287  	node[T]
   288  	dead     atomic.Bool
   289  	parent   *indirect[T]
   290  	mu       sync.Mutex // Protects mutation to children and any children that are entry nodes.
   291  	children [nChildren]atomic.Pointer[node[T]]
   292  }
   293  
   294  func newIndirectNode[T comparable](parent *indirect[T]) *indirect[T] {
   295  	return &indirect[T]{node: node[T]{isEntry: false}, parent: parent}
   296  }
   297  
   298  func (i *indirect[T]) empty() bool {
   299  	for j := range i.children {
   300  		if i.children[j].Load() != nil {
   301  			return false
   302  		}
   303  	}
   304  	return true
   305  }
   306  
   307  // entry is a leaf node in the hash-trie.
   308  type entry[T comparable] struct {
   309  	node[T]
   310  	overflow atomic.Pointer[entry[T]] // Overflow for hash collisions.
   311  	key      weak.Pointer[T]
   312  	hash     uintptr
   313  }
   314  
   315  func newEntryNode[T comparable](key T, hash uintptr) (*entry[T], *T, weak.Pointer[T]) {
   316  	k := new(T)
   317  	*k = key
   318  	wp := weak.Make(k)
   319  	return &entry[T]{
   320  		node: node[T]{isEntry: true},
   321  		key:  wp,
   322  		hash: hash,
   323  	}, k, wp
   324  }
   325  
   326  // lookup finds the entry in the overflow chain that has the provided key.
   327  //
   328  // Returns the key's canonical pointer and the weak pointer for that canonical pointer.
   329  func (e *entry[T]) lookup(key T) (*T, weak.Pointer[T]) {
   330  	for e != nil {
   331  		s := e.key.Value()
   332  		if s != nil && *s == key {
   333  			return s, e.key
   334  		}
   335  		e = e.overflow.Load()
   336  	}
   337  	return nil, weak.Pointer[T]{}
   338  }
   339  
   340  // hasWeakPointer returns true if the provided weak pointer can be found in the overflow chain.
   341  func (e *entry[T]) hasWeakPointer(wp weak.Pointer[T]) bool {
   342  	for e != nil {
   343  		if e.key == wp {
   344  			return true
   345  		}
   346  		e = e.overflow.Load()
   347  	}
   348  	return false
   349  }
   350  
   351  // prune removes all entries in the overflow chain whose keys are nil.
   352  //
   353  // The caller must hold the lock on e's parent node.
   354  func (e *entry[T]) prune() *entry[T] {
   355  	// Prune the head of the list.
   356  	for e != nil {
   357  		if e.key.Value() != nil {
   358  			break
   359  		}
   360  		e = e.overflow.Load()
   361  	}
   362  	if e == nil {
   363  		return nil
   364  	}
   365  
   366  	// Prune individual nodes in the list.
   367  	newHead := e
   368  	i := &e.overflow
   369  	e = i.Load()
   370  	for e != nil {
   371  		if e.key.Value() != nil {
   372  			i = &e.overflow
   373  		} else {
   374  			i.Store(e.overflow.Load())
   375  		}
   376  		e = e.overflow.Load()
   377  	}
   378  	return newHead
   379  }
   380  
   381  // Pull in runtime.rand so that we don't need to take a dependency
   382  // on math/rand/v2.
   383  //
   384  //go:linkname runtime_rand runtime.rand
   385  func runtime_rand() uint64
   386  

View as plain text