// Copyright 2025 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package unique import ( "internal/abi" "internal/goarch" "runtime" "sync" "sync/atomic" "unsafe" "weak" ) // canonMap is a map of T -> *T. The map controls the creation // of a canonical *T, and elements of the map are automatically // deleted when the canonical *T is no longer referenced. type canonMap[T comparable] struct { root atomic.Pointer[indirect[T]] hash func(unsafe.Pointer, uintptr) uintptr seed uintptr } func newCanonMap[T comparable]() *canonMap[T] { cm := new(canonMap[T]) cm.root.Store(newIndirectNode[T](nil)) var m map[T]struct{} mapType := abi.TypeOf(m).MapType() cm.hash = mapType.Hasher cm.seed = uintptr(runtime_rand()) return cm } func (m *canonMap[T]) Load(key T) *T { hash := m.hash(abi.NoEscape(unsafe.Pointer(&key)), m.seed) i := m.root.Load() hashShift := 8 * goarch.PtrSize for hashShift != 0 { hashShift -= nChildrenLog2 n := i.children[(hash>>hashShift)&nChildrenMask].Load() if n == nil { return nil } if n.isEntry { v, _ := n.entry().lookup(key) return v } i = n.indirect() } panic("unique.canonMap: ran out of hash bits while iterating") } func (m *canonMap[T]) LoadOrStore(key T) *T { hash := m.hash(abi.NoEscape(unsafe.Pointer(&key)), m.seed) var i *indirect[T] var hashShift uint var slot *atomic.Pointer[node[T]] var n *node[T] for { // Find the key or a candidate location for insertion. i = m.root.Load() hashShift = 8 * goarch.PtrSize haveInsertPoint := false for hashShift != 0 { hashShift -= nChildrenLog2 slot = &i.children[(hash>>hashShift)&nChildrenMask] n = slot.Load() if n == nil { // We found a nil slot which is a candidate for insertion. haveInsertPoint = true break } if n.isEntry { // We found an existing entry, which is as far as we can go. // If it stays this way, we'll have to replace it with an // indirect node. if v, _ := n.entry().lookup(key); v != nil { return v } haveInsertPoint = true break } i = n.indirect() } if !haveInsertPoint { panic("unique.canonMap: ran out of hash bits while iterating") } // Grab the lock and double-check what we saw. i.mu.Lock() n = slot.Load() if (n == nil || n.isEntry) && !i.dead.Load() { // What we saw is still true, so we can continue with the insert. break } // We have to start over. i.mu.Unlock() } // N.B. This lock is held from when we broke out of the outer loop above. // We specifically break this out so that we can use defer here safely. // One option is to break this out into a new function instead, but // there's so much local iteration state used below that this turns out // to be cleaner. defer i.mu.Unlock() var oldEntry *entry[T] if n != nil { oldEntry = n.entry() if v, _ := oldEntry.lookup(key); v != nil { // Easy case: by loading again, it turns out exactly what we wanted is here! return v } } newEntry, canon, wp := newEntryNode(key, hash) // Prune dead pointers. This is to avoid O(n) lookups when we store the exact same // value in the set but the cleanup hasn't run yet because it got delayed for some // reason. oldEntry = oldEntry.prune() if oldEntry == nil { // Easy case: create a new entry and store it. slot.Store(&newEntry.node) } else { // We possibly need to expand the entry already there into one or more new nodes. // // Publish the node last, which will make both oldEntry and newEntry visible. We // don't want readers to be able to observe that oldEntry isn't in the tree. slot.Store(m.expand(oldEntry, newEntry, hash, hashShift, i)) } runtime.AddCleanup(canon, func(_ struct{}) { m.cleanup(hash, wp) }, struct{}{}) return canon } // expand takes oldEntry and newEntry whose hashes conflict from bit 64 down to hashShift and // produces a subtree of indirect nodes to hold the two new entries. newHash is the hash of // the value in the new entry. func (m *canonMap[T]) expand(oldEntry, newEntry *entry[T], newHash uintptr, hashShift uint, parent *indirect[T]) *node[T] { // Check for a hash collision. oldHash := oldEntry.hash if oldHash == newHash { // Store the old entry in the new entry's overflow list, then store // the new entry. newEntry.overflow.Store(oldEntry) return &newEntry.node } // We have to add an indirect node. Worse still, we may need to add more than one. newIndirect := newIndirectNode(parent) top := newIndirect for { if hashShift == 0 { panic("unique.canonMap: ran out of hash bits while inserting") } hashShift -= nChildrenLog2 // hashShift is for the level parent is at. We need to go deeper. oi := (oldHash >> hashShift) & nChildrenMask ni := (newHash >> hashShift) & nChildrenMask if oi != ni { newIndirect.children[oi].Store(&oldEntry.node) newIndirect.children[ni].Store(&newEntry.node) break } nextIndirect := newIndirectNode(newIndirect) newIndirect.children[oi].Store(&nextIndirect.node) newIndirect = nextIndirect } return &top.node } // cleanup deletes the entry corresponding to wp in the canon map, if it's // still in the map. wp must have a Value method that returns nil by the // time this function is called. hash must be the hash of the value that // wp once pointed to (that is, the hash of *wp.Value()). func (m *canonMap[T]) cleanup(hash uintptr, wp weak.Pointer[T]) { var i *indirect[T] var hashShift uint var slot *atomic.Pointer[node[T]] var n *node[T] for { // Find wp in the map by following hash. i = m.root.Load() hashShift = 8 * goarch.PtrSize haveEntry := false for hashShift != 0 { hashShift -= nChildrenLog2 slot = &i.children[(hash>>hashShift)&nChildrenMask] n = slot.Load() if n == nil { // We found a nil slot, already deleted. return } if n.isEntry { if !n.entry().hasWeakPointer(wp) { // The weak pointer was already pruned. return } haveEntry = true break } i = n.indirect() } if !haveEntry { panic("unique.canonMap: ran out of hash bits while iterating") } // Grab the lock and double-check what we saw. i.mu.Lock() n = slot.Load() if n != nil && n.isEntry { // Prune the entry node without thinking too hard. If we do // somebody else's work, such as someone trying to insert an // entry with the same hash (probably the same value) then // great, they'll back out without taking the lock. newEntry := n.entry().prune() if newEntry == nil { slot.Store(nil) } else { slot.Store(&newEntry.node) } // Delete interior nodes that are empty, up the tree. // // We'll hand-over-hand lock our way up the tree as we do this, // since we need to delete each empty node's link in its parent, // which requires the parents' lock. for i.parent != nil && i.empty() { if hashShift == 8*goarch.PtrSize { panic("internal/sync.HashTrieMap: ran out of hash bits while iterating") } hashShift += nChildrenLog2 // Delete the current node in the parent. parent := i.parent parent.mu.Lock() i.dead.Store(true) // Could be done outside of parent's lock. parent.children[(hash>>hashShift)&nChildrenMask].Store(nil) i.mu.Unlock() i = parent } i.mu.Unlock() return } // We have to start over. i.mu.Unlock() } } // node is the header for a node. It's polymorphic and // is actually either an entry or an indirect. type node[T comparable] struct { isEntry bool } func (n *node[T]) entry() *entry[T] { if !n.isEntry { panic("called entry on non-entry node") } return (*entry[T])(unsafe.Pointer(n)) } func (n *node[T]) indirect() *indirect[T] { if n.isEntry { panic("called indirect on entry node") } return (*indirect[T])(unsafe.Pointer(n)) } const ( // 16 children. This seems to be the sweet spot for // load performance: any smaller and we lose out on // 50% or more in CPU performance. Any larger and the // returns are minuscule (~1% improvement for 32 children). nChildrenLog2 = 4 nChildren = 1 << nChildrenLog2 nChildrenMask = nChildren - 1 ) // indirect is an internal node in the hash-trie. type indirect[T comparable] struct { node[T] dead atomic.Bool parent *indirect[T] mu sync.Mutex // Protects mutation to children and any children that are entry nodes. children [nChildren]atomic.Pointer[node[T]] } func newIndirectNode[T comparable](parent *indirect[T]) *indirect[T] { return &indirect[T]{node: node[T]{isEntry: false}, parent: parent} } func (i *indirect[T]) empty() bool { for j := range i.children { if i.children[j].Load() != nil { return false } } return true } // entry is a leaf node in the hash-trie. type entry[T comparable] struct { node[T] overflow atomic.Pointer[entry[T]] // Overflow for hash collisions. key weak.Pointer[T] hash uintptr } func newEntryNode[T comparable](key T, hash uintptr) (*entry[T], *T, weak.Pointer[T]) { k := new(T) *k = key wp := weak.Make(k) return &entry[T]{ node: node[T]{isEntry: true}, key: wp, hash: hash, }, k, wp } // lookup finds the entry in the overflow chain that has the provided key. // // Returns the key's canonical pointer and the weak pointer for that canonical pointer. func (e *entry[T]) lookup(key T) (*T, weak.Pointer[T]) { for e != nil { s := e.key.Value() if s != nil && *s == key { return s, e.key } e = e.overflow.Load() } return nil, weak.Pointer[T]{} } // hasWeakPointer returns true if the provided weak pointer can be found in the overflow chain. func (e *entry[T]) hasWeakPointer(wp weak.Pointer[T]) bool { for e != nil { if e.key == wp { return true } e = e.overflow.Load() } return false } // prune removes all entries in the overflow chain whose keys are nil. // // The caller must hold the lock on e's parent node. func (e *entry[T]) prune() *entry[T] { // Prune the head of the list. for e != nil { if e.key.Value() != nil { break } e = e.overflow.Load() } if e == nil { return nil } // Prune individual nodes in the list. newHead := e i := &e.overflow e = i.Load() for e != nil { if e.key.Value() != nil { i = &e.overflow } else { i.Store(e.overflow.Load()) } e = e.overflow.Load() } return newHead } // Pull in runtime.rand so that we don't need to take a dependency // on math/rand/v2. // //go:linkname runtime_rand runtime.rand func runtime_rand() uint64