Source file
src/unique/canonmap.go
1
2
3
4
5 package unique
6
7 import (
8 "internal/abi"
9 "internal/goarch"
10 "runtime"
11 "sync"
12 "sync/atomic"
13 "unsafe"
14 "weak"
15 )
16
17
18
19
20 type canonMap[T comparable] struct {
21 root atomic.Pointer[indirect[T]]
22 hash func(unsafe.Pointer, uintptr) uintptr
23 seed uintptr
24 }
25
26 func newCanonMap[T comparable]() *canonMap[T] {
27 cm := new(canonMap[T])
28 cm.root.Store(newIndirectNode[T](nil))
29
30 var m map[T]struct{}
31 mapType := abi.TypeOf(m).MapType()
32 cm.hash = mapType.Hasher
33 cm.seed = uintptr(runtime_rand())
34 return cm
35 }
36
37 func (m *canonMap[T]) Load(key T) *T {
38 hash := m.hash(abi.NoEscape(unsafe.Pointer(&key)), m.seed)
39
40 i := m.root.Load()
41 hashShift := 8 * goarch.PtrSize
42 for hashShift != 0 {
43 hashShift -= nChildrenLog2
44
45 n := i.children[(hash>>hashShift)&nChildrenMask].Load()
46 if n == nil {
47 return nil
48 }
49 if n.isEntry {
50 v, _ := n.entry().lookup(key)
51 return v
52 }
53 i = n.indirect()
54 }
55 panic("unique.canonMap: ran out of hash bits while iterating")
56 }
57
58 func (m *canonMap[T]) LoadOrStore(key T) *T {
59 hash := m.hash(abi.NoEscape(unsafe.Pointer(&key)), m.seed)
60
61 var i *indirect[T]
62 var hashShift uint
63 var slot *atomic.Pointer[node[T]]
64 var n *node[T]
65 for {
66
67 i = m.root.Load()
68 hashShift = 8 * goarch.PtrSize
69 haveInsertPoint := false
70 for hashShift != 0 {
71 hashShift -= nChildrenLog2
72
73 slot = &i.children[(hash>>hashShift)&nChildrenMask]
74 n = slot.Load()
75 if n == nil {
76
77 haveInsertPoint = true
78 break
79 }
80 if n.isEntry {
81
82
83
84 if v, _ := n.entry().lookup(key); v != nil {
85 return v
86 }
87 haveInsertPoint = true
88 break
89 }
90 i = n.indirect()
91 }
92 if !haveInsertPoint {
93 panic("unique.canonMap: ran out of hash bits while iterating")
94 }
95
96
97 i.mu.Lock()
98 n = slot.Load()
99 if (n == nil || n.isEntry) && !i.dead.Load() {
100
101 break
102 }
103
104 i.mu.Unlock()
105 }
106
107
108
109
110
111 defer i.mu.Unlock()
112
113 var oldEntry *entry[T]
114 if n != nil {
115 oldEntry = n.entry()
116 if v, _ := oldEntry.lookup(key); v != nil {
117
118 return v
119 }
120 }
121 newEntry, canon, wp := newEntryNode(key, hash)
122
123
124
125 oldEntry = oldEntry.prune()
126 if oldEntry == nil {
127
128 slot.Store(&newEntry.node)
129 } else {
130
131
132
133
134 slot.Store(m.expand(oldEntry, newEntry, hash, hashShift, i))
135 }
136 runtime.AddCleanup(canon, func(_ struct{}) {
137 m.cleanup(hash, wp)
138 }, struct{}{})
139 return canon
140 }
141
142
143
144
145 func (m *canonMap[T]) expand(oldEntry, newEntry *entry[T], newHash uintptr, hashShift uint, parent *indirect[T]) *node[T] {
146
147 oldHash := oldEntry.hash
148 if oldHash == newHash {
149
150
151 newEntry.overflow.Store(oldEntry)
152 return &newEntry.node
153 }
154
155 newIndirect := newIndirectNode(parent)
156 top := newIndirect
157 for {
158 if hashShift == 0 {
159 panic("unique.canonMap: ran out of hash bits while inserting")
160 }
161 hashShift -= nChildrenLog2
162 oi := (oldHash >> hashShift) & nChildrenMask
163 ni := (newHash >> hashShift) & nChildrenMask
164 if oi != ni {
165 newIndirect.children[oi].Store(&oldEntry.node)
166 newIndirect.children[ni].Store(&newEntry.node)
167 break
168 }
169 nextIndirect := newIndirectNode(newIndirect)
170 newIndirect.children[oi].Store(&nextIndirect.node)
171 newIndirect = nextIndirect
172 }
173 return &top.node
174 }
175
176
177
178
179
180 func (m *canonMap[T]) cleanup(hash uintptr, wp weak.Pointer[T]) {
181 var i *indirect[T]
182 var hashShift uint
183 var slot *atomic.Pointer[node[T]]
184 var n *node[T]
185 for {
186
187 i = m.root.Load()
188 hashShift = 8 * goarch.PtrSize
189 haveEntry := false
190 for hashShift != 0 {
191 hashShift -= nChildrenLog2
192
193 slot = &i.children[(hash>>hashShift)&nChildrenMask]
194 n = slot.Load()
195 if n == nil {
196
197 return
198 }
199 if n.isEntry {
200 if !n.entry().hasWeakPointer(wp) {
201
202 return
203 }
204 haveEntry = true
205 break
206 }
207 i = n.indirect()
208 }
209 if !haveEntry {
210 panic("unique.canonMap: ran out of hash bits while iterating")
211 }
212
213
214 i.mu.Lock()
215 n = slot.Load()
216 if n != nil && n.isEntry {
217
218
219
220
221 newEntry := n.entry().prune()
222 if newEntry == nil {
223 slot.Store(nil)
224 } else {
225 slot.Store(&newEntry.node)
226 }
227
228
229
230
231
232
233 for i.parent != nil && i.empty() {
234 if hashShift == 8*goarch.PtrSize {
235 panic("internal/sync.HashTrieMap: ran out of hash bits while iterating")
236 }
237 hashShift += nChildrenLog2
238
239
240 parent := i.parent
241 parent.mu.Lock()
242 i.dead.Store(true)
243 parent.children[(hash>>hashShift)&nChildrenMask].Store(nil)
244 i.mu.Unlock()
245 i = parent
246 }
247 i.mu.Unlock()
248 return
249 }
250
251 i.mu.Unlock()
252 }
253 }
254
255
256
257 type node[T comparable] struct {
258 isEntry bool
259 }
260
261 func (n *node[T]) entry() *entry[T] {
262 if !n.isEntry {
263 panic("called entry on non-entry node")
264 }
265 return (*entry[T])(unsafe.Pointer(n))
266 }
267
268 func (n *node[T]) indirect() *indirect[T] {
269 if n.isEntry {
270 panic("called indirect on entry node")
271 }
272 return (*indirect[T])(unsafe.Pointer(n))
273 }
274
275 const (
276
277
278
279
280 nChildrenLog2 = 4
281 nChildren = 1 << nChildrenLog2
282 nChildrenMask = nChildren - 1
283 )
284
285
286 type indirect[T comparable] struct {
287 node[T]
288 dead atomic.Bool
289 parent *indirect[T]
290 mu sync.Mutex
291 children [nChildren]atomic.Pointer[node[T]]
292 }
293
294 func newIndirectNode[T comparable](parent *indirect[T]) *indirect[T] {
295 return &indirect[T]{node: node[T]{isEntry: false}, parent: parent}
296 }
297
298 func (i *indirect[T]) empty() bool {
299 for j := range i.children {
300 if i.children[j].Load() != nil {
301 return false
302 }
303 }
304 return true
305 }
306
307
308 type entry[T comparable] struct {
309 node[T]
310 overflow atomic.Pointer[entry[T]]
311 key weak.Pointer[T]
312 hash uintptr
313 }
314
315 func newEntryNode[T comparable](key T, hash uintptr) (*entry[T], *T, weak.Pointer[T]) {
316 k := new(T)
317 *k = key
318 wp := weak.Make(k)
319 return &entry[T]{
320 node: node[T]{isEntry: true},
321 key: wp,
322 hash: hash,
323 }, k, wp
324 }
325
326
327
328
329 func (e *entry[T]) lookup(key T) (*T, weak.Pointer[T]) {
330 for e != nil {
331 s := e.key.Value()
332 if s != nil && *s == key {
333 return s, e.key
334 }
335 e = e.overflow.Load()
336 }
337 return nil, weak.Pointer[T]{}
338 }
339
340
341 func (e *entry[T]) hasWeakPointer(wp weak.Pointer[T]) bool {
342 for e != nil {
343 if e.key == wp {
344 return true
345 }
346 e = e.overflow.Load()
347 }
348 return false
349 }
350
351
352
353
354 func (e *entry[T]) prune() *entry[T] {
355
356 for e != nil {
357 if e.key.Value() != nil {
358 break
359 }
360 e = e.overflow.Load()
361 }
362 if e == nil {
363 return nil
364 }
365
366
367 newHead := e
368 i := &e.overflow
369 e = i.Load()
370 for e != nil {
371 if e.key.Value() != nil {
372 i = &e.overflow
373 } else {
374 i.Store(e.overflow.Load())
375 }
376 e = e.overflow.Load()
377 }
378 return newHead
379 }
380
381
382
383
384
385 func runtime_rand() uint64
386
View as plain text