Source file src/cmd/compile/internal/ssa/poset.go

     1  // Copyright 2018 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package ssa
     6  
     7  import (
     8  	"fmt"
     9  	"os"
    10  	"slices"
    11  )
    12  
    13  // If true, check poset integrity after every mutation
    14  var debugPoset = false
    15  
    16  const uintSize = 32 << (^uint(0) >> 63) // 32 or 64
    17  
    18  // bitset is a bit array for dense indexes.
    19  type bitset []uint
    20  
    21  func computeBitsetSize(n int) int {
    22  	return (n + uintSize - 1) / uintSize
    23  }
    24  
    25  func newBitset(n int) bitset {
    26  	return make(bitset, computeBitsetSize(n))
    27  }
    28  
    29  func (c *Cache) allocBitset(n int) bitset {
    30  	return bitset(c.allocUintSlice(computeBitsetSize(n)))
    31  }
    32  
    33  func (c *Cache) freeBitset(bs bitset) {
    34  	c.freeUintSlice([]uint(bs))
    35  }
    36  
    37  func (bs bitset) Reset() {
    38  	clear(bs)
    39  }
    40  
    41  func (bs bitset) Set(idx uint32) {
    42  	bs[idx/uintSize] |= 1 << (idx % uintSize)
    43  }
    44  
    45  func (bs bitset) Clear(idx uint32) {
    46  	bs[idx/uintSize] &^= 1 << (idx % uintSize)
    47  }
    48  
    49  func (bs bitset) Test(idx uint32) bool {
    50  	return bs[idx/uintSize]&(1<<(idx%uintSize)) != 0
    51  }
    52  
    53  type undoType uint8
    54  
    55  const (
    56  	undoInvalid    undoType = iota
    57  	undoCheckpoint          // a checkpoint to group undo passes
    58  	undoSetChl              // change back left child of undo.idx to undo.edge
    59  	undoSetChr              // change back right child of undo.idx to undo.edge
    60  	undoNonEqual            // forget that SSA value undo.ID is non-equal to undo.idx (another ID)
    61  	undoNewNode             // remove new node created for SSA value undo.ID
    62  	undoAliasNode           // unalias SSA value undo.ID so that it points back to node index undo.idx
    63  	undoNewRoot             // remove node undo.idx from root list
    64  	undoChangeRoot          // remove node undo.idx from root list, and put back undo.edge.Target instead
    65  	undoMergeRoot           // remove node undo.idx from root list, and put back its children instead
    66  )
    67  
    68  // posetUndo represents an undo pass to be performed.
    69  // It's a union of fields that can be used to store information,
    70  // and typ is the discriminant, that specifies which kind
    71  // of operation must be performed. Not all fields are always used.
    72  type posetUndo struct {
    73  	typ  undoType
    74  	idx  uint32
    75  	ID   ID
    76  	edge posetEdge
    77  }
    78  
    79  const (
    80  	// Make poset handle values as unsigned numbers.
    81  	// (TODO: remove?)
    82  	posetFlagUnsigned = 1 << iota
    83  )
    84  
    85  // A poset edge. The zero value is the null/empty edge.
    86  // Packs target node index (31 bits) and strict flag (1 bit).
    87  type posetEdge uint32
    88  
    89  func newedge(t uint32, strict bool) posetEdge {
    90  	s := uint32(0)
    91  	if strict {
    92  		s = 1
    93  	}
    94  	return posetEdge(t<<1 | s)
    95  }
    96  func (e posetEdge) Target() uint32 { return uint32(e) >> 1 }
    97  func (e posetEdge) Strict() bool   { return uint32(e)&1 != 0 }
    98  func (e posetEdge) String() string {
    99  	s := fmt.Sprint(e.Target())
   100  	if e.Strict() {
   101  		s += "*"
   102  	}
   103  	return s
   104  }
   105  
   106  // posetNode is a node of a DAG within the poset.
   107  type posetNode struct {
   108  	l, r posetEdge
   109  }
   110  
   111  // poset is a union-find data structure that can represent a partially ordered set
   112  // of SSA values. Given a binary relation that creates a partial order (eg: '<'),
   113  // clients can record relations between SSA values using SetOrder, and later
   114  // check relations (in the transitive closure) with Ordered. For instance,
   115  // if SetOrder is called to record that A<B and B<C, Ordered will later confirm
   116  // that A<C.
   117  //
   118  // It is possible to record equality relations between SSA values with SetEqual and check
   119  // equality with Equal. Equality propagates into the transitive closure for the partial
   120  // order so that if we know that A<B<C and later learn that A==D, Ordered will return
   121  // true for D<C.
   122  //
   123  // It is also possible to record inequality relations between nodes with SetNonEqual;
   124  // non-equality relations are not transitive, but they can still be useful: for instance
   125  // if we know that A<=B and later we learn that A!=B, we can deduce that A<B.
   126  // NonEqual can be used to check whether it is known that the nodes are different, either
   127  // because SetNonEqual was called before, or because we know that they are strictly ordered.
   128  //
   129  // poset will refuse to record new relations that contradict existing relations:
   130  // for instance if A<B<C, calling SetOrder for C<A will fail returning false; also
   131  // calling SetEqual for C==A will fail.
   132  //
   133  // poset is implemented as a forest of DAGs; in each DAG, if there is a path (directed)
   134  // from node A to B, it means that A<B (or A<=B). Equality is represented by mapping
   135  // two SSA values to the same DAG node; when a new equality relation is recorded
   136  // between two existing nodes, the nodes are merged, adjusting incoming and outgoing edges.
   137  //
   138  // poset is designed to be memory efficient and do little allocations during normal usage.
   139  // Most internal data structures are pre-allocated and flat, so for instance adding a
   140  // new relation does not cause any allocation. For performance reasons,
   141  // each node has only up to two outgoing edges (like a binary tree), so intermediate
   142  // "extra" nodes are required to represent more than two relations. For instance,
   143  // to record that A<I, A<J, A<K (with no known relation between I,J,K), we create the
   144  // following DAG:
   145  //
   146  //	  A
   147  //	 / \
   148  //	I  extra
   149  //	    /  \
   150  //	   J    K
   151  type poset struct {
   152  	lastidx uint32            // last generated dense index
   153  	flags   uint8             // internal flags
   154  	values  map[ID]uint32     // map SSA values to dense indexes
   155  	nodes   []posetNode       // nodes (in all DAGs)
   156  	roots   []uint32          // list of root nodes (forest)
   157  	noneq   map[uint32]bitset // non-equal relations
   158  	undo    []posetUndo       // undo chain
   159  }
   160  
   161  func newPoset() *poset {
   162  	return &poset{
   163  		values: make(map[ID]uint32),
   164  		nodes:  make([]posetNode, 1, 16),
   165  		roots:  make([]uint32, 0, 4),
   166  		noneq:  make(map[uint32]bitset),
   167  		undo:   make([]posetUndo, 0, 4),
   168  	}
   169  }
   170  
   171  func (po *poset) SetUnsigned(uns bool) {
   172  	if uns {
   173  		po.flags |= posetFlagUnsigned
   174  	} else {
   175  		po.flags &^= posetFlagUnsigned
   176  	}
   177  }
   178  
   179  // Handle children
   180  func (po *poset) setchl(i uint32, l posetEdge) { po.nodes[i].l = l }
   181  func (po *poset) setchr(i uint32, r posetEdge) { po.nodes[i].r = r }
   182  func (po *poset) chl(i uint32) uint32          { return po.nodes[i].l.Target() }
   183  func (po *poset) chr(i uint32) uint32          { return po.nodes[i].r.Target() }
   184  func (po *poset) children(i uint32) (posetEdge, posetEdge) {
   185  	return po.nodes[i].l, po.nodes[i].r
   186  }
   187  
   188  // upush records a new undo step. It can be used for simple
   189  // undo passes that record up to one index and one edge.
   190  func (po *poset) upush(typ undoType, p uint32, e posetEdge) {
   191  	po.undo = append(po.undo, posetUndo{typ: typ, idx: p, edge: e})
   192  }
   193  
   194  // upushnew pushes an undo pass for a new node
   195  func (po *poset) upushnew(id ID, idx uint32) {
   196  	po.undo = append(po.undo, posetUndo{typ: undoNewNode, ID: id, idx: idx})
   197  }
   198  
   199  // upushneq pushes a new undo pass for a nonequal relation
   200  func (po *poset) upushneq(idx1 uint32, idx2 uint32) {
   201  	po.undo = append(po.undo, posetUndo{typ: undoNonEqual, ID: ID(idx1), idx: idx2})
   202  }
   203  
   204  // upushalias pushes a new undo pass for aliasing two nodes
   205  func (po *poset) upushalias(id ID, i2 uint32) {
   206  	po.undo = append(po.undo, posetUndo{typ: undoAliasNode, ID: id, idx: i2})
   207  }
   208  
   209  // addchild adds i2 as direct child of i1.
   210  func (po *poset) addchild(i1, i2 uint32, strict bool) {
   211  	i1l, i1r := po.children(i1)
   212  	e2 := newedge(i2, strict)
   213  
   214  	if i1l == 0 {
   215  		po.setchl(i1, e2)
   216  		po.upush(undoSetChl, i1, 0)
   217  	} else if i1r == 0 {
   218  		po.setchr(i1, e2)
   219  		po.upush(undoSetChr, i1, 0)
   220  	} else {
   221  		// If n1 already has two children, add an intermediate extra
   222  		// node to record the relation correctly (without relating
   223  		// n2 to other existing nodes). Use a non-deterministic value
   224  		// to decide whether to append on the left or the right, to avoid
   225  		// creating degenerated chains.
   226  		//
   227  		//      n1
   228  		//     /  \
   229  		//   i1l  extra
   230  		//        /   \
   231  		//      i1r   n2
   232  		//
   233  		extra := po.newnode(nil)
   234  		if (i1^i2)&1 != 0 { // non-deterministic
   235  			po.setchl(extra, i1r)
   236  			po.setchr(extra, e2)
   237  			po.setchr(i1, newedge(extra, false))
   238  			po.upush(undoSetChr, i1, i1r)
   239  		} else {
   240  			po.setchl(extra, i1l)
   241  			po.setchr(extra, e2)
   242  			po.setchl(i1, newedge(extra, false))
   243  			po.upush(undoSetChl, i1, i1l)
   244  		}
   245  	}
   246  }
   247  
   248  // newnode allocates a new node bound to SSA value n.
   249  // If n is nil, this is an extra node (= only used internally).
   250  func (po *poset) newnode(n *Value) uint32 {
   251  	i := po.lastidx + 1
   252  	po.lastidx++
   253  	po.nodes = append(po.nodes, posetNode{})
   254  	if n != nil {
   255  		if po.values[n.ID] != 0 {
   256  			panic("newnode for Value already inserted")
   257  		}
   258  		po.values[n.ID] = i
   259  		po.upushnew(n.ID, i)
   260  	} else {
   261  		po.upushnew(0, i)
   262  	}
   263  	return i
   264  }
   265  
   266  // lookup searches for a SSA value into the forest of DAGS, and return its node.
   267  func (po *poset) lookup(n *Value) (uint32, bool) {
   268  	i, f := po.values[n.ID]
   269  	return i, f
   270  }
   271  
   272  // aliasnewnode records that a single node n2 (not in the poset yet) is an alias
   273  // of the master node n1.
   274  func (po *poset) aliasnewnode(n1, n2 *Value) {
   275  	i1, i2 := po.values[n1.ID], po.values[n2.ID]
   276  	if i1 == 0 || i2 != 0 {
   277  		panic("aliasnewnode invalid arguments")
   278  	}
   279  
   280  	po.values[n2.ID] = i1
   281  	po.upushalias(n2.ID, 0)
   282  }
   283  
   284  // aliasnodes records that all the nodes i2s are aliases of a single master node n1.
   285  // aliasnodes takes care of rearranging the DAG, changing references of parent/children
   286  // of nodes in i2s, so that they point to n1 instead.
   287  // Complexity is O(n) (with n being the total number of nodes in the poset, not just
   288  // the number of nodes being aliased).
   289  func (po *poset) aliasnodes(n1 *Value, i2s bitset) {
   290  	i1 := po.values[n1.ID]
   291  	if i1 == 0 {
   292  		panic("aliasnode for non-existing node")
   293  	}
   294  	if i2s.Test(i1) {
   295  		panic("aliasnode i2s contains n1 node")
   296  	}
   297  
   298  	// Go through all the nodes to adjust parent/chidlren of nodes in i2s
   299  	for idx, n := range po.nodes {
   300  		// Do not touch i1 itself, otherwise we can create useless self-loops
   301  		if uint32(idx) == i1 {
   302  			continue
   303  		}
   304  		l, r := n.l, n.r
   305  
   306  		// Rename all references to i2s into i1
   307  		if i2s.Test(l.Target()) {
   308  			po.setchl(uint32(idx), newedge(i1, l.Strict()))
   309  			po.upush(undoSetChl, uint32(idx), l)
   310  		}
   311  		if i2s.Test(r.Target()) {
   312  			po.setchr(uint32(idx), newedge(i1, r.Strict()))
   313  			po.upush(undoSetChr, uint32(idx), r)
   314  		}
   315  
   316  		// Connect all children of i2s to i1 (unless those children
   317  		// are in i2s as well, in which case it would be useless)
   318  		if i2s.Test(uint32(idx)) {
   319  			if l != 0 && !i2s.Test(l.Target()) {
   320  				po.addchild(i1, l.Target(), l.Strict())
   321  			}
   322  			if r != 0 && !i2s.Test(r.Target()) {
   323  				po.addchild(i1, r.Target(), r.Strict())
   324  			}
   325  			po.setchl(uint32(idx), 0)
   326  			po.setchr(uint32(idx), 0)
   327  			po.upush(undoSetChl, uint32(idx), l)
   328  			po.upush(undoSetChr, uint32(idx), r)
   329  		}
   330  	}
   331  
   332  	// Reassign all existing IDs that point to i2 to i1.
   333  	// This includes n2.ID.
   334  	for k, v := range po.values {
   335  		if i2s.Test(v) {
   336  			po.values[k] = i1
   337  			po.upushalias(k, v)
   338  		}
   339  	}
   340  }
   341  
   342  func (po *poset) isroot(r uint32) bool {
   343  	for i := range po.roots {
   344  		if po.roots[i] == r {
   345  			return true
   346  		}
   347  	}
   348  	return false
   349  }
   350  
   351  func (po *poset) changeroot(oldr, newr uint32) {
   352  	for i := range po.roots {
   353  		if po.roots[i] == oldr {
   354  			po.roots[i] = newr
   355  			return
   356  		}
   357  	}
   358  	panic("changeroot on non-root")
   359  }
   360  
   361  func (po *poset) removeroot(r uint32) {
   362  	for i := range po.roots {
   363  		if po.roots[i] == r {
   364  			po.roots = slices.Delete(po.roots, i, i+1)
   365  			return
   366  		}
   367  	}
   368  	panic("removeroot on non-root")
   369  }
   370  
   371  // dfs performs a depth-first search within the DAG whose root is r.
   372  // f is the visit function called for each node; if it returns true,
   373  // the search is aborted and true is returned. The root node is
   374  // visited too.
   375  // If strict, ignore edges across a path until at least one
   376  // strict edge is found. For instance, for a chain A<=B<=C<D<=E<F,
   377  // a strict walk visits D,E,F.
   378  // If the visit ends, false is returned.
   379  func (po *poset) dfs(r uint32, strict bool, f func(i uint32) bool) bool {
   380  	closed := newBitset(int(po.lastidx + 1))
   381  	open := make([]uint32, 1, 64)
   382  	open[0] = r
   383  
   384  	if strict {
   385  		// Do a first DFS; walk all paths and stop when we find a strict
   386  		// edge, building a "next" list of nodes reachable through strict
   387  		// edges. This will be the bootstrap open list for the real DFS.
   388  		next := make([]uint32, 0, 64)
   389  
   390  		for len(open) > 0 {
   391  			i := open[len(open)-1]
   392  			open = open[:len(open)-1]
   393  
   394  			// Don't visit the same node twice. Notice that all nodes
   395  			// across non-strict paths are still visited at least once, so
   396  			// a non-strict path can never obscure a strict path to the
   397  			// same node.
   398  			if !closed.Test(i) {
   399  				closed.Set(i)
   400  
   401  				l, r := po.children(i)
   402  				if l != 0 {
   403  					if l.Strict() {
   404  						next = append(next, l.Target())
   405  					} else {
   406  						open = append(open, l.Target())
   407  					}
   408  				}
   409  				if r != 0 {
   410  					if r.Strict() {
   411  						next = append(next, r.Target())
   412  					} else {
   413  						open = append(open, r.Target())
   414  					}
   415  				}
   416  			}
   417  		}
   418  		open = next
   419  		closed.Reset()
   420  	}
   421  
   422  	for len(open) > 0 {
   423  		i := open[len(open)-1]
   424  		open = open[:len(open)-1]
   425  
   426  		if !closed.Test(i) {
   427  			if f(i) {
   428  				return true
   429  			}
   430  			closed.Set(i)
   431  			l, r := po.children(i)
   432  			if l != 0 {
   433  				open = append(open, l.Target())
   434  			}
   435  			if r != 0 {
   436  				open = append(open, r.Target())
   437  			}
   438  		}
   439  	}
   440  	return false
   441  }
   442  
   443  // Returns true if there is a path from i1 to i2.
   444  // If strict ==  true: if the function returns true, then i1 <  i2.
   445  // If strict == false: if the function returns true, then i1 <= i2.
   446  // If the function returns false, no relation is known.
   447  func (po *poset) reaches(i1, i2 uint32, strict bool) bool {
   448  	return po.dfs(i1, strict, func(n uint32) bool {
   449  		return n == i2
   450  	})
   451  }
   452  
   453  // findroot finds i's root, that is which DAG contains i.
   454  // Returns the root; if i is itself a root, it is returned.
   455  // Panic if i is not in any DAG.
   456  func (po *poset) findroot(i uint32) uint32 {
   457  	// TODO(rasky): if needed, a way to speed up this search is
   458  	// storing a bitset for each root using it as a mini bloom filter
   459  	// of nodes present under that root.
   460  	for _, r := range po.roots {
   461  		if po.reaches(r, i, false) {
   462  			return r
   463  		}
   464  	}
   465  	panic("findroot didn't find any root")
   466  }
   467  
   468  // mergeroot merges two DAGs into one DAG by creating a new extra root
   469  func (po *poset) mergeroot(r1, r2 uint32) uint32 {
   470  	r := po.newnode(nil)
   471  	po.setchl(r, newedge(r1, false))
   472  	po.setchr(r, newedge(r2, false))
   473  	po.changeroot(r1, r)
   474  	po.removeroot(r2)
   475  	po.upush(undoMergeRoot, r, 0)
   476  	return r
   477  }
   478  
   479  // collapsepath marks n1 and n2 as equal and collapses as equal all
   480  // nodes across all paths between n1 and n2. If a strict edge is
   481  // found, the function does not modify the DAG and returns false.
   482  // Complexity is O(n).
   483  func (po *poset) collapsepath(n1, n2 *Value) bool {
   484  	i1, i2 := po.values[n1.ID], po.values[n2.ID]
   485  	if po.reaches(i1, i2, true) {
   486  		return false
   487  	}
   488  
   489  	// Find all the paths from i1 to i2
   490  	paths := po.findpaths(i1, i2)
   491  	// Mark all nodes in all the paths as aliases of n1
   492  	// (excluding n1 itself)
   493  	paths.Clear(i1)
   494  	po.aliasnodes(n1, paths)
   495  	return true
   496  }
   497  
   498  // findpaths is a recursive function that calculates all paths from cur to dst
   499  // and return them as a bitset (the index of a node is set in the bitset if
   500  // that node is on at least one path from cur to dst).
   501  // We do a DFS from cur (stopping going deep any time we reach dst, if ever),
   502  // and mark as part of the paths any node that has a children which is already
   503  // part of the path (or is dst itself).
   504  func (po *poset) findpaths(cur, dst uint32) bitset {
   505  	seen := newBitset(int(po.lastidx + 1))
   506  	path := newBitset(int(po.lastidx + 1))
   507  	path.Set(dst)
   508  	po.findpaths1(cur, dst, seen, path)
   509  	return path
   510  }
   511  
   512  func (po *poset) findpaths1(cur, dst uint32, seen bitset, path bitset) {
   513  	if cur == dst {
   514  		return
   515  	}
   516  	seen.Set(cur)
   517  	l, r := po.chl(cur), po.chr(cur)
   518  	if !seen.Test(l) {
   519  		po.findpaths1(l, dst, seen, path)
   520  	}
   521  	if !seen.Test(r) {
   522  		po.findpaths1(r, dst, seen, path)
   523  	}
   524  	if path.Test(l) || path.Test(r) {
   525  		path.Set(cur)
   526  	}
   527  }
   528  
   529  // Check whether it is recorded that i1!=i2
   530  func (po *poset) isnoneq(i1, i2 uint32) bool {
   531  	if i1 == i2 {
   532  		return false
   533  	}
   534  	if i1 < i2 {
   535  		i1, i2 = i2, i1
   536  	}
   537  
   538  	// Check if we recorded a non-equal relation before
   539  	if bs, ok := po.noneq[i1]; ok && bs.Test(i2) {
   540  		return true
   541  	}
   542  	return false
   543  }
   544  
   545  // Record that i1!=i2
   546  func (po *poset) setnoneq(n1, n2 *Value) {
   547  	i1, f1 := po.lookup(n1)
   548  	i2, f2 := po.lookup(n2)
   549  
   550  	// If any of the nodes do not exist in the poset, allocate them. Since
   551  	// we don't know any relation (in the partial order) about them, they must
   552  	// become independent roots.
   553  	if !f1 {
   554  		i1 = po.newnode(n1)
   555  		po.roots = append(po.roots, i1)
   556  		po.upush(undoNewRoot, i1, 0)
   557  	}
   558  	if !f2 {
   559  		i2 = po.newnode(n2)
   560  		po.roots = append(po.roots, i2)
   561  		po.upush(undoNewRoot, i2, 0)
   562  	}
   563  
   564  	if i1 == i2 {
   565  		panic("setnoneq on same node")
   566  	}
   567  	if i1 < i2 {
   568  		i1, i2 = i2, i1
   569  	}
   570  	bs := po.noneq[i1]
   571  	if bs == nil {
   572  		// Given that we record non-equality relations using the
   573  		// higher index as a key, the bitsize will never change size.
   574  		// TODO(rasky): if memory is a problem, consider allocating
   575  		// a small bitset and lazily grow it when higher indices arrive.
   576  		bs = newBitset(int(i1))
   577  		po.noneq[i1] = bs
   578  	} else if bs.Test(i2) {
   579  		// Already recorded
   580  		return
   581  	}
   582  	bs.Set(i2)
   583  	po.upushneq(i1, i2)
   584  }
   585  
   586  // CheckIntegrity verifies internal integrity of a poset. It is intended
   587  // for debugging purposes.
   588  func (po *poset) CheckIntegrity() {
   589  	// Verify that each node appears in a single DAG
   590  	seen := newBitset(int(po.lastidx + 1))
   591  	for _, r := range po.roots {
   592  		if r == 0 {
   593  			panic("empty root")
   594  		}
   595  
   596  		po.dfs(r, false, func(i uint32) bool {
   597  			if seen.Test(i) {
   598  				panic("duplicate node")
   599  			}
   600  			seen.Set(i)
   601  			return false
   602  		})
   603  	}
   604  
   605  	// Verify that values contain the minimum set
   606  	for id, idx := range po.values {
   607  		if !seen.Test(idx) {
   608  			panic(fmt.Errorf("spurious value [%d]=%d", id, idx))
   609  		}
   610  	}
   611  
   612  	// Verify that only existing nodes have non-zero children
   613  	for i, n := range po.nodes {
   614  		if n.l|n.r != 0 {
   615  			if !seen.Test(uint32(i)) {
   616  				panic(fmt.Errorf("children of unknown node %d->%v", i, n))
   617  			}
   618  			if n.l.Target() == uint32(i) || n.r.Target() == uint32(i) {
   619  				panic(fmt.Errorf("self-loop on node %d", i))
   620  			}
   621  		}
   622  	}
   623  }
   624  
   625  // CheckEmpty checks that a poset is completely empty.
   626  // It can be used for debugging purposes, as a poset is supposed to
   627  // be empty after it's fully rolled back through Undo.
   628  func (po *poset) CheckEmpty() error {
   629  	if len(po.nodes) != 1 {
   630  		return fmt.Errorf("non-empty nodes list: %v", po.nodes)
   631  	}
   632  	if len(po.values) != 0 {
   633  		return fmt.Errorf("non-empty value map: %v", po.values)
   634  	}
   635  	if len(po.roots) != 0 {
   636  		return fmt.Errorf("non-empty root list: %v", po.roots)
   637  	}
   638  	if len(po.undo) != 0 {
   639  		return fmt.Errorf("non-empty undo list: %v", po.undo)
   640  	}
   641  	if po.lastidx != 0 {
   642  		return fmt.Errorf("lastidx index is not zero: %v", po.lastidx)
   643  	}
   644  	for _, bs := range po.noneq {
   645  		for _, x := range bs {
   646  			if x != 0 {
   647  				return fmt.Errorf("non-empty noneq map")
   648  			}
   649  		}
   650  	}
   651  	return nil
   652  }
   653  
   654  // DotDump dumps the poset in graphviz format to file fn, with the specified title.
   655  func (po *poset) DotDump(fn string, title string) error {
   656  	f, err := os.Create(fn)
   657  	if err != nil {
   658  		return err
   659  	}
   660  	defer f.Close()
   661  
   662  	// Create reverse index mapping (taking aliases into account)
   663  	names := make(map[uint32]string)
   664  	for id, i := range po.values {
   665  		s := names[i]
   666  		if s == "" {
   667  			s = fmt.Sprintf("v%d", id)
   668  		} else {
   669  			s += fmt.Sprintf(", v%d", id)
   670  		}
   671  		names[i] = s
   672  	}
   673  
   674  	fmt.Fprintf(f, "digraph poset {\n")
   675  	fmt.Fprintf(f, "\tedge [ fontsize=10 ]\n")
   676  	for ridx, r := range po.roots {
   677  		fmt.Fprintf(f, "\tsubgraph root%d {\n", ridx)
   678  		po.dfs(r, false, func(i uint32) bool {
   679  			fmt.Fprintf(f, "\t\tnode%d [label=<%s <font point-size=\"6\">[%d]</font>>]\n", i, names[i], i)
   680  			chl, chr := po.children(i)
   681  			for _, ch := range []posetEdge{chl, chr} {
   682  				if ch != 0 {
   683  					if ch.Strict() {
   684  						fmt.Fprintf(f, "\t\tnode%d -> node%d [label=\" <\" color=\"red\"]\n", i, ch.Target())
   685  					} else {
   686  						fmt.Fprintf(f, "\t\tnode%d -> node%d [label=\" <=\" color=\"green\"]\n", i, ch.Target())
   687  					}
   688  				}
   689  			}
   690  			return false
   691  		})
   692  		fmt.Fprintf(f, "\t}\n")
   693  	}
   694  	fmt.Fprintf(f, "\tlabelloc=\"t\"\n")
   695  	fmt.Fprintf(f, "\tlabeldistance=\"3.0\"\n")
   696  	fmt.Fprintf(f, "\tlabel=%q\n", title)
   697  	fmt.Fprintf(f, "}\n")
   698  	return nil
   699  }
   700  
   701  // Ordered reports whether n1<n2. It returns false either when it is
   702  // certain that n1<n2 is false, or if there is not enough information
   703  // to tell.
   704  // Complexity is O(n).
   705  func (po *poset) Ordered(n1, n2 *Value) bool {
   706  	if debugPoset {
   707  		defer po.CheckIntegrity()
   708  	}
   709  	if n1.ID == n2.ID {
   710  		panic("should not call Ordered with n1==n2")
   711  	}
   712  
   713  	i1, f1 := po.lookup(n1)
   714  	i2, f2 := po.lookup(n2)
   715  	if !f1 || !f2 {
   716  		return false
   717  	}
   718  
   719  	return i1 != i2 && po.reaches(i1, i2, true)
   720  }
   721  
   722  // OrderedOrEqual reports whether n1<=n2. It returns false either when it is
   723  // certain that n1<=n2 is false, or if there is not enough information
   724  // to tell.
   725  // Complexity is O(n).
   726  func (po *poset) OrderedOrEqual(n1, n2 *Value) bool {
   727  	if debugPoset {
   728  		defer po.CheckIntegrity()
   729  	}
   730  	if n1.ID == n2.ID {
   731  		panic("should not call Ordered with n1==n2")
   732  	}
   733  
   734  	i1, f1 := po.lookup(n1)
   735  	i2, f2 := po.lookup(n2)
   736  	if !f1 || !f2 {
   737  		return false
   738  	}
   739  
   740  	return i1 == i2 || po.reaches(i1, i2, false)
   741  }
   742  
   743  // Equal reports whether n1==n2. It returns false either when it is
   744  // certain that n1==n2 is false, or if there is not enough information
   745  // to tell.
   746  // Complexity is O(1).
   747  func (po *poset) Equal(n1, n2 *Value) bool {
   748  	if debugPoset {
   749  		defer po.CheckIntegrity()
   750  	}
   751  	if n1.ID == n2.ID {
   752  		panic("should not call Equal with n1==n2")
   753  	}
   754  
   755  	i1, f1 := po.lookup(n1)
   756  	i2, f2 := po.lookup(n2)
   757  	return f1 && f2 && i1 == i2
   758  }
   759  
   760  // NonEqual reports whether n1!=n2. It returns false either when it is
   761  // certain that n1!=n2 is false, or if there is not enough information
   762  // to tell.
   763  // Complexity is O(n) (because it internally calls Ordered to see if we
   764  // can infer n1!=n2 from n1<n2 or n2<n1).
   765  func (po *poset) NonEqual(n1, n2 *Value) bool {
   766  	if debugPoset {
   767  		defer po.CheckIntegrity()
   768  	}
   769  	if n1.ID == n2.ID {
   770  		panic("should not call NonEqual with n1==n2")
   771  	}
   772  
   773  	// If we never saw the nodes before, we don't
   774  	// have a recorded non-equality.
   775  	i1, f1 := po.lookup(n1)
   776  	i2, f2 := po.lookup(n2)
   777  	if !f1 || !f2 {
   778  		return false
   779  	}
   780  
   781  	// Check if we recorded inequality
   782  	if po.isnoneq(i1, i2) {
   783  		return true
   784  	}
   785  
   786  	// Check if n1<n2 or n2<n1, in which case we can infer that n1!=n2
   787  	if po.Ordered(n1, n2) || po.Ordered(n2, n1) {
   788  		return true
   789  	}
   790  
   791  	return false
   792  }
   793  
   794  // setOrder records that n1<n2 or n1<=n2 (depending on strict). Returns false
   795  // if this is a contradiction.
   796  // Implements SetOrder() and SetOrderOrEqual()
   797  func (po *poset) setOrder(n1, n2 *Value, strict bool) bool {
   798  	i1, f1 := po.lookup(n1)
   799  	i2, f2 := po.lookup(n2)
   800  
   801  	switch {
   802  	case !f1 && !f2:
   803  		// Neither n1 nor n2 are in the poset, so they are not related
   804  		// in any way to existing nodes.
   805  		// Create a new DAG to record the relation.
   806  		i1, i2 = po.newnode(n1), po.newnode(n2)
   807  		po.roots = append(po.roots, i1)
   808  		po.upush(undoNewRoot, i1, 0)
   809  		po.addchild(i1, i2, strict)
   810  
   811  	case f1 && !f2:
   812  		// n1 is in one of the DAGs, while n2 is not. Add n2 as children
   813  		// of n1.
   814  		i2 = po.newnode(n2)
   815  		po.addchild(i1, i2, strict)
   816  
   817  	case !f1 && f2:
   818  		// n1 is not in any DAG but n2 is. If n2 is a root, we can put
   819  		// n1 in its place as a root; otherwise, we need to create a new
   820  		// extra root to record the relation.
   821  		i1 = po.newnode(n1)
   822  
   823  		if po.isroot(i2) {
   824  			po.changeroot(i2, i1)
   825  			po.upush(undoChangeRoot, i1, newedge(i2, strict))
   826  			po.addchild(i1, i2, strict)
   827  			return true
   828  		}
   829  
   830  		// Search for i2's root; this requires a O(n) search on all
   831  		// DAGs
   832  		r := po.findroot(i2)
   833  
   834  		// Re-parent as follows:
   835  		//
   836  		//                  extra
   837  		//     r            /   \
   838  		//      \   ===>   r    i1
   839  		//      i2          \   /
   840  		//                    i2
   841  		//
   842  		extra := po.newnode(nil)
   843  		po.changeroot(r, extra)
   844  		po.upush(undoChangeRoot, extra, newedge(r, false))
   845  		po.addchild(extra, r, false)
   846  		po.addchild(extra, i1, false)
   847  		po.addchild(i1, i2, strict)
   848  
   849  	case f1 && f2:
   850  		// If the nodes are aliased, fail only if we're setting a strict order
   851  		// (that is, we cannot set n1<n2 if n1==n2).
   852  		if i1 == i2 {
   853  			return !strict
   854  		}
   855  
   856  		// If we are trying to record n1<=n2 but we learned that n1!=n2,
   857  		// record n1<n2, as it provides more information.
   858  		if !strict && po.isnoneq(i1, i2) {
   859  			strict = true
   860  		}
   861  
   862  		// Both n1 and n2 are in the poset. This is the complex part of the algorithm
   863  		// as we need to find many different cases and DAG shapes.
   864  
   865  		// Check if n1 somehow reaches n2
   866  		if po.reaches(i1, i2, false) {
   867  			// This is the table of all cases we need to handle:
   868  			//
   869  			//      DAG          New      Action
   870  			//      ---------------------------------------------------
   871  			// #1:  N1<=X<=N2 |  N1<=N2 | do nothing
   872  			// #2:  N1<=X<=N2 |  N1<N2  | add strict edge (N1<N2)
   873  			// #3:  N1<X<N2   |  N1<=N2 | do nothing (we already know more)
   874  			// #4:  N1<X<N2   |  N1<N2  | do nothing
   875  
   876  			// Check if we're in case #2
   877  			if strict && !po.reaches(i1, i2, true) {
   878  				po.addchild(i1, i2, true)
   879  				return true
   880  			}
   881  
   882  			// Case #1, #3, or #4: nothing to do
   883  			return true
   884  		}
   885  
   886  		// Check if n2 somehow reaches n1
   887  		if po.reaches(i2, i1, false) {
   888  			// This is the table of all cases we need to handle:
   889  			//
   890  			//      DAG           New      Action
   891  			//      ---------------------------------------------------
   892  			// #5:  N2<=X<=N1  |  N1<=N2 | collapse path (learn that N1=X=N2)
   893  			// #6:  N2<=X<=N1  |  N1<N2  | contradiction
   894  			// #7:  N2<X<N1    |  N1<=N2 | contradiction in the path
   895  			// #8:  N2<X<N1    |  N1<N2  | contradiction
   896  
   897  			if strict {
   898  				// Cases #6 and #8: contradiction
   899  				return false
   900  			}
   901  
   902  			// We're in case #5 or #7. Try to collapse path, and that will
   903  			// fail if it realizes that we are in case #7.
   904  			return po.collapsepath(n2, n1)
   905  		}
   906  
   907  		// We don't know of any existing relation between n1 and n2. They could
   908  		// be part of the same DAG or not.
   909  		// Find their roots to check whether they are in the same DAG.
   910  		r1, r2 := po.findroot(i1), po.findroot(i2)
   911  		if r1 != r2 {
   912  			// We need to merge the two DAGs to record a relation between the nodes
   913  			po.mergeroot(r1, r2)
   914  		}
   915  
   916  		// Connect n1 and n2
   917  		po.addchild(i1, i2, strict)
   918  	}
   919  
   920  	return true
   921  }
   922  
   923  // SetOrder records that n1<n2. Returns false if this is a contradiction
   924  // Complexity is O(1) if n2 was never seen before, or O(n) otherwise.
   925  func (po *poset) SetOrder(n1, n2 *Value) bool {
   926  	if debugPoset {
   927  		defer po.CheckIntegrity()
   928  	}
   929  	if n1.ID == n2.ID {
   930  		panic("should not call SetOrder with n1==n2")
   931  	}
   932  	return po.setOrder(n1, n2, true)
   933  }
   934  
   935  // SetOrderOrEqual records that n1<=n2. Returns false if this is a contradiction
   936  // Complexity is O(1) if n2 was never seen before, or O(n) otherwise.
   937  func (po *poset) SetOrderOrEqual(n1, n2 *Value) bool {
   938  	if debugPoset {
   939  		defer po.CheckIntegrity()
   940  	}
   941  	if n1.ID == n2.ID {
   942  		panic("should not call SetOrder with n1==n2")
   943  	}
   944  	return po.setOrder(n1, n2, false)
   945  }
   946  
   947  // SetEqual records that n1==n2. Returns false if this is a contradiction
   948  // (that is, if it is already recorded that n1<n2 or n2<n1).
   949  // Complexity is O(1) if n2 was never seen before, or O(n) otherwise.
   950  func (po *poset) SetEqual(n1, n2 *Value) bool {
   951  	if debugPoset {
   952  		defer po.CheckIntegrity()
   953  	}
   954  	if n1.ID == n2.ID {
   955  		panic("should not call Add with n1==n2")
   956  	}
   957  
   958  	i1, f1 := po.lookup(n1)
   959  	i2, f2 := po.lookup(n2)
   960  
   961  	switch {
   962  	case !f1 && !f2:
   963  		i1 = po.newnode(n1)
   964  		po.roots = append(po.roots, i1)
   965  		po.upush(undoNewRoot, i1, 0)
   966  		po.aliasnewnode(n1, n2)
   967  	case f1 && !f2:
   968  		po.aliasnewnode(n1, n2)
   969  	case !f1 && f2:
   970  		po.aliasnewnode(n2, n1)
   971  	case f1 && f2:
   972  		if i1 == i2 {
   973  			// Already aliased, ignore
   974  			return true
   975  		}
   976  
   977  		// If we recorded that n1!=n2, this is a contradiction.
   978  		if po.isnoneq(i1, i2) {
   979  			return false
   980  		}
   981  
   982  		// If we already knew that n1<=n2, we can collapse the path to
   983  		// record n1==n2 (and vice versa).
   984  		if po.reaches(i1, i2, false) {
   985  			return po.collapsepath(n1, n2)
   986  		}
   987  		if po.reaches(i2, i1, false) {
   988  			return po.collapsepath(n2, n1)
   989  		}
   990  
   991  		r1 := po.findroot(i1)
   992  		r2 := po.findroot(i2)
   993  		if r1 != r2 {
   994  			// Merge the two DAGs so we can record relations between the nodes
   995  			po.mergeroot(r1, r2)
   996  		}
   997  
   998  		// Set n2 as alias of n1. This will also update all the references
   999  		// to n2 to become references to n1
  1000  		i2s := newBitset(int(po.lastidx) + 1)
  1001  		i2s.Set(i2)
  1002  		po.aliasnodes(n1, i2s)
  1003  	}
  1004  	return true
  1005  }
  1006  
  1007  // SetNonEqual records that n1!=n2. Returns false if this is a contradiction
  1008  // (that is, if it is already recorded that n1==n2).
  1009  // Complexity is O(n).
  1010  func (po *poset) SetNonEqual(n1, n2 *Value) bool {
  1011  	if debugPoset {
  1012  		defer po.CheckIntegrity()
  1013  	}
  1014  	if n1.ID == n2.ID {
  1015  		panic("should not call SetNonEqual with n1==n2")
  1016  	}
  1017  
  1018  	// Check whether the nodes are already in the poset
  1019  	i1, f1 := po.lookup(n1)
  1020  	i2, f2 := po.lookup(n2)
  1021  
  1022  	// If either node wasn't present, we just record the new relation
  1023  	// and exit.
  1024  	if !f1 || !f2 {
  1025  		po.setnoneq(n1, n2)
  1026  		return true
  1027  	}
  1028  
  1029  	// See if we already know this, in which case there's nothing to do.
  1030  	if po.isnoneq(i1, i2) {
  1031  		return true
  1032  	}
  1033  
  1034  	// Check if we're contradicting an existing equality relation
  1035  	if po.Equal(n1, n2) {
  1036  		return false
  1037  	}
  1038  
  1039  	// Record non-equality
  1040  	po.setnoneq(n1, n2)
  1041  
  1042  	// If we know that i1<=i2 but not i1<i2, learn that as we
  1043  	// now know that they are not equal. Do the same for i2<=i1.
  1044  	// Do this check only if both nodes were already in the DAG,
  1045  	// otherwise there cannot be an existing relation.
  1046  	if po.reaches(i1, i2, false) && !po.reaches(i1, i2, true) {
  1047  		po.addchild(i1, i2, true)
  1048  	}
  1049  	if po.reaches(i2, i1, false) && !po.reaches(i2, i1, true) {
  1050  		po.addchild(i2, i1, true)
  1051  	}
  1052  
  1053  	return true
  1054  }
  1055  
  1056  // Checkpoint saves the current state of the DAG so that it's possible
  1057  // to later undo this state.
  1058  // Complexity is O(1).
  1059  func (po *poset) Checkpoint() {
  1060  	po.undo = append(po.undo, posetUndo{typ: undoCheckpoint})
  1061  }
  1062  
  1063  // Undo restores the state of the poset to the previous checkpoint.
  1064  // Complexity depends on the type of operations that were performed
  1065  // since the last checkpoint; each Set* operation creates an undo
  1066  // pass which Undo has to revert with a worst-case complexity of O(n).
  1067  func (po *poset) Undo() {
  1068  	if len(po.undo) == 0 {
  1069  		panic("empty undo stack")
  1070  	}
  1071  	if debugPoset {
  1072  		defer po.CheckIntegrity()
  1073  	}
  1074  
  1075  	for len(po.undo) > 0 {
  1076  		pass := po.undo[len(po.undo)-1]
  1077  		po.undo = po.undo[:len(po.undo)-1]
  1078  
  1079  		switch pass.typ {
  1080  		case undoCheckpoint:
  1081  			return
  1082  
  1083  		case undoSetChl:
  1084  			po.setchl(pass.idx, pass.edge)
  1085  
  1086  		case undoSetChr:
  1087  			po.setchr(pass.idx, pass.edge)
  1088  
  1089  		case undoNonEqual:
  1090  			po.noneq[uint32(pass.ID)].Clear(pass.idx)
  1091  
  1092  		case undoNewNode:
  1093  			if pass.idx != po.lastidx {
  1094  				panic("invalid newnode index")
  1095  			}
  1096  			if pass.ID != 0 {
  1097  				if po.values[pass.ID] != pass.idx {
  1098  					panic("invalid newnode undo pass")
  1099  				}
  1100  				delete(po.values, pass.ID)
  1101  			}
  1102  			po.setchl(pass.idx, 0)
  1103  			po.setchr(pass.idx, 0)
  1104  			po.nodes = po.nodes[:pass.idx]
  1105  			po.lastidx--
  1106  
  1107  		case undoAliasNode:
  1108  			ID, prev := pass.ID, pass.idx
  1109  			cur := po.values[ID]
  1110  			if prev == 0 {
  1111  				// Born as an alias, die as an alias
  1112  				delete(po.values, ID)
  1113  			} else {
  1114  				if cur == prev {
  1115  					panic("invalid aliasnode undo pass")
  1116  				}
  1117  				// Give it back previous value
  1118  				po.values[ID] = prev
  1119  			}
  1120  
  1121  		case undoNewRoot:
  1122  			i := pass.idx
  1123  			l, r := po.children(i)
  1124  			if l|r != 0 {
  1125  				panic("non-empty root in undo newroot")
  1126  			}
  1127  			po.removeroot(i)
  1128  
  1129  		case undoChangeRoot:
  1130  			i := pass.idx
  1131  			l, r := po.children(i)
  1132  			if l|r != 0 {
  1133  				panic("non-empty root in undo changeroot")
  1134  			}
  1135  			po.changeroot(i, pass.edge.Target())
  1136  
  1137  		case undoMergeRoot:
  1138  			i := pass.idx
  1139  			l, r := po.children(i)
  1140  			po.changeroot(i, l.Target())
  1141  			po.roots = append(po.roots, r.Target())
  1142  
  1143  		default:
  1144  			panic(pass.typ)
  1145  		}
  1146  	}
  1147  
  1148  	if debugPoset && po.CheckEmpty() != nil {
  1149  		panic("poset not empty at the end of undo")
  1150  	}
  1151  }
  1152  

View as plain text