Source file src/cmd/compile/internal/ssa/poset.go

     1  // Copyright 2018 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package ssa
     6  
     7  import (
     8  	"fmt"
     9  	"os"
    10  	"slices"
    11  )
    12  
    13  // If true, check poset integrity after every mutation
    14  var debugPoset = false
    15  
    16  const uintSize = 32 << (^uint(0) >> 63) // 32 or 64
    17  
    18  // bitset is a bit array for dense indexes.
    19  type bitset []uint
    20  
    21  func newBitset(n int) bitset {
    22  	return make(bitset, (n+uintSize-1)/uintSize)
    23  }
    24  
    25  func (bs bitset) Reset() {
    26  	clear(bs)
    27  }
    28  
    29  func (bs bitset) Set(idx uint32) {
    30  	bs[idx/uintSize] |= 1 << (idx % uintSize)
    31  }
    32  
    33  func (bs bitset) Clear(idx uint32) {
    34  	bs[idx/uintSize] &^= 1 << (idx % uintSize)
    35  }
    36  
    37  func (bs bitset) Test(idx uint32) bool {
    38  	return bs[idx/uintSize]&(1<<(idx%uintSize)) != 0
    39  }
    40  
    41  type undoType uint8
    42  
    43  const (
    44  	undoInvalid    undoType = iota
    45  	undoCheckpoint          // a checkpoint to group undo passes
    46  	undoSetChl              // change back left child of undo.idx to undo.edge
    47  	undoSetChr              // change back right child of undo.idx to undo.edge
    48  	undoNonEqual            // forget that SSA value undo.ID is non-equal to undo.idx (another ID)
    49  	undoNewNode             // remove new node created for SSA value undo.ID
    50  	undoAliasNode           // unalias SSA value undo.ID so that it points back to node index undo.idx
    51  	undoNewRoot             // remove node undo.idx from root list
    52  	undoChangeRoot          // remove node undo.idx from root list, and put back undo.edge.Target instead
    53  	undoMergeRoot           // remove node undo.idx from root list, and put back its children instead
    54  )
    55  
    56  // posetUndo represents an undo pass to be performed.
    57  // It's a union of fields that can be used to store information,
    58  // and typ is the discriminant, that specifies which kind
    59  // of operation must be performed. Not all fields are always used.
    60  type posetUndo struct {
    61  	typ  undoType
    62  	idx  uint32
    63  	ID   ID
    64  	edge posetEdge
    65  }
    66  
    67  const (
    68  	// Make poset handle values as unsigned numbers.
    69  	// (TODO: remove?)
    70  	posetFlagUnsigned = 1 << iota
    71  )
    72  
    73  // A poset edge. The zero value is the null/empty edge.
    74  // Packs target node index (31 bits) and strict flag (1 bit).
    75  type posetEdge uint32
    76  
    77  func newedge(t uint32, strict bool) posetEdge {
    78  	s := uint32(0)
    79  	if strict {
    80  		s = 1
    81  	}
    82  	return posetEdge(t<<1 | s)
    83  }
    84  func (e posetEdge) Target() uint32 { return uint32(e) >> 1 }
    85  func (e posetEdge) Strict() bool   { return uint32(e)&1 != 0 }
    86  func (e posetEdge) String() string {
    87  	s := fmt.Sprint(e.Target())
    88  	if e.Strict() {
    89  		s += "*"
    90  	}
    91  	return s
    92  }
    93  
    94  // posetNode is a node of a DAG within the poset.
    95  type posetNode struct {
    96  	l, r posetEdge
    97  }
    98  
    99  // poset is a union-find data structure that can represent a partially ordered set
   100  // of SSA values. Given a binary relation that creates a partial order (eg: '<'),
   101  // clients can record relations between SSA values using SetOrder, and later
   102  // check relations (in the transitive closure) with Ordered. For instance,
   103  // if SetOrder is called to record that A<B and B<C, Ordered will later confirm
   104  // that A<C.
   105  //
   106  // It is possible to record equality relations between SSA values with SetEqual and check
   107  // equality with Equal. Equality propagates into the transitive closure for the partial
   108  // order so that if we know that A<B<C and later learn that A==D, Ordered will return
   109  // true for D<C.
   110  //
   111  // It is also possible to record inequality relations between nodes with SetNonEqual;
   112  // non-equality relations are not transitive, but they can still be useful: for instance
   113  // if we know that A<=B and later we learn that A!=B, we can deduce that A<B.
   114  // NonEqual can be used to check whether it is known that the nodes are different, either
   115  // because SetNonEqual was called before, or because we know that they are strictly ordered.
   116  //
   117  // poset will refuse to record new relations that contradict existing relations:
   118  // for instance if A<B<C, calling SetOrder for C<A will fail returning false; also
   119  // calling SetEqual for C==A will fail.
   120  //
   121  // poset is implemented as a forest of DAGs; in each DAG, if there is a path (directed)
   122  // from node A to B, it means that A<B (or A<=B). Equality is represented by mapping
   123  // two SSA values to the same DAG node; when a new equality relation is recorded
   124  // between two existing nodes, the nodes are merged, adjusting incoming and outgoing edges.
   125  //
   126  // poset is designed to be memory efficient and do little allocations during normal usage.
   127  // Most internal data structures are pre-allocated and flat, so for instance adding a
   128  // new relation does not cause any allocation. For performance reasons,
   129  // each node has only up to two outgoing edges (like a binary tree), so intermediate
   130  // "extra" nodes are required to represent more than two relations. For instance,
   131  // to record that A<I, A<J, A<K (with no known relation between I,J,K), we create the
   132  // following DAG:
   133  //
   134  //	  A
   135  //	 / \
   136  //	I  extra
   137  //	    /  \
   138  //	   J    K
   139  type poset struct {
   140  	lastidx uint32            // last generated dense index
   141  	flags   uint8             // internal flags
   142  	values  map[ID]uint32     // map SSA values to dense indexes
   143  	nodes   []posetNode       // nodes (in all DAGs)
   144  	roots   []uint32          // list of root nodes (forest)
   145  	noneq   map[uint32]bitset // non-equal relations
   146  	undo    []posetUndo       // undo chain
   147  }
   148  
   149  func newPoset() *poset {
   150  	return &poset{
   151  		values: make(map[ID]uint32),
   152  		nodes:  make([]posetNode, 1, 16),
   153  		roots:  make([]uint32, 0, 4),
   154  		noneq:  make(map[uint32]bitset),
   155  		undo:   make([]posetUndo, 0, 4),
   156  	}
   157  }
   158  
   159  func (po *poset) SetUnsigned(uns bool) {
   160  	if uns {
   161  		po.flags |= posetFlagUnsigned
   162  	} else {
   163  		po.flags &^= posetFlagUnsigned
   164  	}
   165  }
   166  
   167  // Handle children
   168  func (po *poset) setchl(i uint32, l posetEdge) { po.nodes[i].l = l }
   169  func (po *poset) setchr(i uint32, r posetEdge) { po.nodes[i].r = r }
   170  func (po *poset) chl(i uint32) uint32          { return po.nodes[i].l.Target() }
   171  func (po *poset) chr(i uint32) uint32          { return po.nodes[i].r.Target() }
   172  func (po *poset) children(i uint32) (posetEdge, posetEdge) {
   173  	return po.nodes[i].l, po.nodes[i].r
   174  }
   175  
   176  // upush records a new undo step. It can be used for simple
   177  // undo passes that record up to one index and one edge.
   178  func (po *poset) upush(typ undoType, p uint32, e posetEdge) {
   179  	po.undo = append(po.undo, posetUndo{typ: typ, idx: p, edge: e})
   180  }
   181  
   182  // upushnew pushes an undo pass for a new node
   183  func (po *poset) upushnew(id ID, idx uint32) {
   184  	po.undo = append(po.undo, posetUndo{typ: undoNewNode, ID: id, idx: idx})
   185  }
   186  
   187  // upushneq pushes a new undo pass for a nonequal relation
   188  func (po *poset) upushneq(idx1 uint32, idx2 uint32) {
   189  	po.undo = append(po.undo, posetUndo{typ: undoNonEqual, ID: ID(idx1), idx: idx2})
   190  }
   191  
   192  // upushalias pushes a new undo pass for aliasing two nodes
   193  func (po *poset) upushalias(id ID, i2 uint32) {
   194  	po.undo = append(po.undo, posetUndo{typ: undoAliasNode, ID: id, idx: i2})
   195  }
   196  
   197  // addchild adds i2 as direct child of i1.
   198  func (po *poset) addchild(i1, i2 uint32, strict bool) {
   199  	i1l, i1r := po.children(i1)
   200  	e2 := newedge(i2, strict)
   201  
   202  	if i1l == 0 {
   203  		po.setchl(i1, e2)
   204  		po.upush(undoSetChl, i1, 0)
   205  	} else if i1r == 0 {
   206  		po.setchr(i1, e2)
   207  		po.upush(undoSetChr, i1, 0)
   208  	} else {
   209  		// If n1 already has two children, add an intermediate extra
   210  		// node to record the relation correctly (without relating
   211  		// n2 to other existing nodes). Use a non-deterministic value
   212  		// to decide whether to append on the left or the right, to avoid
   213  		// creating degenerated chains.
   214  		//
   215  		//      n1
   216  		//     /  \
   217  		//   i1l  extra
   218  		//        /   \
   219  		//      i1r   n2
   220  		//
   221  		extra := po.newnode(nil)
   222  		if (i1^i2)&1 != 0 { // non-deterministic
   223  			po.setchl(extra, i1r)
   224  			po.setchr(extra, e2)
   225  			po.setchr(i1, newedge(extra, false))
   226  			po.upush(undoSetChr, i1, i1r)
   227  		} else {
   228  			po.setchl(extra, i1l)
   229  			po.setchr(extra, e2)
   230  			po.setchl(i1, newedge(extra, false))
   231  			po.upush(undoSetChl, i1, i1l)
   232  		}
   233  	}
   234  }
   235  
   236  // newnode allocates a new node bound to SSA value n.
   237  // If n is nil, this is an extra node (= only used internally).
   238  func (po *poset) newnode(n *Value) uint32 {
   239  	i := po.lastidx + 1
   240  	po.lastidx++
   241  	po.nodes = append(po.nodes, posetNode{})
   242  	if n != nil {
   243  		if po.values[n.ID] != 0 {
   244  			panic("newnode for Value already inserted")
   245  		}
   246  		po.values[n.ID] = i
   247  		po.upushnew(n.ID, i)
   248  	} else {
   249  		po.upushnew(0, i)
   250  	}
   251  	return i
   252  }
   253  
   254  // lookup searches for a SSA value into the forest of DAGS, and return its node.
   255  func (po *poset) lookup(n *Value) (uint32, bool) {
   256  	i, f := po.values[n.ID]
   257  	return i, f
   258  }
   259  
   260  // aliasnewnode records that a single node n2 (not in the poset yet) is an alias
   261  // of the master node n1.
   262  func (po *poset) aliasnewnode(n1, n2 *Value) {
   263  	i1, i2 := po.values[n1.ID], po.values[n2.ID]
   264  	if i1 == 0 || i2 != 0 {
   265  		panic("aliasnewnode invalid arguments")
   266  	}
   267  
   268  	po.values[n2.ID] = i1
   269  	po.upushalias(n2.ID, 0)
   270  }
   271  
   272  // aliasnodes records that all the nodes i2s are aliases of a single master node n1.
   273  // aliasnodes takes care of rearranging the DAG, changing references of parent/children
   274  // of nodes in i2s, so that they point to n1 instead.
   275  // Complexity is O(n) (with n being the total number of nodes in the poset, not just
   276  // the number of nodes being aliased).
   277  func (po *poset) aliasnodes(n1 *Value, i2s bitset) {
   278  	i1 := po.values[n1.ID]
   279  	if i1 == 0 {
   280  		panic("aliasnode for non-existing node")
   281  	}
   282  	if i2s.Test(i1) {
   283  		panic("aliasnode i2s contains n1 node")
   284  	}
   285  
   286  	// Go through all the nodes to adjust parent/chidlren of nodes in i2s
   287  	for idx, n := range po.nodes {
   288  		// Do not touch i1 itself, otherwise we can create useless self-loops
   289  		if uint32(idx) == i1 {
   290  			continue
   291  		}
   292  		l, r := n.l, n.r
   293  
   294  		// Rename all references to i2s into i1
   295  		if i2s.Test(l.Target()) {
   296  			po.setchl(uint32(idx), newedge(i1, l.Strict()))
   297  			po.upush(undoSetChl, uint32(idx), l)
   298  		}
   299  		if i2s.Test(r.Target()) {
   300  			po.setchr(uint32(idx), newedge(i1, r.Strict()))
   301  			po.upush(undoSetChr, uint32(idx), r)
   302  		}
   303  
   304  		// Connect all children of i2s to i1 (unless those children
   305  		// are in i2s as well, in which case it would be useless)
   306  		if i2s.Test(uint32(idx)) {
   307  			if l != 0 && !i2s.Test(l.Target()) {
   308  				po.addchild(i1, l.Target(), l.Strict())
   309  			}
   310  			if r != 0 && !i2s.Test(r.Target()) {
   311  				po.addchild(i1, r.Target(), r.Strict())
   312  			}
   313  			po.setchl(uint32(idx), 0)
   314  			po.setchr(uint32(idx), 0)
   315  			po.upush(undoSetChl, uint32(idx), l)
   316  			po.upush(undoSetChr, uint32(idx), r)
   317  		}
   318  	}
   319  
   320  	// Reassign all existing IDs that point to i2 to i1.
   321  	// This includes n2.ID.
   322  	for k, v := range po.values {
   323  		if i2s.Test(v) {
   324  			po.values[k] = i1
   325  			po.upushalias(k, v)
   326  		}
   327  	}
   328  }
   329  
   330  func (po *poset) isroot(r uint32) bool {
   331  	for i := range po.roots {
   332  		if po.roots[i] == r {
   333  			return true
   334  		}
   335  	}
   336  	return false
   337  }
   338  
   339  func (po *poset) changeroot(oldr, newr uint32) {
   340  	for i := range po.roots {
   341  		if po.roots[i] == oldr {
   342  			po.roots[i] = newr
   343  			return
   344  		}
   345  	}
   346  	panic("changeroot on non-root")
   347  }
   348  
   349  func (po *poset) removeroot(r uint32) {
   350  	for i := range po.roots {
   351  		if po.roots[i] == r {
   352  			po.roots = slices.Delete(po.roots, i, i+1)
   353  			return
   354  		}
   355  	}
   356  	panic("removeroot on non-root")
   357  }
   358  
   359  // dfs performs a depth-first search within the DAG whose root is r.
   360  // f is the visit function called for each node; if it returns true,
   361  // the search is aborted and true is returned. The root node is
   362  // visited too.
   363  // If strict, ignore edges across a path until at least one
   364  // strict edge is found. For instance, for a chain A<=B<=C<D<=E<F,
   365  // a strict walk visits D,E,F.
   366  // If the visit ends, false is returned.
   367  func (po *poset) dfs(r uint32, strict bool, f func(i uint32) bool) bool {
   368  	closed := newBitset(int(po.lastidx + 1))
   369  	open := make([]uint32, 1, 64)
   370  	open[0] = r
   371  
   372  	if strict {
   373  		// Do a first DFS; walk all paths and stop when we find a strict
   374  		// edge, building a "next" list of nodes reachable through strict
   375  		// edges. This will be the bootstrap open list for the real DFS.
   376  		next := make([]uint32, 0, 64)
   377  
   378  		for len(open) > 0 {
   379  			i := open[len(open)-1]
   380  			open = open[:len(open)-1]
   381  
   382  			// Don't visit the same node twice. Notice that all nodes
   383  			// across non-strict paths are still visited at least once, so
   384  			// a non-strict path can never obscure a strict path to the
   385  			// same node.
   386  			if !closed.Test(i) {
   387  				closed.Set(i)
   388  
   389  				l, r := po.children(i)
   390  				if l != 0 {
   391  					if l.Strict() {
   392  						next = append(next, l.Target())
   393  					} else {
   394  						open = append(open, l.Target())
   395  					}
   396  				}
   397  				if r != 0 {
   398  					if r.Strict() {
   399  						next = append(next, r.Target())
   400  					} else {
   401  						open = append(open, r.Target())
   402  					}
   403  				}
   404  			}
   405  		}
   406  		open = next
   407  		closed.Reset()
   408  	}
   409  
   410  	for len(open) > 0 {
   411  		i := open[len(open)-1]
   412  		open = open[:len(open)-1]
   413  
   414  		if !closed.Test(i) {
   415  			if f(i) {
   416  				return true
   417  			}
   418  			closed.Set(i)
   419  			l, r := po.children(i)
   420  			if l != 0 {
   421  				open = append(open, l.Target())
   422  			}
   423  			if r != 0 {
   424  				open = append(open, r.Target())
   425  			}
   426  		}
   427  	}
   428  	return false
   429  }
   430  
   431  // Returns true if there is a path from i1 to i2.
   432  // If strict ==  true: if the function returns true, then i1 <  i2.
   433  // If strict == false: if the function returns true, then i1 <= i2.
   434  // If the function returns false, no relation is known.
   435  func (po *poset) reaches(i1, i2 uint32, strict bool) bool {
   436  	return po.dfs(i1, strict, func(n uint32) bool {
   437  		return n == i2
   438  	})
   439  }
   440  
   441  // findroot finds i's root, that is which DAG contains i.
   442  // Returns the root; if i is itself a root, it is returned.
   443  // Panic if i is not in any DAG.
   444  func (po *poset) findroot(i uint32) uint32 {
   445  	// TODO(rasky): if needed, a way to speed up this search is
   446  	// storing a bitset for each root using it as a mini bloom filter
   447  	// of nodes present under that root.
   448  	for _, r := range po.roots {
   449  		if po.reaches(r, i, false) {
   450  			return r
   451  		}
   452  	}
   453  	panic("findroot didn't find any root")
   454  }
   455  
   456  // mergeroot merges two DAGs into one DAG by creating a new extra root
   457  func (po *poset) mergeroot(r1, r2 uint32) uint32 {
   458  	r := po.newnode(nil)
   459  	po.setchl(r, newedge(r1, false))
   460  	po.setchr(r, newedge(r2, false))
   461  	po.changeroot(r1, r)
   462  	po.removeroot(r2)
   463  	po.upush(undoMergeRoot, r, 0)
   464  	return r
   465  }
   466  
   467  // collapsepath marks n1 and n2 as equal and collapses as equal all
   468  // nodes across all paths between n1 and n2. If a strict edge is
   469  // found, the function does not modify the DAG and returns false.
   470  // Complexity is O(n).
   471  func (po *poset) collapsepath(n1, n2 *Value) bool {
   472  	i1, i2 := po.values[n1.ID], po.values[n2.ID]
   473  	if po.reaches(i1, i2, true) {
   474  		return false
   475  	}
   476  
   477  	// Find all the paths from i1 to i2
   478  	paths := po.findpaths(i1, i2)
   479  	// Mark all nodes in all the paths as aliases of n1
   480  	// (excluding n1 itself)
   481  	paths.Clear(i1)
   482  	po.aliasnodes(n1, paths)
   483  	return true
   484  }
   485  
   486  // findpaths is a recursive function that calculates all paths from cur to dst
   487  // and return them as a bitset (the index of a node is set in the bitset if
   488  // that node is on at least one path from cur to dst).
   489  // We do a DFS from cur (stopping going deep any time we reach dst, if ever),
   490  // and mark as part of the paths any node that has a children which is already
   491  // part of the path (or is dst itself).
   492  func (po *poset) findpaths(cur, dst uint32) bitset {
   493  	seen := newBitset(int(po.lastidx + 1))
   494  	path := newBitset(int(po.lastidx + 1))
   495  	path.Set(dst)
   496  	po.findpaths1(cur, dst, seen, path)
   497  	return path
   498  }
   499  
   500  func (po *poset) findpaths1(cur, dst uint32, seen bitset, path bitset) {
   501  	if cur == dst {
   502  		return
   503  	}
   504  	seen.Set(cur)
   505  	l, r := po.chl(cur), po.chr(cur)
   506  	if !seen.Test(l) {
   507  		po.findpaths1(l, dst, seen, path)
   508  	}
   509  	if !seen.Test(r) {
   510  		po.findpaths1(r, dst, seen, path)
   511  	}
   512  	if path.Test(l) || path.Test(r) {
   513  		path.Set(cur)
   514  	}
   515  }
   516  
   517  // Check whether it is recorded that i1!=i2
   518  func (po *poset) isnoneq(i1, i2 uint32) bool {
   519  	if i1 == i2 {
   520  		return false
   521  	}
   522  	if i1 < i2 {
   523  		i1, i2 = i2, i1
   524  	}
   525  
   526  	// Check if we recorded a non-equal relation before
   527  	if bs, ok := po.noneq[i1]; ok && bs.Test(i2) {
   528  		return true
   529  	}
   530  	return false
   531  }
   532  
   533  // Record that i1!=i2
   534  func (po *poset) setnoneq(n1, n2 *Value) {
   535  	i1, f1 := po.lookup(n1)
   536  	i2, f2 := po.lookup(n2)
   537  
   538  	// If any of the nodes do not exist in the poset, allocate them. Since
   539  	// we don't know any relation (in the partial order) about them, they must
   540  	// become independent roots.
   541  	if !f1 {
   542  		i1 = po.newnode(n1)
   543  		po.roots = append(po.roots, i1)
   544  		po.upush(undoNewRoot, i1, 0)
   545  	}
   546  	if !f2 {
   547  		i2 = po.newnode(n2)
   548  		po.roots = append(po.roots, i2)
   549  		po.upush(undoNewRoot, i2, 0)
   550  	}
   551  
   552  	if i1 == i2 {
   553  		panic("setnoneq on same node")
   554  	}
   555  	if i1 < i2 {
   556  		i1, i2 = i2, i1
   557  	}
   558  	bs := po.noneq[i1]
   559  	if bs == nil {
   560  		// Given that we record non-equality relations using the
   561  		// higher index as a key, the bitsize will never change size.
   562  		// TODO(rasky): if memory is a problem, consider allocating
   563  		// a small bitset and lazily grow it when higher indices arrive.
   564  		bs = newBitset(int(i1))
   565  		po.noneq[i1] = bs
   566  	} else if bs.Test(i2) {
   567  		// Already recorded
   568  		return
   569  	}
   570  	bs.Set(i2)
   571  	po.upushneq(i1, i2)
   572  }
   573  
   574  // CheckIntegrity verifies internal integrity of a poset. It is intended
   575  // for debugging purposes.
   576  func (po *poset) CheckIntegrity() {
   577  	// Verify that each node appears in a single DAG
   578  	seen := newBitset(int(po.lastidx + 1))
   579  	for _, r := range po.roots {
   580  		if r == 0 {
   581  			panic("empty root")
   582  		}
   583  
   584  		po.dfs(r, false, func(i uint32) bool {
   585  			if seen.Test(i) {
   586  				panic("duplicate node")
   587  			}
   588  			seen.Set(i)
   589  			return false
   590  		})
   591  	}
   592  
   593  	// Verify that values contain the minimum set
   594  	for id, idx := range po.values {
   595  		if !seen.Test(idx) {
   596  			panic(fmt.Errorf("spurious value [%d]=%d", id, idx))
   597  		}
   598  	}
   599  
   600  	// Verify that only existing nodes have non-zero children
   601  	for i, n := range po.nodes {
   602  		if n.l|n.r != 0 {
   603  			if !seen.Test(uint32(i)) {
   604  				panic(fmt.Errorf("children of unknown node %d->%v", i, n))
   605  			}
   606  			if n.l.Target() == uint32(i) || n.r.Target() == uint32(i) {
   607  				panic(fmt.Errorf("self-loop on node %d", i))
   608  			}
   609  		}
   610  	}
   611  }
   612  
   613  // CheckEmpty checks that a poset is completely empty.
   614  // It can be used for debugging purposes, as a poset is supposed to
   615  // be empty after it's fully rolled back through Undo.
   616  func (po *poset) CheckEmpty() error {
   617  	if len(po.nodes) != 1 {
   618  		return fmt.Errorf("non-empty nodes list: %v", po.nodes)
   619  	}
   620  	if len(po.values) != 0 {
   621  		return fmt.Errorf("non-empty value map: %v", po.values)
   622  	}
   623  	if len(po.roots) != 0 {
   624  		return fmt.Errorf("non-empty root list: %v", po.roots)
   625  	}
   626  	if len(po.undo) != 0 {
   627  		return fmt.Errorf("non-empty undo list: %v", po.undo)
   628  	}
   629  	if po.lastidx != 0 {
   630  		return fmt.Errorf("lastidx index is not zero: %v", po.lastidx)
   631  	}
   632  	for _, bs := range po.noneq {
   633  		for _, x := range bs {
   634  			if x != 0 {
   635  				return fmt.Errorf("non-empty noneq map")
   636  			}
   637  		}
   638  	}
   639  	return nil
   640  }
   641  
   642  // DotDump dumps the poset in graphviz format to file fn, with the specified title.
   643  func (po *poset) DotDump(fn string, title string) error {
   644  	f, err := os.Create(fn)
   645  	if err != nil {
   646  		return err
   647  	}
   648  	defer f.Close()
   649  
   650  	// Create reverse index mapping (taking aliases into account)
   651  	names := make(map[uint32]string)
   652  	for id, i := range po.values {
   653  		s := names[i]
   654  		if s == "" {
   655  			s = fmt.Sprintf("v%d", id)
   656  		} else {
   657  			s += fmt.Sprintf(", v%d", id)
   658  		}
   659  		names[i] = s
   660  	}
   661  
   662  	fmt.Fprintf(f, "digraph poset {\n")
   663  	fmt.Fprintf(f, "\tedge [ fontsize=10 ]\n")
   664  	for ridx, r := range po.roots {
   665  		fmt.Fprintf(f, "\tsubgraph root%d {\n", ridx)
   666  		po.dfs(r, false, func(i uint32) bool {
   667  			fmt.Fprintf(f, "\t\tnode%d [label=<%s <font point-size=\"6\">[%d]</font>>]\n", i, names[i], i)
   668  			chl, chr := po.children(i)
   669  			for _, ch := range []posetEdge{chl, chr} {
   670  				if ch != 0 {
   671  					if ch.Strict() {
   672  						fmt.Fprintf(f, "\t\tnode%d -> node%d [label=\" <\" color=\"red\"]\n", i, ch.Target())
   673  					} else {
   674  						fmt.Fprintf(f, "\t\tnode%d -> node%d [label=\" <=\" color=\"green\"]\n", i, ch.Target())
   675  					}
   676  				}
   677  			}
   678  			return false
   679  		})
   680  		fmt.Fprintf(f, "\t}\n")
   681  	}
   682  	fmt.Fprintf(f, "\tlabelloc=\"t\"\n")
   683  	fmt.Fprintf(f, "\tlabeldistance=\"3.0\"\n")
   684  	fmt.Fprintf(f, "\tlabel=%q\n", title)
   685  	fmt.Fprintf(f, "}\n")
   686  	return nil
   687  }
   688  
   689  // Ordered reports whether n1<n2. It returns false either when it is
   690  // certain that n1<n2 is false, or if there is not enough information
   691  // to tell.
   692  // Complexity is O(n).
   693  func (po *poset) Ordered(n1, n2 *Value) bool {
   694  	if debugPoset {
   695  		defer po.CheckIntegrity()
   696  	}
   697  	if n1.ID == n2.ID {
   698  		panic("should not call Ordered with n1==n2")
   699  	}
   700  
   701  	i1, f1 := po.lookup(n1)
   702  	i2, f2 := po.lookup(n2)
   703  	if !f1 || !f2 {
   704  		return false
   705  	}
   706  
   707  	return i1 != i2 && po.reaches(i1, i2, true)
   708  }
   709  
   710  // OrderedOrEqual reports whether n1<=n2. It returns false either when it is
   711  // certain that n1<=n2 is false, or if there is not enough information
   712  // to tell.
   713  // Complexity is O(n).
   714  func (po *poset) OrderedOrEqual(n1, n2 *Value) bool {
   715  	if debugPoset {
   716  		defer po.CheckIntegrity()
   717  	}
   718  	if n1.ID == n2.ID {
   719  		panic("should not call Ordered with n1==n2")
   720  	}
   721  
   722  	i1, f1 := po.lookup(n1)
   723  	i2, f2 := po.lookup(n2)
   724  	if !f1 || !f2 {
   725  		return false
   726  	}
   727  
   728  	return i1 == i2 || po.reaches(i1, i2, false)
   729  }
   730  
   731  // Equal reports whether n1==n2. It returns false either when it is
   732  // certain that n1==n2 is false, or if there is not enough information
   733  // to tell.
   734  // Complexity is O(1).
   735  func (po *poset) Equal(n1, n2 *Value) bool {
   736  	if debugPoset {
   737  		defer po.CheckIntegrity()
   738  	}
   739  	if n1.ID == n2.ID {
   740  		panic("should not call Equal with n1==n2")
   741  	}
   742  
   743  	i1, f1 := po.lookup(n1)
   744  	i2, f2 := po.lookup(n2)
   745  	return f1 && f2 && i1 == i2
   746  }
   747  
   748  // NonEqual reports whether n1!=n2. It returns false either when it is
   749  // certain that n1!=n2 is false, or if there is not enough information
   750  // to tell.
   751  // Complexity is O(n) (because it internally calls Ordered to see if we
   752  // can infer n1!=n2 from n1<n2 or n2<n1).
   753  func (po *poset) NonEqual(n1, n2 *Value) bool {
   754  	if debugPoset {
   755  		defer po.CheckIntegrity()
   756  	}
   757  	if n1.ID == n2.ID {
   758  		panic("should not call NonEqual with n1==n2")
   759  	}
   760  
   761  	// If we never saw the nodes before, we don't
   762  	// have a recorded non-equality.
   763  	i1, f1 := po.lookup(n1)
   764  	i2, f2 := po.lookup(n2)
   765  	if !f1 || !f2 {
   766  		return false
   767  	}
   768  
   769  	// Check if we recorded inequality
   770  	if po.isnoneq(i1, i2) {
   771  		return true
   772  	}
   773  
   774  	// Check if n1<n2 or n2<n1, in which case we can infer that n1!=n2
   775  	if po.Ordered(n1, n2) || po.Ordered(n2, n1) {
   776  		return true
   777  	}
   778  
   779  	return false
   780  }
   781  
   782  // setOrder records that n1<n2 or n1<=n2 (depending on strict). Returns false
   783  // if this is a contradiction.
   784  // Implements SetOrder() and SetOrderOrEqual()
   785  func (po *poset) setOrder(n1, n2 *Value, strict bool) bool {
   786  	i1, f1 := po.lookup(n1)
   787  	i2, f2 := po.lookup(n2)
   788  
   789  	switch {
   790  	case !f1 && !f2:
   791  		// Neither n1 nor n2 are in the poset, so they are not related
   792  		// in any way to existing nodes.
   793  		// Create a new DAG to record the relation.
   794  		i1, i2 = po.newnode(n1), po.newnode(n2)
   795  		po.roots = append(po.roots, i1)
   796  		po.upush(undoNewRoot, i1, 0)
   797  		po.addchild(i1, i2, strict)
   798  
   799  	case f1 && !f2:
   800  		// n1 is in one of the DAGs, while n2 is not. Add n2 as children
   801  		// of n1.
   802  		i2 = po.newnode(n2)
   803  		po.addchild(i1, i2, strict)
   804  
   805  	case !f1 && f2:
   806  		// n1 is not in any DAG but n2 is. If n2 is a root, we can put
   807  		// n1 in its place as a root; otherwise, we need to create a new
   808  		// extra root to record the relation.
   809  		i1 = po.newnode(n1)
   810  
   811  		if po.isroot(i2) {
   812  			po.changeroot(i2, i1)
   813  			po.upush(undoChangeRoot, i1, newedge(i2, strict))
   814  			po.addchild(i1, i2, strict)
   815  			return true
   816  		}
   817  
   818  		// Search for i2's root; this requires a O(n) search on all
   819  		// DAGs
   820  		r := po.findroot(i2)
   821  
   822  		// Re-parent as follows:
   823  		//
   824  		//                  extra
   825  		//     r            /   \
   826  		//      \   ===>   r    i1
   827  		//      i2          \   /
   828  		//                    i2
   829  		//
   830  		extra := po.newnode(nil)
   831  		po.changeroot(r, extra)
   832  		po.upush(undoChangeRoot, extra, newedge(r, false))
   833  		po.addchild(extra, r, false)
   834  		po.addchild(extra, i1, false)
   835  		po.addchild(i1, i2, strict)
   836  
   837  	case f1 && f2:
   838  		// If the nodes are aliased, fail only if we're setting a strict order
   839  		// (that is, we cannot set n1<n2 if n1==n2).
   840  		if i1 == i2 {
   841  			return !strict
   842  		}
   843  
   844  		// If we are trying to record n1<=n2 but we learned that n1!=n2,
   845  		// record n1<n2, as it provides more information.
   846  		if !strict && po.isnoneq(i1, i2) {
   847  			strict = true
   848  		}
   849  
   850  		// Both n1 and n2 are in the poset. This is the complex part of the algorithm
   851  		// as we need to find many different cases and DAG shapes.
   852  
   853  		// Check if n1 somehow reaches n2
   854  		if po.reaches(i1, i2, false) {
   855  			// This is the table of all cases we need to handle:
   856  			//
   857  			//      DAG          New      Action
   858  			//      ---------------------------------------------------
   859  			// #1:  N1<=X<=N2 |  N1<=N2 | do nothing
   860  			// #2:  N1<=X<=N2 |  N1<N2  | add strict edge (N1<N2)
   861  			// #3:  N1<X<N2   |  N1<=N2 | do nothing (we already know more)
   862  			// #4:  N1<X<N2   |  N1<N2  | do nothing
   863  
   864  			// Check if we're in case #2
   865  			if strict && !po.reaches(i1, i2, true) {
   866  				po.addchild(i1, i2, true)
   867  				return true
   868  			}
   869  
   870  			// Case #1, #3, or #4: nothing to do
   871  			return true
   872  		}
   873  
   874  		// Check if n2 somehow reaches n1
   875  		if po.reaches(i2, i1, false) {
   876  			// This is the table of all cases we need to handle:
   877  			//
   878  			//      DAG           New      Action
   879  			//      ---------------------------------------------------
   880  			// #5:  N2<=X<=N1  |  N1<=N2 | collapse path (learn that N1=X=N2)
   881  			// #6:  N2<=X<=N1  |  N1<N2  | contradiction
   882  			// #7:  N2<X<N1    |  N1<=N2 | contradiction in the path
   883  			// #8:  N2<X<N1    |  N1<N2  | contradiction
   884  
   885  			if strict {
   886  				// Cases #6 and #8: contradiction
   887  				return false
   888  			}
   889  
   890  			// We're in case #5 or #7. Try to collapse path, and that will
   891  			// fail if it realizes that we are in case #7.
   892  			return po.collapsepath(n2, n1)
   893  		}
   894  
   895  		// We don't know of any existing relation between n1 and n2. They could
   896  		// be part of the same DAG or not.
   897  		// Find their roots to check whether they are in the same DAG.
   898  		r1, r2 := po.findroot(i1), po.findroot(i2)
   899  		if r1 != r2 {
   900  			// We need to merge the two DAGs to record a relation between the nodes
   901  			po.mergeroot(r1, r2)
   902  		}
   903  
   904  		// Connect n1 and n2
   905  		po.addchild(i1, i2, strict)
   906  	}
   907  
   908  	return true
   909  }
   910  
   911  // SetOrder records that n1<n2. Returns false if this is a contradiction
   912  // Complexity is O(1) if n2 was never seen before, or O(n) otherwise.
   913  func (po *poset) SetOrder(n1, n2 *Value) bool {
   914  	if debugPoset {
   915  		defer po.CheckIntegrity()
   916  	}
   917  	if n1.ID == n2.ID {
   918  		panic("should not call SetOrder with n1==n2")
   919  	}
   920  	return po.setOrder(n1, n2, true)
   921  }
   922  
   923  // SetOrderOrEqual records that n1<=n2. Returns false if this is a contradiction
   924  // Complexity is O(1) if n2 was never seen before, or O(n) otherwise.
   925  func (po *poset) SetOrderOrEqual(n1, n2 *Value) bool {
   926  	if debugPoset {
   927  		defer po.CheckIntegrity()
   928  	}
   929  	if n1.ID == n2.ID {
   930  		panic("should not call SetOrder with n1==n2")
   931  	}
   932  	return po.setOrder(n1, n2, false)
   933  }
   934  
   935  // SetEqual records that n1==n2. Returns false if this is a contradiction
   936  // (that is, if it is already recorded that n1<n2 or n2<n1).
   937  // Complexity is O(1) if n2 was never seen before, or O(n) otherwise.
   938  func (po *poset) SetEqual(n1, n2 *Value) bool {
   939  	if debugPoset {
   940  		defer po.CheckIntegrity()
   941  	}
   942  	if n1.ID == n2.ID {
   943  		panic("should not call Add with n1==n2")
   944  	}
   945  
   946  	i1, f1 := po.lookup(n1)
   947  	i2, f2 := po.lookup(n2)
   948  
   949  	switch {
   950  	case !f1 && !f2:
   951  		i1 = po.newnode(n1)
   952  		po.roots = append(po.roots, i1)
   953  		po.upush(undoNewRoot, i1, 0)
   954  		po.aliasnewnode(n1, n2)
   955  	case f1 && !f2:
   956  		po.aliasnewnode(n1, n2)
   957  	case !f1 && f2:
   958  		po.aliasnewnode(n2, n1)
   959  	case f1 && f2:
   960  		if i1 == i2 {
   961  			// Already aliased, ignore
   962  			return true
   963  		}
   964  
   965  		// If we recorded that n1!=n2, this is a contradiction.
   966  		if po.isnoneq(i1, i2) {
   967  			return false
   968  		}
   969  
   970  		// If we already knew that n1<=n2, we can collapse the path to
   971  		// record n1==n2 (and vice versa).
   972  		if po.reaches(i1, i2, false) {
   973  			return po.collapsepath(n1, n2)
   974  		}
   975  		if po.reaches(i2, i1, false) {
   976  			return po.collapsepath(n2, n1)
   977  		}
   978  
   979  		r1 := po.findroot(i1)
   980  		r2 := po.findroot(i2)
   981  		if r1 != r2 {
   982  			// Merge the two DAGs so we can record relations between the nodes
   983  			po.mergeroot(r1, r2)
   984  		}
   985  
   986  		// Set n2 as alias of n1. This will also update all the references
   987  		// to n2 to become references to n1
   988  		i2s := newBitset(int(po.lastidx) + 1)
   989  		i2s.Set(i2)
   990  		po.aliasnodes(n1, i2s)
   991  	}
   992  	return true
   993  }
   994  
   995  // SetNonEqual records that n1!=n2. Returns false if this is a contradiction
   996  // (that is, if it is already recorded that n1==n2).
   997  // Complexity is O(n).
   998  func (po *poset) SetNonEqual(n1, n2 *Value) bool {
   999  	if debugPoset {
  1000  		defer po.CheckIntegrity()
  1001  	}
  1002  	if n1.ID == n2.ID {
  1003  		panic("should not call SetNonEqual with n1==n2")
  1004  	}
  1005  
  1006  	// Check whether the nodes are already in the poset
  1007  	i1, f1 := po.lookup(n1)
  1008  	i2, f2 := po.lookup(n2)
  1009  
  1010  	// If either node wasn't present, we just record the new relation
  1011  	// and exit.
  1012  	if !f1 || !f2 {
  1013  		po.setnoneq(n1, n2)
  1014  		return true
  1015  	}
  1016  
  1017  	// See if we already know this, in which case there's nothing to do.
  1018  	if po.isnoneq(i1, i2) {
  1019  		return true
  1020  	}
  1021  
  1022  	// Check if we're contradicting an existing equality relation
  1023  	if po.Equal(n1, n2) {
  1024  		return false
  1025  	}
  1026  
  1027  	// Record non-equality
  1028  	po.setnoneq(n1, n2)
  1029  
  1030  	// If we know that i1<=i2 but not i1<i2, learn that as we
  1031  	// now know that they are not equal. Do the same for i2<=i1.
  1032  	// Do this check only if both nodes were already in the DAG,
  1033  	// otherwise there cannot be an existing relation.
  1034  	if po.reaches(i1, i2, false) && !po.reaches(i1, i2, true) {
  1035  		po.addchild(i1, i2, true)
  1036  	}
  1037  	if po.reaches(i2, i1, false) && !po.reaches(i2, i1, true) {
  1038  		po.addchild(i2, i1, true)
  1039  	}
  1040  
  1041  	return true
  1042  }
  1043  
  1044  // Checkpoint saves the current state of the DAG so that it's possible
  1045  // to later undo this state.
  1046  // Complexity is O(1).
  1047  func (po *poset) Checkpoint() {
  1048  	po.undo = append(po.undo, posetUndo{typ: undoCheckpoint})
  1049  }
  1050  
  1051  // Undo restores the state of the poset to the previous checkpoint.
  1052  // Complexity depends on the type of operations that were performed
  1053  // since the last checkpoint; each Set* operation creates an undo
  1054  // pass which Undo has to revert with a worst-case complexity of O(n).
  1055  func (po *poset) Undo() {
  1056  	if len(po.undo) == 0 {
  1057  		panic("empty undo stack")
  1058  	}
  1059  	if debugPoset {
  1060  		defer po.CheckIntegrity()
  1061  	}
  1062  
  1063  	for len(po.undo) > 0 {
  1064  		pass := po.undo[len(po.undo)-1]
  1065  		po.undo = po.undo[:len(po.undo)-1]
  1066  
  1067  		switch pass.typ {
  1068  		case undoCheckpoint:
  1069  			return
  1070  
  1071  		case undoSetChl:
  1072  			po.setchl(pass.idx, pass.edge)
  1073  
  1074  		case undoSetChr:
  1075  			po.setchr(pass.idx, pass.edge)
  1076  
  1077  		case undoNonEqual:
  1078  			po.noneq[uint32(pass.ID)].Clear(pass.idx)
  1079  
  1080  		case undoNewNode:
  1081  			if pass.idx != po.lastidx {
  1082  				panic("invalid newnode index")
  1083  			}
  1084  			if pass.ID != 0 {
  1085  				if po.values[pass.ID] != pass.idx {
  1086  					panic("invalid newnode undo pass")
  1087  				}
  1088  				delete(po.values, pass.ID)
  1089  			}
  1090  			po.setchl(pass.idx, 0)
  1091  			po.setchr(pass.idx, 0)
  1092  			po.nodes = po.nodes[:pass.idx]
  1093  			po.lastidx--
  1094  
  1095  		case undoAliasNode:
  1096  			ID, prev := pass.ID, pass.idx
  1097  			cur := po.values[ID]
  1098  			if prev == 0 {
  1099  				// Born as an alias, die as an alias
  1100  				delete(po.values, ID)
  1101  			} else {
  1102  				if cur == prev {
  1103  					panic("invalid aliasnode undo pass")
  1104  				}
  1105  				// Give it back previous value
  1106  				po.values[ID] = prev
  1107  			}
  1108  
  1109  		case undoNewRoot:
  1110  			i := pass.idx
  1111  			l, r := po.children(i)
  1112  			if l|r != 0 {
  1113  				panic("non-empty root in undo newroot")
  1114  			}
  1115  			po.removeroot(i)
  1116  
  1117  		case undoChangeRoot:
  1118  			i := pass.idx
  1119  			l, r := po.children(i)
  1120  			if l|r != 0 {
  1121  				panic("non-empty root in undo changeroot")
  1122  			}
  1123  			po.changeroot(i, pass.edge.Target())
  1124  
  1125  		case undoMergeRoot:
  1126  			i := pass.idx
  1127  			l, r := po.children(i)
  1128  			po.changeroot(i, l.Target())
  1129  			po.roots = append(po.roots, r.Target())
  1130  
  1131  		default:
  1132  			panic(pass.typ)
  1133  		}
  1134  	}
  1135  
  1136  	if debugPoset && po.CheckEmpty() != nil {
  1137  		panic("poset not empty at the end of undo")
  1138  	}
  1139  }
  1140  

View as plain text