Source file src/cmd/compile/internal/ssa/known_bits.go

     1  // Copyright 2026 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package ssa
     6  
     7  import (
     8  	"slices"
     9  	"strings"
    10  )
    11  
    12  func (kb *knownBitsState) fold(v *Value) (value, known int64) {
    13  	if kb.seenValues.Test(uint32(v.ID)) {
    14  		return kb.entries[v.ID].value, kb.entries[v.ID].known
    15  	}
    16  	defer func() {
    17  		// maintain the invariants:
    18  		// 3. booleans are stored as 1 byte values who are either 0 or 1.
    19  		if v.Type.IsBoolean() {
    20  			value &= 1
    21  			known |= ^1
    22  		}
    23  
    24  		// 2. all values are sign-extended to int64 (inspired by RISC-V's xlen=64)
    25  		switch v.Type.Size() {
    26  		case 1:
    27  			value = int64(int8(value))
    28  			known = int64(int8(known))
    29  		case 2:
    30  			value = int64(int16(value))
    31  			known = int64(int16(known))
    32  		case 4:
    33  			value = int64(int32(value))
    34  			known = int64(int32(known))
    35  		case 8:
    36  		default:
    37  			panic("unreachable; unknown integer size")
    38  		}
    39  
    40  		// 1. unknown bits are always set to 0 inside value
    41  		value &= known
    42  
    43  		kb.entries[v.ID].known = known
    44  		kb.entries[v.ID].value = value
    45  		if v.Block.Func.pass.debug > 1 {
    46  			v.Block.Func.Warnl(v.Pos, "known bits state %v: %v", v, kb.entries[v.ID])
    47  		}
    48  	}()
    49  	kb.seenValues.Set(uint32(v.ID)) // set seen early to give up on loops
    50  
    51  	switch v.Op {
    52  	// TODO: rotates, ...
    53  	case OpConst64, OpConst32, OpConst16, OpConst8, OpConstBool:
    54  		return v.AuxInt, -1
    55  	case OpAnd64, OpAnd32, OpAnd16, OpAnd8, OpAndB:
    56  		x, xk := kb.fold(v.Args[0])
    57  		y, yk := kb.fold(v.Args[1])
    58  		onesInBoth := x & y
    59  		zerosInX := ^x & xk
    60  		zerosInY := ^y & yk
    61  		return x & y, onesInBoth | zerosInX | zerosInY
    62  	case OpOr64, OpOr32, OpOr16, OpOr8, OpOrB:
    63  		x, xk := kb.fold(v.Args[0])
    64  		y, yk := kb.fold(v.Args[1])
    65  		zerosInBoth := ^x & ^y & (xk & yk)
    66  		onesInX := x
    67  		onesInY := y
    68  		return x | y, onesInX | onesInY | zerosInBoth
    69  	case OpXor64, OpXor32, OpXor16, OpXor8:
    70  		x, xk := kb.fold(v.Args[0])
    71  		y, yk := kb.fold(v.Args[1])
    72  		return x ^ y, xk & yk
    73  	case OpCom64, OpCom32, OpCom16, OpCom8, OpNot:
    74  		x, xk := kb.fold(v.Args[0])
    75  		return ^x, xk
    76  	case OpPhi:
    77  		set := false
    78  		for i, arg := range v.Args {
    79  			if !kb.isLiveInEdge(v.Block, uint(i)) {
    80  				continue
    81  			}
    82  			a, k := kb.fold(arg)
    83  			if !set {
    84  				value, known = a, k
    85  				set = true
    86  			} else {
    87  				known &^= value ^ a
    88  				known &= k
    89  			}
    90  			if known == 0 {
    91  				break
    92  			}
    93  		}
    94  		return value, known
    95  	case OpCopy, OpCvtBoolToUint8,
    96  		OpSignExt8to16, OpSignExt8to32, OpSignExt8to64, OpSignExt16to32, OpSignExt16to64, OpSignExt32to64,
    97  		// The defer block handles maintaining the sign-extension invariant using v.Type.Size()
    98  		// thus we can just pass Truncs as-is.
    99  		OpTrunc64to32, OpTrunc64to16, OpTrunc64to8, OpTrunc32to16, OpTrunc32to8, OpTrunc16to8:
   100  		return kb.fold(v.Args[0])
   101  	case OpEq64, OpEq32, OpEq16, OpEq8, OpEqB:
   102  		x, xk := kb.fold(v.Args[0])
   103  		y, yk := kb.fold(v.Args[1])
   104  		differentBits := x ^ y
   105  		if differentBits&xk&yk != 0 {
   106  			return 0, -1
   107  		}
   108  		if xk == -1 && yk == -1 {
   109  			return boolToAuxInt(x == y), -1
   110  		}
   111  		return 0, -1 << 1
   112  	case OpNeq64, OpNeq32, OpNeq16, OpNeq8, OpNeqB:
   113  		x, xk := kb.fold(v.Args[0])
   114  		y, yk := kb.fold(v.Args[1])
   115  		differentBits := x ^ y
   116  		if differentBits&xk&yk != 0 {
   117  			return 1, -1
   118  		}
   119  		if xk == -1 && yk == -1 {
   120  			return boolToAuxInt(x != y), -1
   121  		}
   122  		return 0, -1 << 1
   123  	case OpZeroExt8to16, OpZeroExt8to32, OpZeroExt8to64, OpZeroExt16to32, OpZeroExt16to64, OpZeroExt32to64:
   124  		x, k := kb.fold(v.Args[0])
   125  		srcSize := v.Args[0].Type.Size() * 8
   126  		mask := int64(1<<srcSize - 1)
   127  		return x & mask, k | ^mask
   128  	case OpLsh8x8, OpLsh16x8, OpLsh32x8, OpLsh64x8,
   129  		OpLsh8x16, OpLsh16x16, OpLsh32x16, OpLsh64x16,
   130  		OpLsh8x32, OpLsh16x32, OpLsh32x32, OpLsh64x32,
   131  		OpLsh8x64, OpLsh16x64, OpLsh32x64, OpLsh64x64:
   132  		return kb.computeKnownBitsForShift(v, func(x, xk, xSize, shift int64) (value, known int64) {
   133  			return x << shift, xk<<shift | (1<<shift - 1)
   134  		})
   135  	case OpRsh8Ux8, OpRsh16Ux8, OpRsh32Ux8, OpRsh64Ux8,
   136  		OpRsh8Ux16, OpRsh16Ux16, OpRsh32Ux16, OpRsh64Ux16,
   137  		OpRsh8Ux32, OpRsh16Ux32, OpRsh32Ux32, OpRsh64Ux32,
   138  		OpRsh8Ux64, OpRsh16Ux64, OpRsh32Ux64, OpRsh64Ux64:
   139  		return kb.computeKnownBitsForShift(v, func(x, xk, xSize, shift int64) (value, known int64) {
   140  			x &= (1<<xSize - 1)
   141  			xk |= -1 << xSize
   142  			return int64(uint64(x) >> shift), int64(uint64(xk)>>shift | (^uint64(0) << (64 - shift)))
   143  		})
   144  	case OpRsh8x8, OpRsh16x8, OpRsh32x8, OpRsh64x8,
   145  		OpRsh8x16, OpRsh16x16, OpRsh32x16, OpRsh64x16,
   146  		OpRsh8x32, OpRsh16x32, OpRsh32x32, OpRsh64x32,
   147  		OpRsh8x64, OpRsh16x64, OpRsh32x64, OpRsh64x64:
   148  		return kb.computeKnownBitsForShift(v, func(x, xk, xSize, shift int64) (value, known int64) {
   149  			return x >> shift, xk >> shift
   150  		})
   151  	default:
   152  		return 0, 0
   153  	}
   154  }
   155  
   156  // knownBits does constant folding across bitfields
   157  func knownBits(f *Func) {
   158  	kb := &knownBitsState{
   159  		entries:         f.Cache.allocKnownBitsEntriesSlice(f.NumValues()),
   160  		seenValues:      f.Cache.allocBitset(f.NumValues()),
   161  		reachableBlocks: f.Cache.allocBitset(f.NumBlocks()),
   162  	}
   163  	defer f.Cache.freeKnownBitsEntriesSlice(kb.entries)
   164  	defer f.Cache.freeBitset(kb.seenValues)
   165  	defer f.Cache.freeBitset(kb.reachableBlocks)
   166  	clear(kb.seenValues)
   167  	clear(kb.entries)
   168  	clear(kb.reachableBlocks)
   169  
   170  	blocks := f.postorder()
   171  	for _, b := range blocks {
   172  		kb.reachableBlocks.Set(uint32(b.ID))
   173  	}
   174  
   175  	for _, b := range slices.Backward(blocks) {
   176  		for _, v := range b.Values {
   177  			if v.Uses == 0 || !(v.Type.IsInteger() || v.Type.IsBoolean()) {
   178  				continue
   179  			}
   180  			switch v.Op {
   181  			case OpConst64, OpConst32, OpConst16, OpConst8, OpConstBool:
   182  				continue
   183  			}
   184  			val, k := kb.fold(v)
   185  			if k != -1 {
   186  				continue
   187  			}
   188  			if f.pass.debug > 0 {
   189  				var pval any = val
   190  				if v.Type.IsBoolean() {
   191  					pval = val != 0
   192  				}
   193  				f.Warnl(v.Pos, "known value of %v (%v): %v", v, v.Op, pval)
   194  			}
   195  			var c *Value
   196  			switch v.Type.Size() {
   197  			case 1:
   198  				if v.Type.IsBoolean() {
   199  					c = f.ConstBool(v.Type, val != 0)
   200  					break
   201  				}
   202  				c = f.ConstInt8(v.Type, int8(val))
   203  			case 2:
   204  				c = f.ConstInt16(v.Type, int16(val))
   205  			case 4:
   206  				c = f.ConstInt32(v.Type, int32(val))
   207  			case 8:
   208  				c = f.ConstInt64(v.Type, val)
   209  			default:
   210  				panic("unreachable; unknown integer size")
   211  			}
   212  			v.copyOf(c)
   213  		}
   214  	}
   215  }
   216  
   217  type knownBitsState struct {
   218  	entries         []knownBitsEntry // indexed by Value.ID
   219  	seenValues      bitset           // indexed by Value.ID (at the bit level)
   220  	reachableBlocks bitset           // indexed by Block.ID (at the bit level)
   221  }
   222  
   223  type knownBitsEntry struct {
   224  	// Two invariants:
   225  	// 1. unknown bits are always set to 0 inside value
   226  	// 2. all values are sign-extended to int64 (inspired by RISC-V's xlen=64)
   227  	//    This means let's say you know an 8 bits value is 0b10??????,
   228  	//    known = int64(int8(0b11000000))
   229  	//    value = int64(int8(0b10000000))
   230  	// 3. booleans are stored as 1 byte values who are either 0 or 1.
   231  	known, value int64
   232  }
   233  
   234  func (kbe knownBitsEntry) String() string {
   235  	lut := []rune{ // indexed by knownBit<<1 | valueBit
   236  		0b00: '?',
   237  		0b01: '¿', // violates invariant 1
   238  		0b10: '0',
   239  		0b11: '1',
   240  	}
   241  	var sb strings.Builder
   242  	sb.Grow(64)
   243  	for i := 63; i >= 0; i-- {
   244  		bits := (kbe.known>>i&1)<<1 | (kbe.value >> i & 1)
   245  		sb.WriteRune(lut[bits])
   246  	}
   247  	return sb.String()
   248  }
   249  
   250  func (kb *knownBitsState) isLiveInEdge(b *Block, index uint) bool {
   251  	inEdge := b.Preds[index]
   252  	return kb.isLiveOutEdge(inEdge.b, uint(inEdge.i))
   253  }
   254  
   255  func (kb *knownBitsState) isLiveOutEdge(b *Block, index uint) bool {
   256  	if !kb.reachableBlocks.Test(uint32(b.ID)) {
   257  		return false
   258  	}
   259  
   260  	switch b.Kind {
   261  	case BlockFirst:
   262  		return index == 0
   263  	case BlockPlain, BlockIf, BlockDefer, BlockRet, BlockRetJmp, BlockExit, BlockJumpTable:
   264  		return true
   265  	default:
   266  		panic("unreachable; unknown block kind")
   267  	}
   268  }
   269  
   270  // computeKnownBitsForShift computes the known bits for a shift operation.
   271  // Considering the following piece of code x = x << uint8(i)
   272  // The algorithm is based on two observations:
   273  //
   274  //  1. computing a shift of a lattice by a constant (i) is easy:
   275  //     value, known = x<<i, xk<<i|(1<<i-1)
   276  //     each point in the lattice is shifted by the constant, all new shifted in bits are known zeros.
   277  //
   278  //  2. x = uint8(x) << i is equivalent to
   279  //
   280  //     switch i {
   281  //     case 0:  x0 = x << 0
   282  //     case 1:  x1 = x << 1
   283  //     case 2:  x2 = x << 2
   284  //     case 3:  x3 = x << 3
   285  //     case 4:  x4 = x << 4
   286  //     case 5:  x5 = x << 5
   287  //     case 6:  x6 = x << 6
   288  //     case 7:  x7 = x << 7
   289  //     default: xd = x << 8
   290  //     }
   291  //     x = phi(x0, x1, x2, x3, x4, x5, x6, x7, xd)
   292  //
   293  // The algorithm below then models the phi in the equivalence above using same intersection algorithm phi uses.
   294  // We also leverage known bits of the shift amount to remove "branches" in the switch that are proved to be impossible.
   295  func (kb *knownBitsState) computeKnownBitsForShift(v *Value, doShiftByAConst func(x, xk, xSize, shift int64) (value, known int64)) (value, known int64) {
   296  	xSize := v.Args[0].Type.Size() * 8
   297  	x, xk := kb.fold(v.Args[0])
   298  	y, yk := kb.fold(v.Args[1])
   299  	if uint64(y) >= uint64(xSize) {
   300  		return doShiftByAConst(x, xk, xSize, 64)
   301  	}
   302  
   303  	set := false
   304  	if v.AuxInt == 0 && uint64(^yk) >= uint64(xSize) {
   305  		// this implement the default case of the equivalent switch above.
   306  		// if the shift isn't bounded and there are unknown bits above the shift size we might completely stomp all bits.
   307  
   308  		value, known = doShiftByAConst(x, xk, xSize, 64)
   309  		set = true
   310  	}
   311  
   312  	yk |= ^(xSize - 1)
   313  
   314  	for i := range allPossibleValues(y, yk) {
   315  		a, k := doShiftByAConst(x, xk, xSize, i)
   316  		if !set {
   317  			value, known = a, k
   318  			set = true
   319  		} else {
   320  			known &^= value ^ a
   321  			known &= k
   322  		}
   323  		if known == 0 {
   324  			break
   325  		}
   326  	}
   327  
   328  	return value & known, known
   329  }
   330  
   331  // allPossibleValues iterates over all values that could exist.
   332  // It scales exponentially with the number of unknown bits,
   333  // the exact number of iterations will be uint128(1)<<bits.OnesCount64(^known)
   334  // thus be careful with what values you pass to it.
   335  func allPossibleValues(value, known int64) func(yield func(v int64) bool) {
   336  	unknown := ^known
   337  	return func(yield func(v int64) bool) {
   338  		// This finds the next valid value for the variable bits.
   339  		// It is equivalent to (s|known + 1) & unknown.
   340  		// The s|known step creates blocks of 1s in all the known bits.
   341  		// +1 finds the next possible value, the blocks of 1s set in the previous step allows it to skip over blocks of known bits.
   342  		// & unknown clears garbage generated by the blocks of ones and overflow.
   343  		//
   344  		// You can transform (s|known + 1) & unknown into (s - unknown) & unknown through:
   345  		// (s +  known   + 1) & unknown: s | known → s + known (since s & known == 0)
   346  		// (s + ^unknown + 1) & unknown: known → ^unknown (definition of unknown)
   347  		// (s + -unknown)     & unknown: ^unknown + 1 → -unknown (two's complement negation)
   348  		// (s -  unknown)     & unknown: s + -unknown → s - unknown (arithmetic)
   349  		for s := int64(0); ; s = (s - unknown) & unknown {
   350  			// fixed bits | current variable bits gives the current iteration
   351  			if !yield(value | s) {
   352  				return
   353  			}
   354  			if s == unknown {
   355  				break
   356  			}
   357  		}
   358  	}
   359  }
   360  

View as plain text