Source file src/crypto/tls/bogo_shim_test.go

     1  // Copyright 2024 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package tls
     6  
     7  import (
     8  	"bytes"
     9  	"crypto/internal/cryptotest"
    10  	"crypto/x509"
    11  	"encoding/base64"
    12  	"encoding/json"
    13  	"encoding/pem"
    14  	"flag"
    15  	"fmt"
    16  	"internal/byteorder"
    17  	"internal/testenv"
    18  	"io"
    19  	"log"
    20  	"net"
    21  	"os"
    22  	"path/filepath"
    23  	"runtime"
    24  	"slices"
    25  	"strconv"
    26  	"strings"
    27  	"testing"
    28  
    29  	"golang.org/x/crypto/cryptobyte"
    30  )
    31  
    32  var (
    33  	port   = flag.String("port", "", "")
    34  	server = flag.Bool("server", false, "")
    35  
    36  	isHandshakerSupported = flag.Bool("is-handshaker-supported", false, "")
    37  
    38  	keyfile  = flag.String("key-file", "", "")
    39  	certfile = flag.String("cert-file", "", "")
    40  
    41  	trustCert = flag.String("trust-cert", "", "")
    42  
    43  	minVersion    = flag.Int("min-version", VersionSSL30, "")
    44  	maxVersion    = flag.Int("max-version", VersionTLS13, "")
    45  	expectVersion = flag.Int("expect-version", 0, "")
    46  
    47  	noTLS1  = flag.Bool("no-tls1", false, "")
    48  	noTLS11 = flag.Bool("no-tls11", false, "")
    49  	noTLS12 = flag.Bool("no-tls12", false, "")
    50  	noTLS13 = flag.Bool("no-tls13", false, "")
    51  
    52  	requireAnyClientCertificate = flag.Bool("require-any-client-certificate", false, "")
    53  
    54  	shimWritesFirst = flag.Bool("shim-writes-first", false, "")
    55  
    56  	resumeCount = flag.Int("resume-count", 0, "")
    57  
    58  	curves        = flagStringSlice("curves", "")
    59  	expectedCurve = flag.String("expect-curve-id", "", "")
    60  
    61  	shimID = flag.Uint64("shim-id", 0, "")
    62  	_      = flag.Bool("ipv6", false, "")
    63  
    64  	echConfigListB64           = flag.String("ech-config-list", "", "")
    65  	expectECHAccepted          = flag.Bool("expect-ech-accept", false, "")
    66  	expectHRR                  = flag.Bool("expect-hrr", false, "")
    67  	expectNoHRR                = flag.Bool("expect-no-hrr", false, "")
    68  	expectedECHRetryConfigs    = flag.String("expect-ech-retry-configs", "", "")
    69  	expectNoECHRetryConfigs    = flag.Bool("expect-no-ech-retry-configs", false, "")
    70  	onInitialExpectECHAccepted = flag.Bool("on-initial-expect-ech-accept", false, "")
    71  	_                          = flag.Bool("expect-no-ech-name-override", false, "")
    72  	_                          = flag.String("expect-ech-name-override", "", "")
    73  	_                          = flag.Bool("reverify-on-resume", false, "")
    74  	onResumeECHConfigListB64   = flag.String("on-resume-ech-config-list", "", "")
    75  	_                          = flag.Bool("on-resume-expect-reject-early-data", false, "")
    76  	onResumeExpectECHAccepted  = flag.Bool("on-resume-expect-ech-accept", false, "")
    77  	_                          = flag.Bool("on-resume-expect-no-ech-name-override", false, "")
    78  	expectedServerName         = flag.String("expect-server-name", "", "")
    79  	echServerConfig            = flagStringSlice("ech-server-config", "")
    80  	echServerKey               = flagStringSlice("ech-server-key", "")
    81  	echServerRetryConfig       = flagStringSlice("ech-is-retry-config", "")
    82  
    83  	expectSessionMiss = flag.Bool("expect-session-miss", false, "")
    84  
    85  	_ = flag.Bool("enable-early-data", false, "")
    86  	_ = flag.Bool("on-resume-expect-accept-early-data", false, "")
    87  	_ = flag.Bool("expect-ticket-supports-early-data", false, "")
    88  	_ = flag.Bool("on-resume-shim-writes-first", false, "")
    89  
    90  	advertiseALPN        = flag.String("advertise-alpn", "", "")
    91  	expectALPN           = flag.String("expect-alpn", "", "")
    92  	rejectALPN           = flag.Bool("reject-alpn", false, "")
    93  	declineALPN          = flag.Bool("decline-alpn", false, "")
    94  	expectAdvertisedALPN = flag.String("expect-advertised-alpn", "", "")
    95  	selectALPN           = flag.String("select-alpn", "", "")
    96  
    97  	hostName = flag.String("host-name", "", "")
    98  
    99  	verifyPeer = flag.Bool("verify-peer", false, "")
   100  	_          = flag.Bool("use-custom-verify-callback", false, "")
   101  
   102  	waitForDebugger = flag.Bool("wait-for-debugger", false, "")
   103  )
   104  
   105  type stringSlice []string
   106  
   107  func flagStringSlice(name, usage string) *stringSlice {
   108  	f := &stringSlice{}
   109  	flag.Var(f, name, usage)
   110  	return f
   111  }
   112  
   113  func (saf *stringSlice) String() string {
   114  	return strings.Join(*saf, ",")
   115  }
   116  
   117  func (saf *stringSlice) Set(s string) error {
   118  	*saf = append(*saf, s)
   119  	return nil
   120  }
   121  
   122  func bogoShim() {
   123  	if *isHandshakerSupported {
   124  		fmt.Println("No")
   125  		return
   126  	}
   127  
   128  	// Test with both the default and insecure cipher suites.
   129  	var ciphersuites []uint16
   130  	for _, s := range append(CipherSuites(), InsecureCipherSuites()...) {
   131  		ciphersuites = append(ciphersuites, s.ID)
   132  	}
   133  
   134  	cfg := &Config{
   135  		ServerName: "test",
   136  
   137  		MinVersion: uint16(*minVersion),
   138  		MaxVersion: uint16(*maxVersion),
   139  
   140  		ClientSessionCache: NewLRUClientSessionCache(0),
   141  
   142  		CipherSuites: ciphersuites,
   143  
   144  		GetConfigForClient: func(chi *ClientHelloInfo) (*Config, error) {
   145  
   146  			if *expectAdvertisedALPN != "" {
   147  
   148  				s := cryptobyte.String(*expectAdvertisedALPN)
   149  
   150  				var expectedALPNs []string
   151  
   152  				for !s.Empty() {
   153  					var alpn cryptobyte.String
   154  					if !s.ReadUint8LengthPrefixed(&alpn) {
   155  						return nil, fmt.Errorf("unexpected error while parsing arguments for -expect-advertised-alpn")
   156  					}
   157  					expectedALPNs = append(expectedALPNs, string(alpn))
   158  				}
   159  
   160  				if !slices.Equal(chi.SupportedProtos, expectedALPNs) {
   161  					return nil, fmt.Errorf("unexpected ALPN: got %q, want %q", chi.SupportedProtos, expectedALPNs)
   162  				}
   163  			}
   164  			return nil, nil
   165  		},
   166  	}
   167  
   168  	if *noTLS1 {
   169  		cfg.MinVersion = VersionTLS11
   170  		if *noTLS11 {
   171  			cfg.MinVersion = VersionTLS12
   172  			if *noTLS12 {
   173  				cfg.MinVersion = VersionTLS13
   174  				if *noTLS13 {
   175  					log.Fatalf("no supported versions enabled")
   176  				}
   177  			}
   178  		}
   179  	} else if *noTLS13 {
   180  		cfg.MaxVersion = VersionTLS12
   181  		if *noTLS12 {
   182  			cfg.MaxVersion = VersionTLS11
   183  			if *noTLS11 {
   184  				cfg.MaxVersion = VersionTLS10
   185  				if *noTLS1 {
   186  					log.Fatalf("no supported versions enabled")
   187  				}
   188  			}
   189  		}
   190  	}
   191  
   192  	if *advertiseALPN != "" {
   193  		alpns := *advertiseALPN
   194  		for len(alpns) > 0 {
   195  			alpnLen := int(alpns[0])
   196  			cfg.NextProtos = append(cfg.NextProtos, alpns[1:1+alpnLen])
   197  			alpns = alpns[alpnLen+1:]
   198  		}
   199  	}
   200  
   201  	if *rejectALPN {
   202  		cfg.NextProtos = []string{"unnegotiableprotocol"}
   203  	}
   204  
   205  	if *declineALPN {
   206  		cfg.NextProtos = []string{}
   207  	}
   208  	if *selectALPN != "" {
   209  		cfg.NextProtos = []string{*selectALPN}
   210  	}
   211  
   212  	if *hostName != "" {
   213  		cfg.ServerName = *hostName
   214  	}
   215  
   216  	if *keyfile != "" || *certfile != "" {
   217  		pair, err := LoadX509KeyPair(*certfile, *keyfile)
   218  		if err != nil {
   219  			log.Fatalf("load key-file err: %s", err)
   220  		}
   221  		cfg.Certificates = []Certificate{pair}
   222  	}
   223  	if *trustCert != "" {
   224  		pool := x509.NewCertPool()
   225  		certFile, err := os.ReadFile(*trustCert)
   226  		if err != nil {
   227  			log.Fatalf("load trust-cert err: %s", err)
   228  		}
   229  		block, _ := pem.Decode(certFile)
   230  		cert, err := x509.ParseCertificate(block.Bytes)
   231  		if err != nil {
   232  			log.Fatalf("parse trust-cert err: %s", err)
   233  		}
   234  		pool.AddCert(cert)
   235  		cfg.RootCAs = pool
   236  	}
   237  
   238  	if *requireAnyClientCertificate {
   239  		cfg.ClientAuth = RequireAnyClientCert
   240  	}
   241  	if *verifyPeer {
   242  		cfg.ClientAuth = VerifyClientCertIfGiven
   243  	}
   244  
   245  	if *echConfigListB64 != "" {
   246  		echConfigList, err := base64.StdEncoding.DecodeString(*echConfigListB64)
   247  		if err != nil {
   248  			log.Fatalf("parse ech-config-list err: %s", err)
   249  		}
   250  		cfg.EncryptedClientHelloConfigList = echConfigList
   251  		cfg.MinVersion = VersionTLS13
   252  	}
   253  
   254  	if len(*curves) != 0 {
   255  		for _, curveStr := range *curves {
   256  			id, err := strconv.Atoi(curveStr)
   257  			if err != nil {
   258  				log.Fatalf("failed to parse curve id %q: %s", curveStr, err)
   259  			}
   260  			cfg.CurvePreferences = append(cfg.CurvePreferences, CurveID(id))
   261  		}
   262  	}
   263  
   264  	if len(*echServerConfig) != 0 {
   265  		if len(*echServerConfig) != len(*echServerKey) || len(*echServerConfig) != len(*echServerRetryConfig) {
   266  			log.Fatal("-ech-server-config, -ech-server-key, and -ech-is-retry-config mismatch")
   267  		}
   268  
   269  		for i, c := range *echServerConfig {
   270  			configBytes, err := base64.StdEncoding.DecodeString(c)
   271  			if err != nil {
   272  				log.Fatalf("parse ech-server-config err: %s", err)
   273  			}
   274  			privBytes, err := base64.StdEncoding.DecodeString((*echServerKey)[i])
   275  			if err != nil {
   276  				log.Fatalf("parse ech-server-key err: %s", err)
   277  			}
   278  
   279  			cfg.EncryptedClientHelloKeys = append(cfg.EncryptedClientHelloKeys, EncryptedClientHelloKey{
   280  				Config:      configBytes,
   281  				PrivateKey:  privBytes,
   282  				SendAsRetry: (*echServerRetryConfig)[i] == "1",
   283  			})
   284  		}
   285  	}
   286  
   287  	for i := 0; i < *resumeCount+1; i++ {
   288  		if i > 0 && (*onResumeECHConfigListB64 != "") {
   289  			echConfigList, err := base64.StdEncoding.DecodeString(*onResumeECHConfigListB64)
   290  			if err != nil {
   291  				log.Fatalf("parse ech-config-list err: %s", err)
   292  			}
   293  			cfg.EncryptedClientHelloConfigList = echConfigList
   294  		}
   295  
   296  		conn, err := net.Dial("tcp", net.JoinHostPort("localhost", *port))
   297  		if err != nil {
   298  			log.Fatalf("dial err: %s", err)
   299  		}
   300  		defer conn.Close()
   301  
   302  		// Write the shim ID we were passed as a little endian uint64
   303  		shimIDBytes := make([]byte, 8)
   304  		byteorder.LEPutUint64(shimIDBytes, *shimID)
   305  		if _, err := conn.Write(shimIDBytes); err != nil {
   306  			log.Fatalf("failed to write shim id: %s", err)
   307  		}
   308  
   309  		var tlsConn *Conn
   310  		if *server {
   311  			tlsConn = Server(conn, cfg)
   312  		} else {
   313  			tlsConn = Client(conn, cfg)
   314  		}
   315  
   316  		if i == 0 && *shimWritesFirst {
   317  			if _, err := tlsConn.Write([]byte("hello")); err != nil {
   318  				log.Fatalf("write err: %s", err)
   319  			}
   320  		}
   321  
   322  		// If we were instructed to wait for a debugger, then send SIGSTOP to ourselves.
   323  		// When the debugger attaches it will continue the process.
   324  		if *waitForDebugger {
   325  			pauseProcess()
   326  		}
   327  
   328  		for {
   329  			buf := make([]byte, 500)
   330  			var n int
   331  			n, err = tlsConn.Read(buf)
   332  			if err != nil {
   333  				break
   334  			}
   335  			buf = buf[:n]
   336  			for i := range buf {
   337  				buf[i] ^= 0xff
   338  			}
   339  			if _, err = tlsConn.Write(buf); err != nil {
   340  				break
   341  			}
   342  		}
   343  		if err != io.EOF {
   344  			retryErr, ok := err.(*ECHRejectionError)
   345  			if !ok {
   346  				log.Fatalf("unexpected error type returned: %v", err)
   347  			}
   348  			if *expectNoECHRetryConfigs && len(retryErr.RetryConfigList) > 0 {
   349  				log.Fatalf("expected no ECH retry configs, got some")
   350  			}
   351  			if *expectedECHRetryConfigs != "" {
   352  				expectedRetryConfigs, err := base64.StdEncoding.DecodeString(*expectedECHRetryConfigs)
   353  				if err != nil {
   354  					log.Fatalf("failed to decode expected retry configs: %s", err)
   355  				}
   356  				if !bytes.Equal(retryErr.RetryConfigList, expectedRetryConfigs) {
   357  					log.Fatalf("unexpected retry list returned: got %x, want %x", retryErr.RetryConfigList, expectedRetryConfigs)
   358  				}
   359  			}
   360  			log.Fatalf("conn error: %s", err)
   361  		}
   362  
   363  		cs := tlsConn.ConnectionState()
   364  		if cs.HandshakeComplete {
   365  			if *expectALPN != "" && cs.NegotiatedProtocol != *expectALPN {
   366  				log.Fatalf("unexpected protocol negotiated: want %q, got %q", *expectALPN, cs.NegotiatedProtocol)
   367  			}
   368  
   369  			if *selectALPN != "" && cs.NegotiatedProtocol != *selectALPN {
   370  				log.Fatalf("unexpected protocol negotiated: want %q, got %q", *selectALPN, cs.NegotiatedProtocol)
   371  			}
   372  
   373  			if *expectVersion != 0 && cs.Version != uint16(*expectVersion) {
   374  				log.Fatalf("expected ssl version %q, got %q", uint16(*expectVersion), cs.Version)
   375  			}
   376  			if *declineALPN && cs.NegotiatedProtocol != "" {
   377  				log.Fatal("unexpected ALPN protocol")
   378  			}
   379  			if *expectECHAccepted && !cs.ECHAccepted {
   380  				log.Fatal("expected ECH to be accepted, but connection state shows it was not")
   381  			} else if i == 0 && *onInitialExpectECHAccepted && !cs.ECHAccepted {
   382  				log.Fatal("expected ECH to be accepted, but connection state shows it was not")
   383  			} else if i > 0 && *onResumeExpectECHAccepted && !cs.ECHAccepted {
   384  				log.Fatal("expected ECH to be accepted on resumption, but connection state shows it was not")
   385  			} else if i == 0 && !*expectECHAccepted && cs.ECHAccepted {
   386  				log.Fatal("did not expect ECH, but it was accepted")
   387  			}
   388  
   389  			if *expectHRR && !cs.testingOnlyDidHRR {
   390  				log.Fatal("expected HRR but did not do it")
   391  			}
   392  
   393  			if *expectNoHRR && cs.testingOnlyDidHRR {
   394  				log.Fatal("expected no HRR but did do it")
   395  			}
   396  
   397  			if *expectSessionMiss && cs.DidResume {
   398  				log.Fatal("unexpected session resumption")
   399  			}
   400  
   401  			if *expectedServerName != "" && cs.ServerName != *expectedServerName {
   402  				log.Fatalf("unexpected server name: got %q, want %q", cs.ServerName, *expectedServerName)
   403  			}
   404  		}
   405  
   406  		if *expectedCurve != "" {
   407  			expectedCurveID, err := strconv.Atoi(*expectedCurve)
   408  			if err != nil {
   409  				log.Fatalf("failed to parse -expect-curve-id: %s", err)
   410  			}
   411  			if tlsConn.curveID != CurveID(expectedCurveID) {
   412  				log.Fatalf("unexpected curve id: want %d, got %d", expectedCurveID, tlsConn.curveID)
   413  			}
   414  		}
   415  	}
   416  }
   417  
   418  func TestBogoSuite(t *testing.T) {
   419  	if testing.Short() {
   420  		t.Skip("skipping in short mode")
   421  	}
   422  	if testenv.Builder() != "" && runtime.GOOS == "windows" {
   423  		t.Skip("#66913: windows network connections are flakey on builders")
   424  	}
   425  	skipFIPS(t)
   426  
   427  	// In order to make Go test caching work as expected, we stat the
   428  	// bogo_config.json file, so that the Go testing hooks know that it is
   429  	// important for this test and will invalidate a cached test result if the
   430  	// file changes.
   431  	if _, err := os.Stat("bogo_config.json"); err != nil {
   432  		t.Fatal(err)
   433  	}
   434  
   435  	var bogoDir string
   436  	if *bogoLocalDir != "" {
   437  		bogoDir = *bogoLocalDir
   438  	} else {
   439  		const boringsslModVer = "v0.0.0-20241120195446-5cce3fbd23e1"
   440  		bogoDir = cryptotest.FetchModule(t, "boringssl.googlesource.com/boringssl.git", boringsslModVer)
   441  	}
   442  
   443  	cwd, err := os.Getwd()
   444  	if err != nil {
   445  		t.Fatal(err)
   446  	}
   447  
   448  	resultsFile := filepath.Join(t.TempDir(), "results.json")
   449  
   450  	args := []string{
   451  		"test",
   452  		".",
   453  		fmt.Sprintf("-shim-config=%s", filepath.Join(cwd, "bogo_config.json")),
   454  		fmt.Sprintf("-shim-path=%s", os.Args[0]),
   455  		"-shim-extra-flags=-bogo-mode",
   456  		"-allow-unimplemented",
   457  		"-loose-errors", // TODO(roland): this should be removed eventually
   458  		fmt.Sprintf("-json-output=%s", resultsFile),
   459  	}
   460  	if *bogoFilter != "" {
   461  		args = append(args, fmt.Sprintf("-test=%s", *bogoFilter))
   462  	}
   463  
   464  	cmd := testenv.Command(t, testenv.GoToolPath(t), args...)
   465  	out := &strings.Builder{}
   466  	cmd.Stderr = out
   467  	cmd.Dir = filepath.Join(bogoDir, "ssl/test/runner")
   468  	err = cmd.Run()
   469  	// NOTE: we don't immediately check the error, because the failure could be either because
   470  	// the runner failed for some unexpected reason, or because a test case failed, and we
   471  	// cannot easily differentiate these cases. We check if the JSON results file was written,
   472  	// which should only happen if the failure was because of a test failure, and use that
   473  	// to determine the failure mode.
   474  
   475  	resultsJSON, jsonErr := os.ReadFile(resultsFile)
   476  	if jsonErr != nil {
   477  		if err != nil {
   478  			t.Fatalf("bogo failed: %s\n%s", err, out)
   479  		}
   480  		t.Fatalf("failed to read results JSON file: %s", jsonErr)
   481  	}
   482  
   483  	var results bogoResults
   484  	if err := json.Unmarshal(resultsJSON, &results); err != nil {
   485  		t.Fatalf("failed to parse results JSON: %s", err)
   486  	}
   487  
   488  	// assertResults contains test results we want to make sure
   489  	// are present in the output. They are only checked if -bogo-filter
   490  	// was not passed.
   491  	assertResults := map[string]string{
   492  		"CurveTest-Client-MLKEM-TLS13": "PASS",
   493  		"CurveTest-Server-MLKEM-TLS13": "PASS",
   494  	}
   495  
   496  	for name, result := range results.Tests {
   497  		// This is not really the intended way to do this... but... it works?
   498  		t.Run(name, func(t *testing.T) {
   499  			if result.Actual == "FAIL" && result.IsUnexpected {
   500  				t.Fatal(result.Error)
   501  			}
   502  			if expectedResult, ok := assertResults[name]; ok && expectedResult != result.Actual {
   503  				t.Fatalf("unexpected result: got %s, want %s", result.Actual, assertResults[name])
   504  			}
   505  			delete(assertResults, name)
   506  			if result.Actual == "SKIP" {
   507  				t.Skip()
   508  			}
   509  		})
   510  	}
   511  	if *bogoFilter == "" {
   512  		// Anything still in assertResults did not show up in the results, so we should fail
   513  		for name, expectedResult := range assertResults {
   514  			t.Run(name, func(t *testing.T) {
   515  				t.Fatalf("expected test to run with result %s, but it was not present in the test results", expectedResult)
   516  			})
   517  		}
   518  	}
   519  }
   520  
   521  // bogoResults is a copy of boringssl.googlesource.com/boringssl/testresults.Results
   522  type bogoResults struct {
   523  	Version           int            `json:"version"`
   524  	Interrupted       bool           `json:"interrupted"`
   525  	PathDelimiter     string         `json:"path_delimiter"`
   526  	SecondsSinceEpoch float64        `json:"seconds_since_epoch"`
   527  	NumFailuresByType map[string]int `json:"num_failures_by_type"`
   528  	Tests             map[string]struct {
   529  		Actual       string `json:"actual"`
   530  		Expected     string `json:"expected"`
   531  		IsUnexpected bool   `json:"is_unexpected"`
   532  		Error        string `json:"error,omitempty"`
   533  	} `json:"tests"`
   534  }
   535  

View as plain text