Source file src/cmd/compile/internal/midway/rewrite.go

     1  // Copyright 2026 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package midway
     6  
     7  import (
     8  	"cmd/compile/internal/syntax"
     9  	"cmd/compile/internal/types2"
    10  	"fmt"
    11  	"internal/buildcfg"
    12  	"strings"
    13  )
    14  
    15  // "Midway" rewriting
    16  //
    17  // Go attempts to provide a package similar to the the "Highway" library
    18  // for C++ (https://google.github.io/highway).  The library package is "simd"
    19  // and defines vector types with unspecified widths that are bound to particular
    20  // machine dependent types as late as program execution.  This is accomplished
    21  // by rewriting code that depends on these types into code that references
    22  // architecture-specific types, perhaps more than once, and if necessary
    23  // dynamically choosing which version to execute based on hardware attributes.
    24  //
    25  // The rewriting takes place early in the compiler, after type checking but
    26  // before conversion to "unified" IR.  To ensure that types are correctly set
    27  // on the modified version of the code, type checking information is reset and
    28  // the type checking phase is re-run.  The places some limits on the shape of
    29  // the rewrites, but it also ensures that the rewritten code is well-formed.
    30  //
    31  // Rewritten code does not reference "archsimd" types directly, but instead
    32  // references types in a "bridge" package that filters the available methods
    33  // and adds a few more.  The package used relies on a builder/compiler hack;
    34  // the compiler's type checker enforces export naming conventions, but the
    35  // build system limits visibility to unrelated "internal" packages and can be
    36  // modified to allow access in special cases (like this one).  This allows the
    37  // rewritten code to reference types, functions, and methods that are not
    38  // accessible otherwise.
    39  
    40  type Rewriter struct {
    41  	pkg      *types2.Package
    42  	analyzer *Analyzer
    43  	info     *types2.Info
    44  	sizes    []int
    45  }
    46  
    47  func NewRewriter(pkg *types2.Package, info *types2.Info, analyzer *Analyzer, sizes []int) *Rewriter {
    48  	return &Rewriter{
    49  		pkg:      pkg,
    50  		info:     info,
    51  		analyzer: analyzer,
    52  		sizes:    sizes,
    53  	}
    54  }
    55  
    56  func (r *Rewriter) Rewrite(files []*syntax.File) {
    57  
    58  	// First duplicate and specialize all dependent functions and variables.
    59  	for _, fileAST := range files {
    60  
    61  		var newDecls []syntax.Decl
    62  		for _, k := range r.sizes {
    63  			newDecls = r.generateForSize(fileAST, k, newDecls)
    64  		}
    65  
    66  		// Then replace original functions with dispatchers.
    67  		r.generateDispatchers(fileAST)
    68  
    69  		fileAST.DeclList = append(fileAST.DeclList, newDecls...)
    70  	}
    71  }
    72  
    73  func (r *Rewriter) generateDispatchers(fileAST *syntax.File) {
    74  	var newDecls []syntax.Decl
    75  
    76  	for _, decl := range fileAST.DeclList {
    77  		switch d := decl.(type) {
    78  		case *syntax.FuncDecl:
    79  			if d.Name == nil {
    80  				newDecls = append(newDecls, d)
    81  				continue
    82  			}
    83  			obj := r.info.Defs[d.Name]
    84  			if !r.analyzer.dependentObj[obj] || r.analyzer.inSimd {
    85  				newDecls = append(newDecls, d)
    86  				continue
    87  			}
    88  
    89  			sig, ok := obj.Type().(*types2.Signature)
    90  			if !ok {
    91  				newDecls = append(newDecls, d)
    92  				continue
    93  			}
    94  
    95  			if r.analyzer.HasDependentSignature(sig) {
    96  				// Drop dependent signatures entirely
    97  				continue
    98  			}
    99  
   100  			// Clean signature -> Replace body with dispatcher
   101  			d.Body = r.createDispatcherBody(d, sig)
   102  			newDecls = append(newDecls, d)
   103  
   104  		case *syntax.VarDecl:
   105  			// Filter specs conceptually based on dependents
   106  			keep := false
   107  			for _, name := range d.NameList {
   108  				if !r.analyzer.dependentObj[r.info.Defs[name]] {
   109  					keep = true
   110  					break // Keep entire var decl if any name is clean, else drop
   111  				}
   112  			}
   113  			if keep {
   114  				newDecls = append(newDecls, d)
   115  			}
   116  		case *syntax.TypeDecl:
   117  			if !r.analyzer.dependentObj[r.info.Defs[d.Name]] || r.analyzer.inSimd {
   118  				newDecls = append(newDecls, d)
   119  			}
   120  		default:
   121  			newDecls = append(newDecls, decl)
   122  		}
   123  	}
   124  
   125  	fileAST.DeclList = newDecls
   126  
   127  	if !r.analyzer.inSimd {
   128  		// Inject an import to the bridge package (if not exists)
   129  		hasArchSimd := false
   130  		var simdImport *syntax.ImportDecl
   131  		for _, decl := range fileAST.DeclList {
   132  			if imp, ok := decl.(*syntax.ImportDecl); ok {
   133  				if imp.Path.Value == `"`+archFullPkg+`"` {
   134  					hasArchSimd = true
   135  				}
   136  				if imp.Path.Value == `"`+simdPkg+`"` {
   137  					simdImport = imp
   138  				}
   139  
   140  			}
   141  		}
   142  		p := simdImport.Pos()
   143  		if !hasArchSimd {
   144  			r.injectImport(fileAST, archFullPkg, p)
   145  		}
   146  
   147  		// Ensure at least one use of "simd"
   148  		// var _ = simd.VectorBitLen()
   149  		fun := &syntax.SelectorExpr{
   150  			X:   syntax.NewName(p, simdPkg), // Assume this is resolvable
   151  			Sel: syntax.NewName(p, vectorSizeFn),
   152  		}
   153  		fun.SetPos(p)
   154  		call := &syntax.CallExpr{Fun: fun}
   155  		call.SetPos(p)
   156  
   157  		name := syntax.NewName(p, "_")
   158  
   159  		varDecl := &syntax.VarDecl{NameList: []*syntax.Name{name}, Values: call}
   160  		varDecl.SetPos(p)
   161  		fileAST.DeclList = append(fileAST.DeclList, varDecl)
   162  	}
   163  }
   164  
   165  func (r *Rewriter) injectImport(fileAST *syntax.File, toImport string, simdImportPos syntax.Pos) {
   166  	importDecl := &syntax.ImportDecl{
   167  		Path: &syntax.BasicLit{Value: `"` + toImport + `"`, Kind: syntax.StringLit},
   168  	}
   169  	importDecl.Path.SetPos(simdImportPos)
   170  	importDecl.SetPos(simdImportPos)
   171  	fileAST.DeclList = append([]syntax.Decl{importDecl}, fileAST.DeclList...)
   172  }
   173  
   174  func (r *Rewriter) createDispatcherBody(d *syntax.FuncDecl, sig *types2.Signature) *syntax.BlockStmt {
   175  
   176  	// Build call arguments from the function parameters
   177  	args := func() []syntax.Expr {
   178  		var args []syntax.Expr
   179  		if d.Type.ParamList != nil {
   180  			for _, field := range d.Type.ParamList {
   181  				if field.Name != nil {
   182  					paramName := syntax.NewName(field.Pos(), field.Name.Value)
   183  					args = append(args, paramName)
   184  				}
   185  			}
   186  		}
   187  		return args
   188  	}
   189  
   190  	// Slap a pos on an expression
   191  	pe := func(e syntax.Expr) syntax.Expr {
   192  		e.SetPos(d.Pos())
   193  		return e
   194  	}
   195  	// Slap a pos on a statement
   196  	ps := func(e syntax.Stmt) syntax.Stmt {
   197  		e.SetPos(d.Pos())
   198  		return e
   199  	}
   200  
   201  	// switch ast node.
   202  	// the goal is something like (for now, till there are finer-grained choices)
   203  	// switch simd.VectorSize() {
   204  	//   case 128: if simd.Emulated() { call the specialize-for-emulation-code(args) }
   205  	//             else { call the specialize-for-128-code(args) }
   206  	//   case 256: call the specialize-for-256-code(args)
   207  	//   etc
   208  	// }
   209  	//
   210  	// the cases above deal with the usual `return call(...)` vs `call(...); return`
   211  	switchStmt := &syntax.SwitchStmt{
   212  		Tag: pe(&syntax.CallExpr{
   213  			Fun: pe(&syntax.SelectorExpr{
   214  				X:   syntax.NewName(d.Pos(), simdPkg), // Assume this is resolvable
   215  				Sel: syntax.NewName(d.Pos(), vectorSizeFn),
   216  			}),
   217  		}),
   218  		Body: []*syntax.CaseClause{},
   219  	}
   220  
   221  	var emulation syntax.Stmt
   222  
   223  	for _, k := range r.sizes {
   224  		fnName := fmt.Sprintf("%s@simd%d", d.Name.Value, k)
   225  		fnIdent := syntax.NewName(d.Pos(), fnName)
   226  
   227  		callExpr := pe(&syntax.CallExpr{
   228  			Fun:     pe(fnIdent),
   229  			ArgList: args(),
   230  		})
   231  
   232  		// callReturnStmt is either `return call(...)` or `call(...); return`
   233  		var callReturnStmt syntax.Stmt
   234  		if d.Type.ResultList != nil && len(d.Type.ResultList) > 0 {
   235  			callReturnStmt = &syntax.ReturnStmt{Results: callExpr}
   236  		} else {
   237  			callReturnStmt = &syntax.BlockStmt{
   238  				List: []syntax.Stmt{
   239  					ps(&syntax.ExprStmt{X: callExpr}),
   240  					ps(&syntax.ReturnStmt{}),
   241  				},
   242  				Rbrace: d.Pos(),
   243  			}
   244  		}
   245  		callReturnStmt.SetPos(d.Pos())
   246  
   247  		if k == 0 {
   248  			// emulation == `if simd.Emulated() { callReturnStmt }`
   249  			// save it for the first part of the 128 case.
   250  			cond := pe(&syntax.CallExpr{
   251  				Fun: pe(&syntax.SelectorExpr{
   252  					X:   syntax.NewName(d.Pos(), simdPkg), // Assume this is resolvable
   253  					Sel: syntax.NewName(d.Pos(), emulatedFn),
   254  				})})
   255  
   256  			blockStmt, ok := callReturnStmt.(*syntax.BlockStmt)
   257  			if !ok {
   258  				blockStmt = &syntax.BlockStmt{
   259  					List:   []syntax.Stmt{callReturnStmt},
   260  					Rbrace: d.Pos(),
   261  				}
   262  				blockStmt.SetPos(d.Pos())
   263  			}
   264  
   265  			emulation = ps(&syntax.IfStmt{
   266  				Cond: cond,
   267  				Then: blockStmt,
   268  			})
   269  			continue
   270  		}
   271  
   272  		var caseBody []syntax.Stmt
   273  		// assume that 128 is a case; when we do scalable simd, this may change.
   274  		// For now, if there is emulation, it is 128-bit (only).
   275  		if emulation != nil && k == 128 {
   276  			caseBody = append(caseBody, emulation)
   277  			emulation = nil
   278  		}
   279  
   280  		caseClause := &syntax.CaseClause{
   281  			Cases: pe(&syntax.BasicLit{Kind: syntax.IntLit, Value: fmt.Sprintf("%d", k)}),
   282  			Body:  append(caseBody, callReturnStmt),
   283  		}
   284  		caseClause.SetPos(d.Pos())
   285  		switchStmt.Body = append(switchStmt.Body, caseClause)
   286  	}
   287  
   288  	fnName := "panic"
   289  	fnIdent := pe(syntax.NewName(d.Pos(), fnName))
   290  
   291  	callExpr := pe(&syntax.CallExpr{
   292  		Fun:     fnIdent,
   293  		ArgList: []syntax.Expr{pe(&syntax.BasicLit{Value: "\"unsupported vector size in simd-rewritten code\"", Kind: syntax.StringLit})},
   294  	})
   295  
   296  	panicStmt := &syntax.ExprStmt{X: callExpr}
   297  	blockStmt := &syntax.BlockStmt{List: []syntax.Stmt{ps(switchStmt), ps(panicStmt)}}
   298  
   299  	blockStmt.SetPos(d.Pos())
   300  
   301  	return blockStmt
   302  }
   303  
   304  func (r *Rewriter) generateForSize(fileAST *syntax.File, k int, newDecls []syntax.Decl) []syntax.Decl {
   305  	copier := NewDeepCopier(r.pkg, r.info, k, r.analyzer, fmt.Sprintf("@simd%d", k))
   306  	for _, decl := range fileAST.DeclList {
   307  		if r.shouldIncludeDecl(decl) {
   308  			newDecl := copier.CopyDecl(decl)
   309  			newDecls = append(newDecls, newDecl)
   310  		}
   311  	}
   312  	return newDecls
   313  }
   314  
   315  func nameToElemBitWidth(name string) int {
   316  	var width int
   317  	switch name {
   318  	case "Int8s", "Uint8s", "Mask8s":
   319  		width = 8
   320  	case "Int16s", "Uint16s", "Mask16s":
   321  		width = 16
   322  	case "Int32s", "Uint32s", "Float32s", "Mask32s":
   323  		width = 32
   324  	case "Int64s", "Uint64s", "Float64s", "Mask64s":
   325  		width = 64
   326  	}
   327  	return width
   328  }
   329  
   330  func (r *Rewriter) shouldIncludeDecl(decl syntax.Decl) bool {
   331  	// Files (and declarations) in the simd package are excluded
   332  	// from processing, except for those that whose name begins
   333  	// with "tofrom_".
   334  	if r.analyzer.inSimd {
   335  		theFile := decl.Pos().Base().Filename()
   336  
   337  		lastSlash := strings.LastIndex(theFile, simdPkg+"/")
   338  		lastBackslash := strings.LastIndex(theFile, simdPkg+"\\")
   339  
   340  		// Windows paths can be chaos, all we care, is whether the very last part
   341  		// of the path is any-path-separator + "tofrom_" + anything-else, given that
   342  		// we already know that we are in the simd package.
   343  		maxSlash := max(lastSlash, lastBackslash)
   344  		if maxSlash == -1 {
   345  			return false
   346  		}
   347  		if !strings.HasPrefix(theFile[maxSlash:], simdPkg+"/tofrom_") &&
   348  			!strings.HasPrefix(theFile[maxSlash:], simdPkg+"\\tofrom_") {
   349  			return false
   350  		}
   351  	}
   352  
   353  	switch d := decl.(type) {
   354  	case *syntax.FuncDecl:
   355  		if d.Name != nil {
   356  			return r.analyzer.dependentObj[r.info.Defs[d.Name]]
   357  		}
   358  	case *syntax.TypeDecl:
   359  		return r.analyzer.dependentObj[r.info.Defs[d.Name]]
   360  	case *syntax.VarDecl:
   361  		for _, name := range d.NameList {
   362  			if r.analyzer.dependentObj[r.info.Defs[name]] {
   363  				return true
   364  			}
   365  		}
   366  	}
   367  	return false
   368  }
   369  
   370  // Generate an API matching the standalone compilation call
   371  func RewriteWrapper(pkg *types2.Package, info *types2.Info, files []*syntax.File) bool {
   372  	if !buildcfg.Experiment.SIMD {
   373  		return false
   374  	}
   375  
   376  	switch buildcfg.GOARCH {
   377  	case "wasm", "amd64", "arm64":
   378  	default:
   379  		return false
   380  	}
   381  
   382  	sizes := rewriteSizes()
   383  	if len(sizes) == 0 {
   384  		return false
   385  	}
   386  	analyzer := NewAnalyzer(pkg, info)
   387  	if !analyzer.Analyze(files) {
   388  		return false
   389  	}
   390  
   391  	CheckPositions(files, "before midway")
   392  
   393  	rewriter := NewRewriter(pkg, info, analyzer, sizes)
   394  	rewriter.Rewrite(files)
   395  
   396  	CheckPositions(files, "after midway")
   397  
   398  	return true
   399  }
   400  

View as plain text