1
2
3
4
5 package midway
6
7 import (
8 "cmd/compile/internal/syntax"
9 "cmd/compile/internal/types2"
10 "fmt"
11 "internal/buildcfg"
12 "strings"
13 )
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40 type Rewriter struct {
41 pkg *types2.Package
42 analyzer *Analyzer
43 info *types2.Info
44 sizes []int
45 }
46
47 func NewRewriter(pkg *types2.Package, info *types2.Info, analyzer *Analyzer, sizes []int) *Rewriter {
48 return &Rewriter{
49 pkg: pkg,
50 info: info,
51 analyzer: analyzer,
52 sizes: sizes,
53 }
54 }
55
56 func (r *Rewriter) Rewrite(files []*syntax.File) {
57
58
59 for _, fileAST := range files {
60
61 var newDecls []syntax.Decl
62 for _, k := range r.sizes {
63 newDecls = r.generateForSize(fileAST, k, newDecls)
64 }
65
66
67 r.generateDispatchers(fileAST)
68
69 fileAST.DeclList = append(fileAST.DeclList, newDecls...)
70 }
71 }
72
73 func (r *Rewriter) generateDispatchers(fileAST *syntax.File) {
74 var newDecls []syntax.Decl
75
76 for _, decl := range fileAST.DeclList {
77 switch d := decl.(type) {
78 case *syntax.FuncDecl:
79 if d.Name == nil {
80 newDecls = append(newDecls, d)
81 continue
82 }
83 obj := r.info.Defs[d.Name]
84 if !r.analyzer.dependentObj[obj] || r.analyzer.inSimd {
85 newDecls = append(newDecls, d)
86 continue
87 }
88
89 sig, ok := obj.Type().(*types2.Signature)
90 if !ok {
91 newDecls = append(newDecls, d)
92 continue
93 }
94
95 if r.analyzer.HasDependentSignature(sig) {
96
97 continue
98 }
99
100
101 d.Body = r.createDispatcherBody(d, sig)
102 newDecls = append(newDecls, d)
103
104 case *syntax.VarDecl:
105
106 keep := false
107 for _, name := range d.NameList {
108 if !r.analyzer.dependentObj[r.info.Defs[name]] {
109 keep = true
110 break
111 }
112 }
113 if keep {
114 newDecls = append(newDecls, d)
115 }
116 case *syntax.TypeDecl:
117 if !r.analyzer.dependentObj[r.info.Defs[d.Name]] || r.analyzer.inSimd {
118 newDecls = append(newDecls, d)
119 }
120 default:
121 newDecls = append(newDecls, decl)
122 }
123 }
124
125 fileAST.DeclList = newDecls
126
127 if !r.analyzer.inSimd {
128
129 hasArchSimd := false
130 var simdImport *syntax.ImportDecl
131 for _, decl := range fileAST.DeclList {
132 if imp, ok := decl.(*syntax.ImportDecl); ok {
133 if imp.Path.Value == `"`+archFullPkg+`"` {
134 hasArchSimd = true
135 }
136 if imp.Path.Value == `"`+simdPkg+`"` {
137 simdImport = imp
138 }
139
140 }
141 }
142 p := simdImport.Pos()
143 if !hasArchSimd {
144 r.injectImport(fileAST, archFullPkg, p)
145 }
146
147
148
149 fun := &syntax.SelectorExpr{
150 X: syntax.NewName(p, simdPkg),
151 Sel: syntax.NewName(p, vectorSizeFn),
152 }
153 fun.SetPos(p)
154 call := &syntax.CallExpr{Fun: fun}
155 call.SetPos(p)
156
157 name := syntax.NewName(p, "_")
158
159 varDecl := &syntax.VarDecl{NameList: []*syntax.Name{name}, Values: call}
160 varDecl.SetPos(p)
161 fileAST.DeclList = append(fileAST.DeclList, varDecl)
162 }
163 }
164
165 func (r *Rewriter) injectImport(fileAST *syntax.File, toImport string, simdImportPos syntax.Pos) {
166 importDecl := &syntax.ImportDecl{
167 Path: &syntax.BasicLit{Value: `"` + toImport + `"`, Kind: syntax.StringLit},
168 }
169 importDecl.Path.SetPos(simdImportPos)
170 importDecl.SetPos(simdImportPos)
171 fileAST.DeclList = append([]syntax.Decl{importDecl}, fileAST.DeclList...)
172 }
173
174 func (r *Rewriter) createDispatcherBody(d *syntax.FuncDecl, sig *types2.Signature) *syntax.BlockStmt {
175
176
177 args := func() []syntax.Expr {
178 var args []syntax.Expr
179 if d.Type.ParamList != nil {
180 for _, field := range d.Type.ParamList {
181 if field.Name != nil {
182 paramName := syntax.NewName(field.Pos(), field.Name.Value)
183 args = append(args, paramName)
184 }
185 }
186 }
187 return args
188 }
189
190
191 pe := func(e syntax.Expr) syntax.Expr {
192 e.SetPos(d.Pos())
193 return e
194 }
195
196 ps := func(e syntax.Stmt) syntax.Stmt {
197 e.SetPos(d.Pos())
198 return e
199 }
200
201
202
203
204
205
206
207
208
209
210
211 switchStmt := &syntax.SwitchStmt{
212 Tag: pe(&syntax.CallExpr{
213 Fun: pe(&syntax.SelectorExpr{
214 X: syntax.NewName(d.Pos(), simdPkg),
215 Sel: syntax.NewName(d.Pos(), vectorSizeFn),
216 }),
217 }),
218 Body: []*syntax.CaseClause{},
219 }
220
221 var emulation syntax.Stmt
222
223 for _, k := range r.sizes {
224 fnName := fmt.Sprintf("%s@simd%d", d.Name.Value, k)
225 fnIdent := syntax.NewName(d.Pos(), fnName)
226
227 callExpr := pe(&syntax.CallExpr{
228 Fun: pe(fnIdent),
229 ArgList: args(),
230 })
231
232
233 var callReturnStmt syntax.Stmt
234 if d.Type.ResultList != nil && len(d.Type.ResultList) > 0 {
235 callReturnStmt = &syntax.ReturnStmt{Results: callExpr}
236 } else {
237 callReturnStmt = &syntax.BlockStmt{
238 List: []syntax.Stmt{
239 ps(&syntax.ExprStmt{X: callExpr}),
240 ps(&syntax.ReturnStmt{}),
241 },
242 Rbrace: d.Pos(),
243 }
244 }
245 callReturnStmt.SetPos(d.Pos())
246
247 if k == 0 {
248
249
250 cond := pe(&syntax.CallExpr{
251 Fun: pe(&syntax.SelectorExpr{
252 X: syntax.NewName(d.Pos(), simdPkg),
253 Sel: syntax.NewName(d.Pos(), emulatedFn),
254 })})
255
256 blockStmt, ok := callReturnStmt.(*syntax.BlockStmt)
257 if !ok {
258 blockStmt = &syntax.BlockStmt{
259 List: []syntax.Stmt{callReturnStmt},
260 Rbrace: d.Pos(),
261 }
262 blockStmt.SetPos(d.Pos())
263 }
264
265 emulation = ps(&syntax.IfStmt{
266 Cond: cond,
267 Then: blockStmt,
268 })
269 continue
270 }
271
272 var caseBody []syntax.Stmt
273
274
275 if emulation != nil && k == 128 {
276 caseBody = append(caseBody, emulation)
277 emulation = nil
278 }
279
280 caseClause := &syntax.CaseClause{
281 Cases: pe(&syntax.BasicLit{Kind: syntax.IntLit, Value: fmt.Sprintf("%d", k)}),
282 Body: append(caseBody, callReturnStmt),
283 }
284 caseClause.SetPos(d.Pos())
285 switchStmt.Body = append(switchStmt.Body, caseClause)
286 }
287
288 fnName := "panic"
289 fnIdent := pe(syntax.NewName(d.Pos(), fnName))
290
291 callExpr := pe(&syntax.CallExpr{
292 Fun: fnIdent,
293 ArgList: []syntax.Expr{pe(&syntax.BasicLit{Value: "\"unsupported vector size in simd-rewritten code\"", Kind: syntax.StringLit})},
294 })
295
296 panicStmt := &syntax.ExprStmt{X: callExpr}
297 blockStmt := &syntax.BlockStmt{List: []syntax.Stmt{ps(switchStmt), ps(panicStmt)}}
298
299 blockStmt.SetPos(d.Pos())
300
301 return blockStmt
302 }
303
304 func (r *Rewriter) generateForSize(fileAST *syntax.File, k int, newDecls []syntax.Decl) []syntax.Decl {
305 copier := NewDeepCopier(r.pkg, r.info, k, r.analyzer, fmt.Sprintf("@simd%d", k))
306 for _, decl := range fileAST.DeclList {
307 if r.shouldIncludeDecl(decl) {
308 newDecl := copier.CopyDecl(decl)
309 newDecls = append(newDecls, newDecl)
310 }
311 }
312 return newDecls
313 }
314
315 func nameToElemBitWidth(name string) int {
316 var width int
317 switch name {
318 case "Int8s", "Uint8s", "Mask8s":
319 width = 8
320 case "Int16s", "Uint16s", "Mask16s":
321 width = 16
322 case "Int32s", "Uint32s", "Float32s", "Mask32s":
323 width = 32
324 case "Int64s", "Uint64s", "Float64s", "Mask64s":
325 width = 64
326 }
327 return width
328 }
329
330 func (r *Rewriter) shouldIncludeDecl(decl syntax.Decl) bool {
331
332
333
334 if r.analyzer.inSimd {
335 theFile := decl.Pos().Base().Filename()
336
337 lastSlash := strings.LastIndex(theFile, simdPkg+"/")
338 lastBackslash := strings.LastIndex(theFile, simdPkg+"\\")
339
340
341
342
343 maxSlash := max(lastSlash, lastBackslash)
344 if maxSlash == -1 {
345 return false
346 }
347 if !strings.HasPrefix(theFile[maxSlash:], simdPkg+"/tofrom_") &&
348 !strings.HasPrefix(theFile[maxSlash:], simdPkg+"\\tofrom_") {
349 return false
350 }
351 }
352
353 switch d := decl.(type) {
354 case *syntax.FuncDecl:
355 if d.Name != nil {
356 return r.analyzer.dependentObj[r.info.Defs[d.Name]]
357 }
358 case *syntax.TypeDecl:
359 return r.analyzer.dependentObj[r.info.Defs[d.Name]]
360 case *syntax.VarDecl:
361 for _, name := range d.NameList {
362 if r.analyzer.dependentObj[r.info.Defs[name]] {
363 return true
364 }
365 }
366 }
367 return false
368 }
369
370
371 func RewriteWrapper(pkg *types2.Package, info *types2.Info, files []*syntax.File) bool {
372 if !buildcfg.Experiment.SIMD {
373 return false
374 }
375
376 switch buildcfg.GOARCH {
377 case "wasm", "amd64", "arm64":
378 default:
379 return false
380 }
381
382 sizes := rewriteSizes()
383 if len(sizes) == 0 {
384 return false
385 }
386 analyzer := NewAnalyzer(pkg, info)
387 if !analyzer.Analyze(files) {
388 return false
389 }
390
391 CheckPositions(files, "before midway")
392
393 rewriter := NewRewriter(pkg, info, analyzer, sizes)
394 rewriter.Rewrite(files)
395
396 CheckPositions(files, "after midway")
397
398 return true
399 }
400
View as plain text