Source file src/internal/synctest/synctest_test.go

     1  // Copyright 2024 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package synctest_test
     6  
     7  import (
     8  	"fmt"
     9  	"internal/synctest"
    10  	"iter"
    11  	"reflect"
    12  	"slices"
    13  	"strconv"
    14  	"sync"
    15  	"testing"
    16  	"time"
    17  )
    18  
    19  func TestNow(t *testing.T) {
    20  	start := time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC).In(time.Local)
    21  	synctest.Run(func() {
    22  		// Time starts at 2000-1-1 00:00:00.
    23  		if got, want := time.Now(), start; !got.Equal(want) {
    24  			t.Errorf("at start: time.Now = %v, want %v", got, want)
    25  		}
    26  		go func() {
    27  			// New goroutines see the same fake clock.
    28  			if got, want := time.Now(), start; !got.Equal(want) {
    29  				t.Errorf("time.Now = %v, want %v", got, want)
    30  			}
    31  		}()
    32  		// Time advances after a sleep.
    33  		time.Sleep(1 * time.Second)
    34  		if got, want := time.Now(), start.Add(1*time.Second); !got.Equal(want) {
    35  			t.Errorf("after sleep: time.Now = %v, want %v", got, want)
    36  		}
    37  	})
    38  }
    39  
    40  // TestMonotonicClock exercises comparing times from within a bubble
    41  // with ones from outside the bubble.
    42  func TestMonotonicClock(t *testing.T) {
    43  	start := time.Now()
    44  	synctest.Run(func() {
    45  		time.Sleep(time.Until(start.Round(0)))
    46  		if got, want := time.Now().In(time.UTC), start.In(time.UTC); !got.Equal(want) {
    47  			t.Fatalf("time.Now() = %v, want %v", got, want)
    48  		}
    49  
    50  		wait := 1 * time.Second
    51  		time.Sleep(wait)
    52  		if got := time.Since(start); got != wait {
    53  			t.Fatalf("time.Since(start) = %v, want %v", got, wait)
    54  		}
    55  		if got := time.Now().Sub(start); got != wait {
    56  			t.Fatalf("time.Now().Sub(start) = %v, want %v", got, wait)
    57  		}
    58  	})
    59  }
    60  
    61  func TestRunEmpty(t *testing.T) {
    62  	synctest.Run(func() {
    63  	})
    64  }
    65  
    66  func TestSimpleWait(t *testing.T) {
    67  	synctest.Run(func() {
    68  		synctest.Wait()
    69  	})
    70  }
    71  
    72  func TestGoroutineWait(t *testing.T) {
    73  	synctest.Run(func() {
    74  		go func() {}()
    75  		synctest.Wait()
    76  	})
    77  }
    78  
    79  // TestWait starts a collection of goroutines.
    80  // It checks that synctest.Wait waits for all goroutines to exit before returning.
    81  func TestWait(t *testing.T) {
    82  	synctest.Run(func() {
    83  		done := false
    84  		ch := make(chan int)
    85  		var f func()
    86  		f = func() {
    87  			count := <-ch
    88  			if count == 0 {
    89  				done = true
    90  			} else {
    91  				go f()
    92  				ch <- count - 1
    93  			}
    94  		}
    95  		go f()
    96  		ch <- 100
    97  		synctest.Wait()
    98  		if !done {
    99  			t.Fatalf("done = false, want true")
   100  		}
   101  	})
   102  }
   103  
   104  func TestMallocs(t *testing.T) {
   105  	for i := 0; i < 100; i++ {
   106  		synctest.Run(func() {
   107  			done := false
   108  			ch := make(chan []byte)
   109  			var f func()
   110  			f = func() {
   111  				b := <-ch
   112  				if len(b) == 0 {
   113  					done = true
   114  				} else {
   115  					go f()
   116  					ch <- make([]byte, len(b)-1)
   117  				}
   118  			}
   119  			go f()
   120  			ch <- make([]byte, 100)
   121  			synctest.Wait()
   122  			if !done {
   123  				t.Fatalf("done = false, want true")
   124  			}
   125  		})
   126  	}
   127  }
   128  
   129  func TestTimerReadBeforeDeadline(t *testing.T) {
   130  	synctest.Run(func() {
   131  		start := time.Now()
   132  		tm := time.NewTimer(5 * time.Second)
   133  		<-tm.C
   134  		if got, want := time.Since(start), 5*time.Second; got != want {
   135  			t.Errorf("after sleep: time.Since(start) = %v, want %v", got, want)
   136  		}
   137  	})
   138  }
   139  
   140  func TestTimerReadAfterDeadline(t *testing.T) {
   141  	synctest.Run(func() {
   142  		delay := 1 * time.Second
   143  		want := time.Now().Add(delay)
   144  		tm := time.NewTimer(delay)
   145  		time.Sleep(2 * delay)
   146  		got := <-tm.C
   147  		if got != want {
   148  			t.Errorf("<-tm.C = %v, want %v", got, want)
   149  		}
   150  	})
   151  }
   152  
   153  func TestTimerReset(t *testing.T) {
   154  	synctest.Run(func() {
   155  		start := time.Now()
   156  		tm := time.NewTimer(1 * time.Second)
   157  		if got, want := <-tm.C, start.Add(1*time.Second); got != want {
   158  			t.Errorf("first sleep: <-tm.C = %v, want %v", got, want)
   159  		}
   160  
   161  		tm.Reset(2 * time.Second)
   162  		if got, want := <-tm.C, start.Add((1+2)*time.Second); got != want {
   163  			t.Errorf("second sleep: <-tm.C = %v, want %v", got, want)
   164  		}
   165  
   166  		tm.Reset(3 * time.Second)
   167  		time.Sleep(1 * time.Second)
   168  		tm.Reset(3 * time.Second)
   169  		if got, want := <-tm.C, start.Add((1+2+4)*time.Second); got != want {
   170  			t.Errorf("third sleep: <-tm.C = %v, want %v", got, want)
   171  		}
   172  	})
   173  }
   174  
   175  func TestTimeAfter(t *testing.T) {
   176  	synctest.Run(func() {
   177  		i := 0
   178  		time.AfterFunc(1*time.Second, func() {
   179  			// Ensure synctest group membership propagates through the AfterFunc.
   180  			i++ // 1
   181  			go func() {
   182  				time.Sleep(1 * time.Second)
   183  				i++ // 2
   184  			}()
   185  		})
   186  		time.Sleep(3 * time.Second)
   187  		synctest.Wait()
   188  		if got, want := i, 2; got != want {
   189  			t.Errorf("after sleep and wait: i = %v, want %v", got, want)
   190  		}
   191  	})
   192  }
   193  
   194  func TestTimerAfterBubbleExit(t *testing.T) {
   195  	run := false
   196  	synctest.Run(func() {
   197  		time.AfterFunc(1*time.Second, func() {
   198  			run = true
   199  		})
   200  	})
   201  	if run {
   202  		t.Errorf("timer ran before bubble exit")
   203  	}
   204  }
   205  
   206  func TestTimerFromOutsideBubble(t *testing.T) {
   207  	tm := time.NewTimer(10 * time.Millisecond)
   208  	synctest.Run(func() {
   209  		<-tm.C
   210  	})
   211  	if tm.Stop() {
   212  		t.Errorf("synctest.Run unexpectedly returned before timer fired")
   213  	}
   214  }
   215  
   216  func TestChannelFromOutsideBubble(t *testing.T) {
   217  	choutside := make(chan struct{})
   218  	for _, test := range []struct {
   219  		desc    string
   220  		outside func(ch chan int)
   221  		inside  func(ch chan int)
   222  	}{{
   223  		desc:    "read closed",
   224  		outside: func(ch chan int) { close(ch) },
   225  		inside:  func(ch chan int) { <-ch },
   226  	}, {
   227  		desc:    "read value",
   228  		outside: func(ch chan int) { ch <- 0 },
   229  		inside:  func(ch chan int) { <-ch },
   230  	}, {
   231  		desc:    "write value",
   232  		outside: func(ch chan int) { <-ch },
   233  		inside:  func(ch chan int) { ch <- 0 },
   234  	}, {
   235  		desc:    "select outside only",
   236  		outside: func(ch chan int) { close(ch) },
   237  		inside: func(ch chan int) {
   238  			select {
   239  			case <-ch:
   240  			case <-choutside:
   241  			}
   242  		},
   243  	}, {
   244  		desc:    "select mixed",
   245  		outside: func(ch chan int) { close(ch) },
   246  		inside: func(ch chan int) {
   247  			ch2 := make(chan struct{})
   248  			select {
   249  			case <-ch:
   250  			case <-ch2:
   251  			}
   252  		},
   253  	}} {
   254  		t.Run(test.desc, func(t *testing.T) {
   255  			ch := make(chan int)
   256  			time.AfterFunc(1*time.Millisecond, func() {
   257  				test.outside(ch)
   258  			})
   259  			synctest.Run(func() {
   260  				test.inside(ch)
   261  			})
   262  		})
   263  	}
   264  }
   265  
   266  func TestTimerFromInsideBubble(t *testing.T) {
   267  	for _, test := range []struct {
   268  		desc      string
   269  		f         func(tm *time.Timer)
   270  		wantPanic string
   271  	}{{
   272  		desc: "read channel",
   273  		f: func(tm *time.Timer) {
   274  			<-tm.C
   275  		},
   276  		wantPanic: "receive on synctest channel from outside bubble",
   277  	}, {
   278  		desc: "Reset",
   279  		f: func(tm *time.Timer) {
   280  			tm.Reset(1 * time.Second)
   281  		},
   282  		wantPanic: "reset of synctest timer from outside bubble",
   283  	}, {
   284  		desc: "Stop",
   285  		f: func(tm *time.Timer) {
   286  			tm.Stop()
   287  		},
   288  		wantPanic: "stop of synctest timer from outside bubble",
   289  	}} {
   290  		t.Run(test.desc, func(t *testing.T) {
   291  			donec := make(chan struct{})
   292  			ch := make(chan *time.Timer)
   293  			go func() {
   294  				defer close(donec)
   295  				defer wantPanic(t, test.wantPanic)
   296  				test.f(<-ch)
   297  			}()
   298  			synctest.Run(func() {
   299  				tm := time.NewTimer(1 * time.Second)
   300  				ch <- tm
   301  			})
   302  			<-donec
   303  		})
   304  	}
   305  }
   306  
   307  func TestDeadlockRoot(t *testing.T) {
   308  	defer wantPanic(t, "deadlock: all goroutines in bubble are blocked")
   309  	synctest.Run(func() {
   310  		select {}
   311  	})
   312  }
   313  
   314  func TestDeadlockChild(t *testing.T) {
   315  	defer wantPanic(t, "deadlock: all goroutines in bubble are blocked")
   316  	synctest.Run(func() {
   317  		go func() {
   318  			select {}
   319  		}()
   320  	})
   321  }
   322  
   323  func TestDeadlockTicker(t *testing.T) {
   324  	defer wantPanic(t, "deadlock: all goroutines in bubble are blocked")
   325  	synctest.Run(func() {
   326  		go func() {
   327  			for range time.Tick(1 * time.Second) {
   328  				t.Errorf("ticker unexpectedly ran")
   329  				return
   330  			}
   331  		}()
   332  	})
   333  }
   334  
   335  func TestCond(t *testing.T) {
   336  	synctest.Run(func() {
   337  		var mu sync.Mutex
   338  		cond := sync.NewCond(&mu)
   339  		start := time.Now()
   340  		const waitTime = 1 * time.Millisecond
   341  
   342  		go func() {
   343  			// Signal the cond.
   344  			time.Sleep(waitTime)
   345  			mu.Lock()
   346  			cond.Signal()
   347  			mu.Unlock()
   348  
   349  			// Broadcast to the cond.
   350  			time.Sleep(waitTime)
   351  			mu.Lock()
   352  			cond.Broadcast()
   353  			mu.Unlock()
   354  		}()
   355  
   356  		// Wait for cond.Signal.
   357  		mu.Lock()
   358  		cond.Wait()
   359  		mu.Unlock()
   360  		if got, want := time.Since(start), waitTime; got != want {
   361  			t.Errorf("after cond.Signal: time elapsed = %v, want %v", got, want)
   362  		}
   363  
   364  		// Wait for cond.Broadcast in two goroutines.
   365  		waiterDone := false
   366  		go func() {
   367  			mu.Lock()
   368  			cond.Wait()
   369  			mu.Unlock()
   370  			waiterDone = true
   371  		}()
   372  		mu.Lock()
   373  		cond.Wait()
   374  		mu.Unlock()
   375  		synctest.Wait()
   376  		if !waiterDone {
   377  			t.Errorf("after cond.Broadcast: waiter not done")
   378  		}
   379  		if got, want := time.Since(start), 2*waitTime; got != want {
   380  			t.Errorf("after cond.Broadcast: time elapsed = %v, want %v", got, want)
   381  		}
   382  	})
   383  }
   384  
   385  func TestIteratorPush(t *testing.T) {
   386  	synctest.Run(func() {
   387  		seq := func(yield func(time.Time) bool) {
   388  			for yield(time.Now()) {
   389  				time.Sleep(1 * time.Second)
   390  			}
   391  		}
   392  		var got []time.Time
   393  		go func() {
   394  			for now := range seq {
   395  				got = append(got, now)
   396  				if len(got) >= 3 {
   397  					break
   398  				}
   399  			}
   400  		}()
   401  		want := []time.Time{
   402  			time.Now(),
   403  			time.Now().Add(1 * time.Second),
   404  			time.Now().Add(2 * time.Second),
   405  		}
   406  		time.Sleep(5 * time.Second)
   407  		synctest.Wait()
   408  		if !slices.Equal(got, want) {
   409  			t.Errorf("got: %v; want: %v", got, want)
   410  		}
   411  	})
   412  }
   413  
   414  func TestIteratorPull(t *testing.T) {
   415  	synctest.Run(func() {
   416  		seq := func(yield func(time.Time) bool) {
   417  			for yield(time.Now()) {
   418  				time.Sleep(1 * time.Second)
   419  			}
   420  		}
   421  		var got []time.Time
   422  		go func() {
   423  			next, stop := iter.Pull(seq)
   424  			defer stop()
   425  			for len(got) < 3 {
   426  				now, _ := next()
   427  				got = append(got, now)
   428  			}
   429  		}()
   430  		want := []time.Time{
   431  			time.Now(),
   432  			time.Now().Add(1 * time.Second),
   433  			time.Now().Add(2 * time.Second),
   434  		}
   435  		time.Sleep(5 * time.Second)
   436  		synctest.Wait()
   437  		if !slices.Equal(got, want) {
   438  			t.Errorf("got: %v; want: %v", got, want)
   439  		}
   440  	})
   441  }
   442  
   443  func TestReflectFuncOf(t *testing.T) {
   444  	mkfunc := func(name string, i int) {
   445  		reflect.FuncOf([]reflect.Type{
   446  			reflect.StructOf([]reflect.StructField{{
   447  				Name: name + strconv.Itoa(i),
   448  				Type: reflect.TypeOf(0),
   449  			}}),
   450  		}, nil, false)
   451  	}
   452  	go func() {
   453  		for i := 0; i < 100000; i++ {
   454  			mkfunc("A", i)
   455  		}
   456  	}()
   457  	synctest.Run(func() {
   458  		for i := 0; i < 100000; i++ {
   459  			mkfunc("A", i)
   460  		}
   461  	})
   462  }
   463  
   464  func TestWaitGroup(t *testing.T) {
   465  	synctest.Run(func() {
   466  		var wg sync.WaitGroup
   467  		wg.Add(1)
   468  		const delay = 1 * time.Second
   469  		go func() {
   470  			time.Sleep(delay)
   471  			wg.Done()
   472  		}()
   473  		start := time.Now()
   474  		wg.Wait()
   475  		if got := time.Since(start); got != delay {
   476  			t.Fatalf("WaitGroup.Wait() took %v, want %v", got, delay)
   477  		}
   478  	})
   479  }
   480  
   481  func TestHappensBefore(t *testing.T) {
   482  	// Use two parallel goroutines accessing different vars to ensure that
   483  	// we correctly account for multiple goroutines in the bubble.
   484  	var v1 int
   485  	var v2 int
   486  	synctest.Run(func() {
   487  		v1++ // 1
   488  		v2++ // 1
   489  
   490  		// Wait returns after these goroutines exit.
   491  		go func() {
   492  			v1++ // 2
   493  		}()
   494  		go func() {
   495  			v2++ // 2
   496  		}()
   497  		synctest.Wait()
   498  
   499  		v1++ // 3
   500  		v2++ // 3
   501  
   502  		// Wait returns after these goroutines block.
   503  		ch1 := make(chan struct{})
   504  		go func() {
   505  			v1++ // 4
   506  			<-ch1
   507  		}()
   508  		go func() {
   509  			v2++ // 4
   510  			<-ch1
   511  		}()
   512  		synctest.Wait()
   513  
   514  		v1++ // 5
   515  		v2++ // 5
   516  		close(ch1)
   517  
   518  		// Wait returns after these timers run.
   519  		time.AfterFunc(0, func() {
   520  			v1++ // 6
   521  		})
   522  		time.AfterFunc(0, func() {
   523  			v2++ // 6
   524  		})
   525  		synctest.Wait()
   526  
   527  		v1++ // 7
   528  		v2++ // 7
   529  
   530  		// Wait returns after these timer goroutines block.
   531  		ch2 := make(chan struct{})
   532  		time.AfterFunc(0, func() {
   533  			v1++ // 8
   534  			<-ch2
   535  		})
   536  		time.AfterFunc(0, func() {
   537  			v2++ // 8
   538  			<-ch2
   539  		})
   540  		synctest.Wait()
   541  
   542  		v1++ // 9
   543  		v2++ // 9
   544  		close(ch2)
   545  	})
   546  	// This Run happens after the previous Run returns.
   547  	synctest.Run(func() {
   548  		go func() {
   549  			go func() {
   550  				v1++ // 10
   551  			}()
   552  		}()
   553  		go func() {
   554  			go func() {
   555  				v2++ // 10
   556  			}()
   557  		}()
   558  	})
   559  	// These tests happen after Run returns.
   560  	if got, want := v1, 10; got != want {
   561  		t.Errorf("v1 = %v, want %v", got, want)
   562  	}
   563  	if got, want := v2, 10; got != want {
   564  		t.Errorf("v2 = %v, want %v", got, want)
   565  	}
   566  }
   567  
   568  func wantPanic(t *testing.T, want string) {
   569  	if e := recover(); e != nil {
   570  		if got := fmt.Sprint(e); got != want {
   571  			t.Errorf("got panic message %q, want %q", got, want)
   572  		}
   573  	} else {
   574  		t.Errorf("got no panic, want one")
   575  	}
   576  }
   577  

View as plain text