1
2
3
4
5 package main
6
7
8
9
10 import (
11 "bufio"
12 "bytes"
13 "flag"
14 "fmt"
15 "go/format"
16 "io"
17 "os"
18 "strings"
19 "text/template"
20 )
21
22 type resultTypeFunc func(t string, w, c int) (ot string, ow int, oc int)
23
24
25 type shapes struct {
26 vecs []int
27 ints []int
28 uints []int
29 floats []int
30 output resultTypeFunc
31 }
32
33
34 type shapeAndTemplate struct {
35 s *shapes
36 t *template.Template
37 }
38
39 func (sat shapeAndTemplate) target(outType string, width int) shapeAndTemplate {
40 newSat := sat
41 newShape := *sat.s
42 newShape.output = func(t string, w, c int) (ot string, ow int, oc int) {
43 oc = c
44 if width*c > 512 {
45 oc = 512 / width
46 } else if width*c < 128 {
47 oc = 128 / width
48 }
49 return outType, width, oc
50 }
51 newSat.s = &newShape
52 return newSat
53 }
54
55 func (sat shapeAndTemplate) targetFixed(outType string, width, count int) shapeAndTemplate {
56 newSat := sat
57 newShape := *sat.s
58 newShape.output = func(t string, w, c int) (ot string, ow int, oc int) {
59 return outType, width, count
60 }
61 newSat.s = &newShape
62 return newSat
63 }
64
65 func (s *shapes) forAllShapes(f func(seq int, t, upperT string, w, c int, out io.Writer), out io.Writer) {
66 vecs := s.vecs
67 ints := s.ints
68 uints := s.uints
69 floats := s.floats
70 seq := 0
71 for _, v := range vecs {
72 for _, w := range ints {
73 c := v / w
74 f(seq, "int", "Int", w, c, out)
75 seq++
76 }
77 for _, w := range uints {
78 c := v / w
79 f(seq, "uint", "Uint", w, c, out)
80 seq++
81 }
82 for _, w := range floats {
83 c := v / w
84 f(seq, "float", "Float", w, c, out)
85 seq++
86 }
87 }
88 }
89
90 var allShapes = &shapes{
91 vecs: []int{128, 256, 512},
92 ints: []int{8, 16, 32, 64},
93 uints: []int{8, 16, 32, 64},
94 floats: []int{32, 64},
95 }
96
97 var intShapes = &shapes{
98 vecs: []int{128, 256, 512},
99 ints: []int{8, 16, 32, 64},
100 }
101
102 var uintShapes = &shapes{
103 vecs: []int{128, 256, 512},
104 uints: []int{8, 16, 32, 64},
105 }
106
107 var floatShapes = &shapes{
108 vecs: []int{128, 256, 512},
109 floats: []int{32, 64},
110 }
111
112 var integerShapes = &shapes{
113 vecs: []int{128, 256, 512},
114 ints: []int{8, 16, 32, 64},
115 uints: []int{8, 16, 32, 64},
116 }
117
118 var avx512Shapes = &shapes{
119 vecs: []int{512},
120 ints: []int{8, 16, 32, 64},
121 uints: []int{8, 16, 32, 64},
122 floats: []int{32, 64},
123 }
124
125 var avx2Shapes = &shapes{
126 vecs: []int{128, 256},
127 ints: []int{8, 16, 32, 64},
128 uints: []int{8, 16, 32, 64},
129 floats: []int{32, 64},
130 }
131
132 var avx2MaskedLoadShapes = &shapes{
133 vecs: []int{128, 256},
134 ints: []int{32, 64},
135 uints: []int{32, 64},
136 floats: []int{32, 64},
137 }
138
139 var avx2SmallLoadPunShapes = &shapes{
140
141 vecs: []int{128, 256},
142 uints: []int{8, 16},
143 }
144
145 var unaryFlaky = &shapes{
146 vecs: []int{128, 256, 512},
147 floats: []int{32, 64},
148 }
149
150 var ternaryFlaky = &shapes{
151 vecs: []int{128, 256, 512},
152 floats: []int{32},
153 }
154
155 var avx2SignedComparisons = &shapes{
156 vecs: []int{128, 256},
157 ints: []int{8, 16, 32, 64},
158 }
159
160 var avx2UnsignedComparisons = &shapes{
161 vecs: []int{128, 256},
162 uints: []int{8, 16, 32, 64},
163 }
164
165 type templateData struct {
166 VType string
167 AOrAn string
168 EWidth int
169 Vwidth int
170 Count int
171 WxC string
172 BxC string
173 Base string
174 Etype string
175 OxFF string
176
177 OVType string
178 OEtype string
179 OEType string
180 OCount int
181 }
182
183 func (t templateData) As128BitVec() string {
184 return fmt.Sprintf("%s%dx%d", t.Base, t.EWidth, 128/t.EWidth)
185 }
186
187 func oneTemplate(t *template.Template, baseType string, width, count int, out io.Writer, rtf resultTypeFunc) {
188 b := width * count
189 if b < 128 || b > 512 {
190 return
191 }
192
193 ot, ow, oc := baseType, width, count
194 if rtf != nil {
195 ot, ow, oc = rtf(ot, ow, oc)
196 if ow*oc > 512 || ow*oc < 128 || ow < 8 || ow > 64 {
197 return
198 }
199
200 if ot == "float" && ow < 32 {
201 return
202 }
203 }
204 ovType := fmt.Sprintf("%s%dx%d", strings.ToUpper(ot[:1])+ot[1:], ow, oc)
205 oeType := fmt.Sprintf("%s%d", ot, ow)
206 oEType := fmt.Sprintf("%s%d", strings.ToUpper(ot[:1])+ot[1:], ow)
207
208 wxc := fmt.Sprintf("%dx%d", width, count)
209 BaseType := strings.ToUpper(baseType[:1]) + baseType[1:]
210 vType := fmt.Sprintf("%s%s", BaseType, wxc)
211 eType := fmt.Sprintf("%s%d", baseType, width)
212
213 bxc := fmt.Sprintf("%dx%d", 8, count*(width/8))
214 aOrAn := "a"
215 if strings.Contains("aeiou", baseType[:1]) {
216 aOrAn = "an"
217 }
218 oxFF := fmt.Sprintf("0x%x", uint64((1<<count)-1))
219 t.Execute(out, templateData{
220 VType: vType,
221 AOrAn: aOrAn,
222 EWidth: width,
223 Vwidth: b,
224 Count: count,
225 WxC: wxc,
226 BxC: bxc,
227 Base: BaseType,
228 Etype: eType,
229 OxFF: oxFF,
230 OVType: ovType,
231 OEtype: oeType,
232 OCount: oc,
233 OEType: oEType,
234 })
235 }
236
237
238
239 func (sat shapeAndTemplate) forTemplates(out io.Writer) {
240 t, s := sat.t, sat.s
241 vecs := s.vecs
242 ints := s.ints
243 uints := s.uints
244 floats := s.floats
245 for _, v := range vecs {
246 for _, w := range ints {
247 c := v / w
248 oneTemplate(t, "int", w, c, out, sat.s.output)
249 }
250 for _, w := range uints {
251 c := v / w
252 oneTemplate(t, "uint", w, c, out, sat.s.output)
253 }
254 for _, w := range floats {
255 c := v / w
256 oneTemplate(t, "float", w, c, out, sat.s.output)
257 }
258 }
259 }
260
261 func prologue(s string, out io.Writer) {
262 fmt.Fprintf(out,
263 `// Code generated by '%s'; DO NOT EDIT.
264
265 //go:build goexperiment.simd
266
267 package archsimd
268
269 `, s)
270 }
271
272 func ssaPrologue(s string, out io.Writer) {
273 fmt.Fprintf(out,
274 `// Code generated by '%s'; DO NOT EDIT.
275
276 package ssa
277
278 `, s)
279 }
280
281 func unsafePrologue(s string, out io.Writer) {
282 fmt.Fprintf(out,
283 `// Code generated by '%s'; DO NOT EDIT.
284
285 //go:build goexperiment.simd
286
287 package archsimd
288
289 import "unsafe"
290
291 `, s)
292 }
293
294 func testPrologue(t, s string, out io.Writer) {
295 fmt.Fprintf(out,
296 `// Code generated by '%s'; DO NOT EDIT.
297
298 //go:build goexperiment.simd && amd64
299
300 // This file contains functions testing %s.
301 // Each function in this file is specialized for a
302 // particular simd type <BaseType><Width>x<Count>.
303
304 package simd_test
305
306 import (
307 "simd/archsimd"
308 "testing"
309 )
310
311 `, s, t)
312 }
313
314 func curryTestPrologue(t string) func(s string, out io.Writer) {
315 return func(s string, out io.Writer) {
316 testPrologue(t, s, out)
317 }
318 }
319
320 func templateOf(name, temp string) shapeAndTemplate {
321 return shapeAndTemplate{s: allShapes,
322 t: template.Must(template.New(name).Parse(temp))}
323 }
324
325 func shapedTemplateOf(s *shapes, name, temp string) shapeAndTemplate {
326 return shapeAndTemplate{s: s,
327 t: template.Must(template.New(name).Parse(temp))}
328 }
329
330 var sliceTemplate = templateOf("slice", `
331 // Load{{.VType}}Slice loads {{.AOrAn}} {{.VType}} from a slice of at least {{.Count}} {{.Etype}}s.
332 func Load{{.VType}}Slice(s []{{.Etype}}) {{.VType}} {
333 return Load{{.VType}}((*[{{.Count}}]{{.Etype}})(s))
334 }
335
336 // StoreSlice stores x into a slice of at least {{.Count}} {{.Etype}}s.
337 func (x {{.VType}}) StoreSlice(s []{{.Etype}}) {
338 x.Store((*[{{.Count}}]{{.Etype}})(s))
339 }
340 `)
341
342 var unaryTemplate = templateOf("unary_helpers", `
343 // test{{.VType}}Unary tests the simd unary method f against the expected behavior generated by want
344 func test{{.VType}}Unary(t *testing.T, f func(_ archsimd.{{.VType}}) archsimd.{{.VType}}, want func(_ []{{.Etype}}) []{{.Etype}}) {
345 n := {{.Count}}
346 t.Helper()
347 forSlice(t, {{.Etype}}s, n, func(x []{{.Etype}}) bool {
348 t.Helper()
349 a := archsimd.Load{{.VType}}Slice(x)
350 g := make([]{{.Etype}}, n)
351 f(a).StoreSlice(g)
352 w := want(x)
353 return checkSlicesLogInput(t, g, w, 0.0, func() {t.Helper(); t.Logf("x=%v", x)})
354 })
355 }
356 `)
357
358 var unaryFlakyTemplate = shapedTemplateOf(unaryFlaky, "unary_flaky_helpers", `
359 // test{{.VType}}UnaryFlaky tests the simd unary method f against the expected behavior generated by want,
360 // but using a flakiness parameter because we haven't exactly figured out how simd floating point works
361 func test{{.VType}}UnaryFlaky(t *testing.T, f func(x archsimd.{{.VType}}) archsimd.{{.VType}}, want func(x []{{.Etype}}) []{{.Etype}}, flakiness float64) {
362 n := {{.Count}}
363 t.Helper()
364 forSlice(t, {{.Etype}}s, n, func(x []{{.Etype}}) bool {
365 t.Helper()
366 a := archsimd.Load{{.VType}}Slice(x)
367 g := make([]{{.Etype}}, n)
368 f(a).StoreSlice(g)
369 w := want(x)
370 return checkSlicesLogInput(t, g, w, flakiness, func() {t.Helper(); t.Logf("x=%v", x)})
371 })
372 }
373 `)
374
375 var convertTemplate = templateOf("convert_helpers", `
376 // test{{.VType}}ConvertTo{{.OEType}} tests the simd conversion method f against the expected behavior generated by want.
377 // This is for count-preserving conversions, so if there is a change in size, then there is a change in vector width,
378 // (extended to at least 128 bits, or truncated to at most 512 bits).
379 func test{{.VType}}ConvertTo{{.OEType}}(t *testing.T, f func(x archsimd.{{.VType}}) archsimd.{{.OVType}}, want func(x []{{.Etype}}) []{{.OEtype}}) {
380 n := {{.Count}}
381 t.Helper()
382 forSlice(t, {{.Etype}}s, n, func(x []{{.Etype}}) bool {
383 t.Helper()
384 a := archsimd.Load{{.VType}}Slice(x)
385 g := make([]{{.OEtype}}, {{.OCount}})
386 f(a).StoreSlice(g)
387 w := want(x)
388 return checkSlicesLogInput(t, g, w, 0.0, func() {t.Helper(); t.Logf("x=%v", x)})
389 })
390 }
391 `)
392
393 var (
394
395
396
397 unaryToInt8 = convertTemplate.target("int", 8)
398 unaryToUint8 = convertTemplate.target("uint", 8)
399 unaryToInt16 = convertTemplate.target("int", 16)
400 unaryToUint16 = convertTemplate.target("uint", 16)
401 unaryToInt32 = convertTemplate.target("int", 32)
402 unaryToUint32 = convertTemplate.target("uint", 32)
403 unaryToInt64 = convertTemplate.target("int", 64)
404 unaryToUint64 = convertTemplate.target("uint", 64)
405 unaryToFloat32 = convertTemplate.target("float", 32)
406 unaryToFloat64 = convertTemplate.target("float", 64)
407 )
408
409 var convertLoTemplate = shapedTemplateOf(integerShapes, "convert_lo_helpers", `
410 // test{{.VType}}ConvertLoTo{{.OVType}} tests the simd conversion method f against the expected behavior generated by want.
411 // This converts only the low {{.OCount}} elements.
412 func test{{.VType}}ConvertLoTo{{.OVType}}(t *testing.T, f func(x archsimd.{{.VType}}) archsimd.{{.OVType}}, want func(x []{{.Etype}}) []{{.OEtype}}) {
413 n := {{.Count}}
414 t.Helper()
415 forSlice(t, {{.Etype}}s, n, func(x []{{.Etype}}) bool {
416 t.Helper()
417 a := archsimd.Load{{.VType}}Slice(x)
418 g := make([]{{.OEtype}}, {{.OCount}})
419 f(a).StoreSlice(g)
420 w := want(x)
421 return checkSlicesLogInput(t, g, w, 0.0, func() {t.Helper(); t.Logf("x=%v", x)})
422 })
423 }
424 `)
425
426 var (
427
428
429
430
431
432 unaryToInt64x2 = convertLoTemplate.targetFixed("int", 64, 2)
433 unaryToInt64x4 = convertLoTemplate.targetFixed("int", 64, 4)
434 unaryToUint64x2 = convertLoTemplate.targetFixed("uint", 64, 2)
435 unaryToUint64x4 = convertLoTemplate.targetFixed("uint", 64, 4)
436 unaryToInt32x4 = convertLoTemplate.targetFixed("int", 32, 4)
437 unaryToInt32x8 = convertLoTemplate.targetFixed("int", 32, 8)
438 unaryToUint32x4 = convertLoTemplate.targetFixed("uint", 32, 4)
439 unaryToUint32x8 = convertLoTemplate.targetFixed("uint", 32, 8)
440 unaryToInt16x8 = convertLoTemplate.targetFixed("int", 16, 8)
441 unaryToUint16x8 = convertLoTemplate.targetFixed("uint", 16, 8)
442 )
443
444 var binaryTemplate = templateOf("binary_helpers", `
445 // test{{.VType}}Binary tests the simd binary method f against the expected behavior generated by want
446 func test{{.VType}}Binary(t *testing.T, f func(_, _ archsimd.{{.VType}}) archsimd.{{.VType}}, want func(_, _ []{{.Etype}}) []{{.Etype}}) {
447 n := {{.Count}}
448 t.Helper()
449 forSlicePair(t, {{.Etype}}s, n, func(x, y []{{.Etype}}) bool {
450 t.Helper()
451 a := archsimd.Load{{.VType}}Slice(x)
452 b := archsimd.Load{{.VType}}Slice(y)
453 g := make([]{{.Etype}}, n)
454 f(a, b).StoreSlice(g)
455 w := want(x, y)
456 return checkSlicesLogInput(t, g, w, 0.0, func() {t.Helper(); t.Logf("x=%v", x); t.Logf("y=%v", y); })
457 })
458 }
459 `)
460
461 var ternaryTemplate = templateOf("ternary_helpers", `
462 // test{{.VType}}Ternary tests the simd ternary method f against the expected behavior generated by want
463 func test{{.VType}}Ternary(t *testing.T, f func(_, _, _ archsimd.{{.VType}}) archsimd.{{.VType}}, want func(_, _, _ []{{.Etype}}) []{{.Etype}}) {
464 n := {{.Count}}
465 t.Helper()
466 forSliceTriple(t, {{.Etype}}s, n, func(x, y, z []{{.Etype}}) bool {
467 t.Helper()
468 a := archsimd.Load{{.VType}}Slice(x)
469 b := archsimd.Load{{.VType}}Slice(y)
470 c := archsimd.Load{{.VType}}Slice(z)
471 g := make([]{{.Etype}}, n)
472 f(a, b, c).StoreSlice(g)
473 w := want(x, y, z)
474 return checkSlicesLogInput(t, g, w, 0.0, func() {t.Helper(); t.Logf("x=%v", x); t.Logf("y=%v", y); t.Logf("z=%v", z); })
475 })
476 }
477 `)
478
479 var ternaryFlakyTemplate = shapedTemplateOf(ternaryFlaky, "ternary_helpers", `
480 // test{{.VType}}TernaryFlaky tests the simd ternary method f against the expected behavior generated by want,
481 // but using a flakiness parameter because we haven't exactly figured out how simd floating point works
482 func test{{.VType}}TernaryFlaky(t *testing.T, f func(x, y, z archsimd.{{.VType}}) archsimd.{{.VType}}, want func(x, y, z []{{.Etype}}) []{{.Etype}}, flakiness float64) {
483 n := {{.Count}}
484 t.Helper()
485 forSliceTriple(t, {{.Etype}}s, n, func(x, y, z []{{.Etype}}) bool {
486 t.Helper()
487 a := archsimd.Load{{.VType}}Slice(x)
488 b := archsimd.Load{{.VType}}Slice(y)
489 c := archsimd.Load{{.VType}}Slice(z)
490 g := make([]{{.Etype}}, n)
491 f(a, b, c).StoreSlice(g)
492 w := want(x, y, z)
493 return checkSlicesLogInput(t, g, w, flakiness, func() {t.Helper(); t.Logf("x=%v", x); t.Logf("y=%v", y); t.Logf("z=%v", z); })
494 })
495 }
496 `)
497
498 var compareTemplate = templateOf("compare_helpers", `
499 // test{{.VType}}Compare tests the simd comparison method f against the expected behavior generated by want
500 func test{{.VType}}Compare(t *testing.T, f func(_, _ archsimd.{{.VType}}) archsimd.Mask{{.WxC}}, want func(_, _ []{{.Etype}}) []int64) {
501 n := {{.Count}}
502 t.Helper()
503 forSlicePair(t, {{.Etype}}s, n, func(x, y []{{.Etype}}) bool {
504 t.Helper()
505 a := archsimd.Load{{.VType}}Slice(x)
506 b := archsimd.Load{{.VType}}Slice(y)
507 g := make([]int{{.EWidth}}, n)
508 f(a, b).ToInt{{.WxC}}().StoreSlice(g)
509 w := want(x, y)
510 return checkSlicesLogInput(t, s64(g), w, 0.0, func() {t.Helper(); t.Logf("x=%v", x); t.Logf("y=%v", y); })
511 })
512 }
513 `)
514
515 var compareUnaryTemplate = shapedTemplateOf(floatShapes, "compare_unary_helpers", `
516 // test{{.VType}}UnaryCompare tests the simd unary comparison method f against the expected behavior generated by want
517 func test{{.VType}}UnaryCompare(t *testing.T, f func(x archsimd.{{.VType}}) archsimd.Mask{{.WxC}}, want func(x []{{.Etype}}) []int64) {
518 n := {{.Count}}
519 t.Helper()
520 forSlice(t, {{.Etype}}s, n, func(x []{{.Etype}}) bool {
521 t.Helper()
522 a := archsimd.Load{{.VType}}Slice(x)
523 g := make([]int{{.EWidth}}, n)
524 f(a).ToInt{{.WxC}}().StoreSlice(g)
525 w := want(x)
526 return checkSlicesLogInput(t, s64(g), w, 0.0, func() {t.Helper(); t.Logf("x=%v", x)})
527 })
528 }
529 `)
530
531
532 var compareMaskedTemplate = templateOf("comparemasked_helpers", `
533 // test{{.VType}}CompareMasked tests the simd masked comparison method f against the expected behavior generated by want
534 // The mask is applied to the output of want; anything not in the mask, is zeroed.
535 func test{{.VType}}CompareMasked(t *testing.T,
536 f func(_, _ archsimd.{{.VType}}, m archsimd.Mask{{.WxC}}) archsimd.Mask{{.WxC}},
537 want func(_, _ []{{.Etype}}) []int64) {
538 n := {{.Count}}
539 t.Helper()
540 forSlicePairMasked(t, {{.Etype}}s, n, func(x, y []{{.Etype}}, m []bool) bool {
541 t.Helper()
542 a := archsimd.Load{{.VType}}Slice(x)
543 b := archsimd.Load{{.VType}}Slice(y)
544 k := archsimd.LoadInt{{.WxC}}Slice(toVect[int{{.EWidth}}](m)).ToMask()
545 g := make([]int{{.EWidth}}, n)
546 f(a, b, k).ToInt{{.WxC}}().StoreSlice(g)
547 w := want(x, y)
548 for i := range m {
549 if !m[i] {
550 w[i] = 0
551 }
552 }
553 return checkSlicesLogInput(t, s64(g), w, 0.0, func() {t.Helper(); t.Logf("x=%v", x); t.Logf("y=%v", y); t.Logf("m=%v", m); })
554 })
555 }
556 `)
557
558 var avx512MaskedLoadSlicePartTemplate = shapedTemplateOf(avx512Shapes, "avx 512 load slice part", `
559 // Load{{.VType}}SlicePart loads a {{.VType}} from the slice s.
560 // If s has fewer than {{.Count}} elements, the remaining elements of the vector are filled with zeroes.
561 // If s has {{.Count}} or more elements, the function is equivalent to Load{{.VType}}Slice.
562 func Load{{.VType}}SlicePart(s []{{.Etype}}) {{.VType}} {
563 l := len(s)
564 if l >= {{.Count}} {
565 return Load{{.VType}}Slice(s)
566 }
567 if l == 0 {
568 var x {{.VType}}
569 return x
570 }
571 mask := Mask{{.WxC}}FromBits({{.OxFF}} >> ({{.Count}} - l))
572 return LoadMasked{{.VType}}(pa{{.VType}}(s), mask)
573 }
574
575 // StoreSlicePart stores the {{.Count}} elements of x into the slice s.
576 // It stores as many elements as will fit in s.
577 // If s has {{.Count}} or more elements, the method is equivalent to x.StoreSlice.
578 func (x {{.VType}}) StoreSlicePart(s []{{.Etype}}) {
579 l := len(s)
580 if l >= {{.Count}} {
581 x.StoreSlice(s)
582 return
583 }
584 if l == 0 {
585 return
586 }
587 mask := Mask{{.WxC}}FromBits({{.OxFF}} >> ({{.Count}} - l))
588 x.StoreMasked(pa{{.VType}}(s), mask)
589 }
590 `)
591
592 var avx2MaskedLoadSlicePartTemplate = shapedTemplateOf(avx2MaskedLoadShapes, "avx 2 load slice part", `
593 // Load{{.VType}}SlicePart loads a {{.VType}} from the slice s.
594 // If s has fewer than {{.Count}} elements, the remaining elements of the vector are filled with zeroes.
595 // If s has {{.Count}} or more elements, the function is equivalent to Load{{.VType}}Slice.
596 func Load{{.VType}}SlicePart(s []{{.Etype}}) {{.VType}} {
597 l := len(s)
598 if l >= {{.Count}} {
599 return Load{{.VType}}Slice(s)
600 }
601 if l == 0 {
602 var x {{.VType}}
603 return x
604 }
605 mask := vecMask{{.EWidth}}[len(vecMask{{.EWidth}})/2-l:]
606 return LoadMasked{{.VType}}(pa{{.VType}}(s), LoadInt{{.WxC}}Slice(mask).asMask())
607 }
608
609 // StoreSlicePart stores the {{.Count}} elements of x into the slice s.
610 // It stores as many elements as will fit in s.
611 // If s has {{.Count}} or more elements, the method is equivalent to x.StoreSlice.
612 func (x {{.VType}}) StoreSlicePart(s []{{.Etype}}) {
613 l := len(s)
614 if l >= {{.Count}} {
615 x.StoreSlice(s)
616 return
617 }
618 if l == 0 {
619 return
620 }
621 mask := vecMask{{.EWidth}}[len(vecMask{{.EWidth}})/2-l:]
622 x.StoreMasked(pa{{.VType}}(s), LoadInt{{.WxC}}Slice(mask).asMask())
623 }
624 `)
625
626 var avx2SmallLoadSlicePartTemplate = shapedTemplateOf(avx2SmallLoadPunShapes, "avx 2 small load slice part", `
627 // Load{{.VType}}SlicePart loads a {{.VType}} from the slice s.
628 // If s has fewer than {{.Count}} elements, the remaining elements of the vector are filled with zeroes.
629 // If s has {{.Count}} or more elements, the function is equivalent to Load{{.VType}}Slice.
630 func Load{{.VType}}SlicePart(s []{{.Etype}}) {{.VType}} {
631 if len(s) == 0 {
632 var zero {{.VType}}
633 return zero
634 }
635 t := unsafe.Slice((*int{{.EWidth}})(unsafe.Pointer(&s[0])), len(s))
636 return LoadInt{{.WxC}}SlicePart(t).As{{.VType}}()
637 }
638
639 // StoreSlicePart stores the {{.Count}} elements of x into the slice s.
640 // It stores as many elements as will fit in s.
641 // If s has {{.Count}} or more elements, the method is equivalent to x.StoreSlice.
642 func (x {{.VType}}) StoreSlicePart(s []{{.Etype}}) {
643 if len(s) == 0 {
644 return
645 }
646 t := unsafe.Slice((*int{{.EWidth}})(unsafe.Pointer(&s[0])), len(s))
647 x.AsInt{{.WxC}}().StoreSlicePart(t)
648 }
649 `)
650
651 func (t templateData) CPUfeature() string {
652 switch t.Vwidth {
653 case 128:
654 return "AVX"
655 case 256:
656 return "AVX2"
657 case 512:
658 return "AVX512"
659 }
660 panic(fmt.Errorf("unexpected vector width %d", t.Vwidth))
661 }
662
663 var avx2SignedComparisonsTemplate = shapedTemplateOf(avx2SignedComparisons, "avx2 signed comparisons", `
664 // Less returns a mask whose elements indicate whether x < y.
665 //
666 // Emulated, CPU Feature: {{.CPUfeature}}
667 func (x {{.VType}}) Less(y {{.VType}}) Mask{{.WxC}} {
668 return y.Greater(x)
669 }
670
671 // GreaterEqual returns a mask whose elements indicate whether x >= y.
672 //
673 // Emulated, CPU Feature: {{.CPUfeature}}
674 func (x {{.VType}}) GreaterEqual(y {{.VType}}) Mask{{.WxC}} {
675 ones := x.Equal(x).ToInt{{.WxC}}()
676 return y.Greater(x).ToInt{{.WxC}}().Xor(ones).asMask()
677 }
678
679 // LessEqual returns a mask whose elements indicate whether x <= y.
680 //
681 // Emulated, CPU Feature: {{.CPUfeature}}
682 func (x {{.VType}}) LessEqual(y {{.VType}}) Mask{{.WxC}} {
683 ones := x.Equal(x).ToInt{{.WxC}}()
684 return x.Greater(y).ToInt{{.WxC}}().Xor(ones).asMask()
685 }
686
687 // NotEqual returns a mask whose elements indicate whether x != y.
688 //
689 // Emulated, CPU Feature: {{.CPUfeature}}
690 func (x {{.VType}}) NotEqual(y {{.VType}}) Mask{{.WxC}} {
691 ones := x.Equal(x).ToInt{{.WxC}}()
692 return x.Equal(y).ToInt{{.WxC}}().Xor(ones).asMask()
693 }
694 `)
695
696 var bitWiseIntTemplate = shapedTemplateOf(intShapes, "bitwise int complement", `
697 // Not returns the bitwise complement of x.
698 //
699 // Emulated, CPU Feature: {{.CPUfeature}}
700 func (x {{.VType}}) Not() {{.VType}} {
701 return x.Xor(x.Equal(x).ToInt{{.WxC}}())
702 }
703 `)
704
705 var bitWiseUintTemplate = shapedTemplateOf(uintShapes, "bitwise uint complement", `
706 // Not returns the bitwise complement of x.
707 //
708 // Emulated, CPU Feature: {{.CPUfeature}}
709 func (x {{.VType}}) Not() {{.VType}} {
710 return x.Xor(x.Equal(x).ToInt{{.WxC}}().As{{.VType}}())
711 }
712 `)
713
714
715
716
717
718
719 func (t templateData) CPUfeatureAVX2if8() string {
720 if t.EWidth == 8 {
721 return "AVX2"
722 }
723 return t.CPUfeature()
724 }
725
726 var avx2UnsignedComparisonsTemplate = shapedTemplateOf(avx2UnsignedComparisons, "avx2 unsigned comparisons", `
727 // Greater returns a mask whose elements indicate whether x > y.
728 //
729 // Emulated, CPU Feature: {{.CPUfeatureAVX2if8}}
730 func (x {{.VType}}) Greater(y {{.VType}}) Mask{{.WxC}} {
731 a, b := x.AsInt{{.WxC}}(), y.AsInt{{.WxC}}()
732 {{- if eq .EWidth 8}}
733 signs := BroadcastInt{{.WxC}}(-1 << ({{.EWidth}}-1))
734 {{- else}}
735 ones := x.Equal(x).ToInt{{.WxC}}()
736 signs := ones.ShiftAllLeft({{.EWidth}}-1)
737 {{- end }}
738 return a.Xor(signs).Greater(b.Xor(signs))
739 }
740
741 // Less returns a mask whose elements indicate whether x < y.
742 //
743 // Emulated, CPU Feature: {{.CPUfeatureAVX2if8}}
744 func (x {{.VType}}) Less(y {{.VType}}) Mask{{.WxC}} {
745 a, b := x.AsInt{{.WxC}}(), y.AsInt{{.WxC}}()
746 {{- if eq .EWidth 8}}
747 signs := BroadcastInt{{.WxC}}(-1 << ({{.EWidth}}-1))
748 {{- else}}
749 ones := x.Equal(x).ToInt{{.WxC}}()
750 signs := ones.ShiftAllLeft({{.EWidth}}-1)
751 {{- end }}
752 return b.Xor(signs).Greater(a.Xor(signs))
753 }
754
755 // GreaterEqual returns a mask whose elements indicate whether x >= y.
756 //
757 // Emulated, CPU Feature: {{.CPUfeatureAVX2if8}}
758 func (x {{.VType}}) GreaterEqual(y {{.VType}}) Mask{{.WxC}} {
759 a, b := x.AsInt{{.WxC}}(), y.AsInt{{.WxC}}()
760 ones := x.Equal(x).ToInt{{.WxC}}()
761 {{- if eq .EWidth 8}}
762 signs := BroadcastInt{{.WxC}}(-1 << ({{.EWidth}}-1))
763 {{- else}}
764 signs := ones.ShiftAllLeft({{.EWidth}}-1)
765 {{- end }}
766 return b.Xor(signs).Greater(a.Xor(signs)).ToInt{{.WxC}}().Xor(ones).asMask()
767 }
768
769 // LessEqual returns a mask whose elements indicate whether x <= y.
770 //
771 // Emulated, CPU Feature: {{.CPUfeatureAVX2if8}}
772 func (x {{.VType}}) LessEqual(y {{.VType}}) Mask{{.WxC}} {
773 a, b := x.AsInt{{.WxC}}(), y.AsInt{{.WxC}}()
774 ones := x.Equal(x).ToInt{{.WxC}}()
775 {{- if eq .EWidth 8}}
776 signs := BroadcastInt{{.WxC}}(-1 << ({{.EWidth}}-1))
777 {{- else}}
778 signs := ones.ShiftAllLeft({{.EWidth}}-1)
779 {{- end }}
780 return a.Xor(signs).Greater(b.Xor(signs)).ToInt{{.WxC}}().Xor(ones).asMask()
781 }
782
783 // NotEqual returns a mask whose elements indicate whether x != y.
784 //
785 // Emulated, CPU Feature: {{.CPUfeature}}
786 func (x {{.VType}}) NotEqual(y {{.VType}}) Mask{{.WxC}} {
787 a, b := x.AsInt{{.WxC}}(), y.AsInt{{.WxC}}()
788 ones := x.Equal(x).ToInt{{.WxC}}()
789 return a.Equal(b).ToInt{{.WxC}}().Xor(ones).asMask()
790 }
791 `)
792
793 var unsafePATemplate = templateOf("unsafe PA helper", `
794 // pa{{.VType}} returns a type-unsafe pointer to array that can
795 // only be used with partial load/store operations that only
796 // access the known-safe portions of the array.
797 func pa{{.VType}}(s []{{.Etype}}) *[{{.Count}}]{{.Etype}} {
798 return (*[{{.Count}}]{{.Etype}})(unsafe.Pointer(&s[0]))
799 }
800 `)
801
802 var avx2MaskedTemplate = shapedTemplateOf(avx2Shapes, "avx2 .Masked methods", `
803 // Masked returns x but with elements zeroed where mask is false.
804 //
805 // Emulated, CPU Feature: {{.CPUfeature}}
806 func (x {{.VType}}) Masked(mask Mask{{.WxC}}) {{.VType}} {
807 im := mask.ToInt{{.WxC}}()
808 {{- if eq .Base "Int" }}
809 return im.And(x)
810 {{- else}}
811 return x.AsInt{{.WxC}}().And(im).As{{.VType}}()
812 {{- end -}}
813 }
814
815 // Merge returns x but with elements set to y where mask is false.
816 //
817 // Emulated, CPU Feature: {{.CPUfeature}}
818 func (x {{.VType}}) Merge(y {{.VType}}, mask Mask{{.WxC}}) {{.VType}} {
819 {{- if eq .BxC .WxC -}}
820 im := mask.ToInt{{.BxC}}()
821 {{- else}}
822 im := mask.ToInt{{.WxC}}().AsInt{{.BxC}}()
823 {{- end -}}
824 {{- if and (eq .Base "Int") (eq .BxC .WxC) }}
825 return y.blend(x, im)
826 {{- else}}
827 ix := x.AsInt{{.BxC}}()
828 iy := y.AsInt{{.BxC}}()
829 return iy.blend(ix, im).As{{.VType}}()
830 {{- end -}}
831 }
832 `)
833
834
835 var avx512MaskedTemplate = shapedTemplateOf(avx512Shapes, "avx512 .Masked methods", `
836 // Masked returns x but with elements zeroed where mask is false.
837 //
838 // Emulated, CPU Feature: AVX512
839 func (x {{.VType}}) Masked(mask Mask{{.WxC}}) {{.VType}} {
840 im := mask.ToInt{{.WxC}}()
841 {{- if eq .Base "Int" }}
842 return im.And(x)
843 {{- else}}
844 return x.AsInt{{.WxC}}().And(im).As{{.VType}}()
845 {{- end -}}
846 }
847
848 // Merge returns x but with elements set to y where mask is false.
849 //
850 // Emulated, CPU Feature: AVX512
851 func (x {{.VType}}) Merge(y {{.VType}}, mask Mask{{.WxC}}) {{.VType}} {
852 {{- if eq .Base "Int" }}
853 return y.blendMasked(x, mask)
854 {{- else}}
855 ix := x.AsInt{{.WxC}}()
856 iy := y.AsInt{{.WxC}}()
857 return iy.blendMasked(ix, mask).As{{.VType}}()
858 {{- end -}}
859 }
860 `)
861
862 func (t templateData) CPUfeatureBC() string {
863 switch t.Vwidth {
864 case 128:
865 return "AVX2"
866 case 256:
867 return "AVX2"
868 case 512:
869 if t.EWidth <= 16 {
870 return "AVX512BW"
871 }
872 return "AVX512F"
873 }
874 panic(fmt.Errorf("unexpected vector width %d", t.Vwidth))
875 }
876
877 var broadcastTemplate = templateOf("Broadcast functions", `
878 // Broadcast{{.VType}} returns a vector with the input
879 // x assigned to all elements of the output.
880 //
881 // Emulated, CPU Feature: {{.CPUfeatureBC}}
882 func Broadcast{{.VType}}(x {{.Etype}}) {{.VType}} {
883 var z {{.As128BitVec }}
884 return z.SetElem(0, x).Broadcast1To{{.Count}}()
885 }
886 `)
887
888 var maskCvtTemplate = shapedTemplateOf(intShapes, "Mask conversions", `
889 // ToMask converts from {{.Base}}{{.WxC}} to Mask{{.WxC}}, mask element is set to true when the corresponding vector element is non-zero.
890 func (from {{.Base}}{{.WxC}}) ToMask() (to Mask{{.WxC}}) {
891 return from.NotEqual({{.Base}}{{.WxC}}{})
892 }
893 `)
894
895 var stringTemplate = shapedTemplateOf(allShapes, "String methods", `
896 // String returns a string representation of SIMD vector x.
897 func (x {{.VType}}) String() string {
898 var s [{{.Count}}]{{.Etype}}
899 x.Store(&s)
900 return sliceToString(s[:])
901 }
902 `)
903
904 const SIMD = "../../"
905 const TD = "../../internal/simd_test/"
906 const SSA = "../../../../cmd/compile/internal/ssa/"
907
908 func main() {
909 sl := flag.String("sl", SIMD+"slice_gen_amd64.go", "file name for slice operations")
910 cm := flag.String("cm", SIMD+"compare_gen_amd64.go", "file name for comparison operations")
911 mm := flag.String("mm", SIMD+"maskmerge_gen_amd64.go", "file name for mask/merge operations")
912 op := flag.String("op", SIMD+"other_gen_amd64.go", "file name for other operations")
913 ush := flag.String("ush", SIMD+"unsafe_helpers.go", "file name for unsafe helpers")
914 bh := flag.String("bh", TD+"binary_helpers_test.go", "file name for binary test helpers")
915 uh := flag.String("uh", TD+"unary_helpers_test.go", "file name for unary test helpers")
916 th := flag.String("th", TD+"ternary_helpers_test.go", "file name for ternary test helpers")
917 ch := flag.String("ch", TD+"compare_helpers_test.go", "file name for compare test helpers")
918 cmh := flag.String("cmh", TD+"comparemasked_helpers_test.go", "file name for compare-masked test helpers")
919 flag.Parse()
920
921 if *sl != "" {
922 one(*sl, unsafePrologue,
923 sliceTemplate,
924 avx512MaskedLoadSlicePartTemplate,
925 avx2MaskedLoadSlicePartTemplate,
926 avx2SmallLoadSlicePartTemplate,
927 )
928 }
929 if *cm != "" {
930 one(*cm, prologue,
931 avx2SignedComparisonsTemplate,
932 avx2UnsignedComparisonsTemplate,
933 )
934 }
935 if *mm != "" {
936 one(*mm, prologue,
937 avx2MaskedTemplate,
938 avx512MaskedTemplate,
939 )
940 }
941 if *op != "" {
942 one(*op, prologue,
943 broadcastTemplate,
944 maskCvtTemplate,
945 bitWiseIntTemplate,
946 bitWiseUintTemplate,
947 stringTemplate,
948 )
949 }
950 if *ush != "" {
951 one(*ush, unsafePrologue, unsafePATemplate)
952 }
953 if *uh != "" {
954 one(*uh, curryTestPrologue("unary simd methods"), unaryTemplate,
955 unaryToInt8, unaryToUint8, unaryToInt16, unaryToUint16,
956 unaryToInt32, unaryToUint32, unaryToInt64, unaryToUint64,
957 unaryToFloat32, unaryToFloat64,
958 unaryToInt64x2, unaryToInt64x4,
959 unaryToUint64x2, unaryToUint64x4,
960 unaryToInt32x4, unaryToInt32x8,
961 unaryToUint32x4, unaryToUint32x8,
962 unaryToInt16x8, unaryToUint16x8,
963 unaryFlakyTemplate,
964 )
965 }
966 if *bh != "" {
967 one(*bh, curryTestPrologue("binary simd methods"), binaryTemplate)
968 }
969 if *th != "" {
970 one(*th, curryTestPrologue("ternary simd methods"), ternaryTemplate, ternaryFlakyTemplate)
971 }
972 if *ch != "" {
973 one(*ch, curryTestPrologue("simd methods that compare two operands"), compareTemplate, compareUnaryTemplate)
974 }
975 if *cmh != "" {
976 one(*cmh, curryTestPrologue("simd methods that compare two operands under a mask"), compareMaskedTemplate)
977 }
978
979 nonTemplateRewrites(SSA+"tern_helpers.go", ssaPrologue, classifyBooleanSIMD, ternOpForLogical)
980
981 }
982
983 func ternOpForLogical(out io.Writer) {
984 fmt.Fprintf(out, `
985 func ternOpForLogical(op Op) Op {
986 switch op {
987 `)
988
989 intShapes.forAllShapes(func(seq int, t, upperT string, w, c int, out io.Writer) {
990 wt, ct := w, c
991 if wt < 32 {
992 wt = 32
993 ct = (w * c) / wt
994 }
995 fmt.Fprintf(out, "case OpAndInt%[1]dx%[2]d, OpOrInt%[1]dx%[2]d, OpXorInt%[1]dx%[2]d,OpAndNotInt%[1]dx%[2]d: return OpternInt%dx%d\n", w, c, wt, ct)
996 fmt.Fprintf(out, "case OpAndUint%[1]dx%[2]d, OpOrUint%[1]dx%[2]d, OpXorUint%[1]dx%[2]d,OpAndNotUint%[1]dx%[2]d: return OpternUint%dx%d\n", w, c, wt, ct)
997 }, out)
998
999 fmt.Fprintf(out, `
1000 }
1001 return op
1002 }
1003 `)
1004
1005 }
1006
1007 func classifyBooleanSIMD(out io.Writer) {
1008 fmt.Fprintf(out, `
1009 type SIMDLogicalOP uint8
1010 const (
1011 // boolean simd operations, for reducing expression to VPTERNLOG* instructions
1012 // sloInterior is set for non-root nodes in logical-op expression trees.
1013 // the operations are even-numbered.
1014 sloInterior SIMDLogicalOP = 1
1015 sloNone SIMDLogicalOP = 2 * iota
1016 sloAnd
1017 sloOr
1018 sloAndNot
1019 sloXor
1020 sloNot
1021 )
1022 func classifyBooleanSIMD(v *Value) SIMDLogicalOP {
1023 switch v.Op {
1024 case `)
1025 intShapes.forAllShapes(func(seq int, t, upperT string, w, c int, out io.Writer) {
1026 op := "And"
1027 if seq > 0 {
1028 fmt.Fprintf(out, ",Op%s%s%dx%d", op, upperT, w, c)
1029 } else {
1030 fmt.Fprintf(out, "Op%s%s%dx%d", op, upperT, w, c)
1031 }
1032 seq++
1033 }, out)
1034
1035 fmt.Fprintf(out, `:
1036 return sloAnd
1037
1038 case `)
1039 intShapes.forAllShapes(func(seq int, t, upperT string, w, c int, out io.Writer) {
1040 op := "Or"
1041 if seq > 0 {
1042 fmt.Fprintf(out, ",Op%s%s%dx%d", op, upperT, w, c)
1043 } else {
1044 fmt.Fprintf(out, "Op%s%s%dx%d", op, upperT, w, c)
1045 }
1046 seq++
1047 }, out)
1048
1049 fmt.Fprintf(out, `:
1050 return sloOr
1051
1052 case `)
1053 intShapes.forAllShapes(func(seq int, t, upperT string, w, c int, out io.Writer) {
1054 op := "AndNot"
1055 if seq > 0 {
1056 fmt.Fprintf(out, ",Op%s%s%dx%d", op, upperT, w, c)
1057 } else {
1058 fmt.Fprintf(out, "Op%s%s%dx%d", op, upperT, w, c)
1059 }
1060 seq++
1061 }, out)
1062
1063 fmt.Fprintf(out, `:
1064 return sloAndNot
1065 `)
1066
1067
1068
1069
1070
1071 intShapes.forAllShapes(
1072 func(seq int, t, upperT string, w, c int, out io.Writer) {
1073 fmt.Fprintf(out, "case OpXor%s%dx%d: ", upperT, w, c)
1074 fmt.Fprintf(out, `
1075 if y := v.Args[1]; y.Op == OpEqual%s%dx%d &&
1076 y.Args[0] == y.Args[1] {
1077 return sloNot
1078 }
1079 `, upperT, w, c)
1080 fmt.Fprintf(out, "return sloXor\n")
1081 }, out)
1082
1083 fmt.Fprintf(out, `
1084 }
1085 return sloNone
1086 }
1087 `)
1088 }
1089
1090
1091
1092 func numberLines(data []byte) string {
1093 var buf bytes.Buffer
1094 r := bytes.NewReader(data)
1095 s := bufio.NewScanner(r)
1096 for i := 1; s.Scan(); i++ {
1097 fmt.Fprintf(&buf, "%d: %s\n", i, s.Text())
1098 }
1099 return buf.String()
1100 }
1101
1102 func nonTemplateRewrites(filename string, prologue func(s string, out io.Writer), rewrites ...func(out io.Writer)) {
1103 if filename == "" {
1104 return
1105 }
1106
1107 ofile := os.Stdout
1108
1109 if filename != "-" {
1110 var err error
1111 ofile, err = os.Create(filename)
1112 if err != nil {
1113 fmt.Fprintf(os.Stderr, "Could not create the output file %s for the generated code, %v", filename, err)
1114 os.Exit(1)
1115 }
1116 }
1117
1118 out := new(bytes.Buffer)
1119
1120 prologue("tmplgen", out)
1121 for _, rewrite := range rewrites {
1122 rewrite(out)
1123 }
1124
1125 b, err := format.Source(out.Bytes())
1126 if err != nil {
1127 fmt.Fprintf(os.Stderr, "There was a problem formatting the generated code for %s, %v\n", filename, err)
1128 fmt.Fprintf(os.Stderr, "%s\n", numberLines(out.Bytes()))
1129 fmt.Fprintf(os.Stderr, "There was a problem formatting the generated code for %s, %v\n", filename, err)
1130 os.Exit(1)
1131 } else {
1132 ofile.Write(b)
1133 ofile.Close()
1134 }
1135
1136 }
1137
1138 func one(filename string, prologue func(s string, out io.Writer), sats ...shapeAndTemplate) {
1139 if filename == "" {
1140 return
1141 }
1142
1143 ofile := os.Stdout
1144
1145 if filename != "-" {
1146 var err error
1147 ofile, err = os.Create(filename)
1148 if err != nil {
1149 fmt.Fprintf(os.Stderr, "Could not create the output file %s for the generated code, %v", filename, err)
1150 os.Exit(1)
1151 }
1152 }
1153
1154 out := new(bytes.Buffer)
1155
1156 prologue("tmplgen", out)
1157 for _, sat := range sats {
1158 sat.forTemplates(out)
1159 }
1160
1161 b, err := format.Source(out.Bytes())
1162 if err != nil {
1163 fmt.Fprintf(os.Stderr, "There was a problem formatting the generated code for %s, %v\n", filename, err)
1164 fmt.Fprintf(os.Stderr, "%s\n", numberLines(out.Bytes()))
1165 fmt.Fprintf(os.Stderr, "There was a problem formatting the generated code for %s, %v\n", filename, err)
1166 os.Exit(1)
1167 } else {
1168 ofile.Write(b)
1169 ofile.Close()
1170 }
1171
1172 }
1173
View as plain text