Source file src/crypto/hpke/pq.go

     1  // Copyright 2025 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package hpke
     6  
     7  import (
     8  	"bytes"
     9  	"crypto"
    10  	"crypto/ecdh"
    11  	"crypto/internal/fips140/drbg"
    12  	"crypto/internal/rand"
    13  	"crypto/mlkem"
    14  	"crypto/sha3"
    15  	"errors"
    16  	"internal/byteorder"
    17  )
    18  
    19  var mlkem768X25519 = &hybridKEM{
    20  	id: 0x647a,
    21  	label: /**/ `\./` +
    22  		/*   */ `/^\`,
    23  	curve: ecdh.X25519(),
    24  
    25  	curveSeedSize:    32,
    26  	curvePointSize:   32,
    27  	pqEncapsKeySize:  mlkem.EncapsulationKeySize768,
    28  	pqCiphertextSize: mlkem.CiphertextSize768,
    29  
    30  	pqNewPublicKey: func(data []byte) (crypto.Encapsulator, error) {
    31  		return mlkem.NewEncapsulationKey768(data)
    32  	},
    33  	pqNewPrivateKey: func(data []byte) (crypto.Decapsulator, error) {
    34  		return mlkem.NewDecapsulationKey768(data)
    35  	},
    36  	pqGenerateKey: func() (crypto.Decapsulator, error) {
    37  		return mlkem.GenerateKey768()
    38  	},
    39  }
    40  
    41  // MLKEM768X25519 returns a KEM implementing MLKEM768-X25519 (a.k.a. X-Wing)
    42  // from draft-ietf-hpke-pq.
    43  func MLKEM768X25519() KEM {
    44  	return mlkem768X25519
    45  }
    46  
    47  var mlkem768P256 = &hybridKEM{
    48  	id:    0x0050,
    49  	label: "MLKEM768-P256",
    50  	curve: ecdh.P256(),
    51  
    52  	curveSeedSize:    32,
    53  	curvePointSize:   65,
    54  	pqEncapsKeySize:  mlkem.EncapsulationKeySize768,
    55  	pqCiphertextSize: mlkem.CiphertextSize768,
    56  
    57  	pqNewPublicKey: func(data []byte) (crypto.Encapsulator, error) {
    58  		return mlkem.NewEncapsulationKey768(data)
    59  	},
    60  	pqNewPrivateKey: func(data []byte) (crypto.Decapsulator, error) {
    61  		return mlkem.NewDecapsulationKey768(data)
    62  	},
    63  	pqGenerateKey: func() (crypto.Decapsulator, error) {
    64  		return mlkem.GenerateKey768()
    65  	},
    66  }
    67  
    68  // MLKEM768P256 returns a KEM implementing MLKEM768-P256 from draft-ietf-hpke-pq.
    69  func MLKEM768P256() KEM {
    70  	return mlkem768P256
    71  }
    72  
    73  var mlkem1024P384 = &hybridKEM{
    74  	id:    0x0051,
    75  	label: "MLKEM1024-P384",
    76  	curve: ecdh.P384(),
    77  
    78  	curveSeedSize:    48,
    79  	curvePointSize:   97,
    80  	pqEncapsKeySize:  mlkem.EncapsulationKeySize1024,
    81  	pqCiphertextSize: mlkem.CiphertextSize1024,
    82  
    83  	pqNewPublicKey: func(data []byte) (crypto.Encapsulator, error) {
    84  		return mlkem.NewEncapsulationKey1024(data)
    85  	},
    86  	pqNewPrivateKey: func(data []byte) (crypto.Decapsulator, error) {
    87  		return mlkem.NewDecapsulationKey1024(data)
    88  	},
    89  	pqGenerateKey: func() (crypto.Decapsulator, error) {
    90  		return mlkem.GenerateKey1024()
    91  	},
    92  }
    93  
    94  // MLKEM1024P384 returns a KEM implementing MLKEM1024-P384 from draft-ietf-hpke-pq.
    95  func MLKEM1024P384() KEM {
    96  	return mlkem1024P384
    97  }
    98  
    99  type hybridKEM struct {
   100  	id    uint16
   101  	label string
   102  	curve ecdh.Curve
   103  
   104  	curveSeedSize    int
   105  	curvePointSize   int
   106  	pqEncapsKeySize  int
   107  	pqCiphertextSize int
   108  
   109  	pqNewPublicKey  func(data []byte) (crypto.Encapsulator, error)
   110  	pqNewPrivateKey func(data []byte) (crypto.Decapsulator, error)
   111  	pqGenerateKey   func() (crypto.Decapsulator, error)
   112  }
   113  
   114  func (kem *hybridKEM) ID() uint16 {
   115  	return kem.id
   116  }
   117  
   118  func (kem *hybridKEM) encSize() int {
   119  	return kem.pqCiphertextSize + kem.curvePointSize
   120  }
   121  
   122  func (kem *hybridKEM) sharedSecret(ssPQ, ssT, ctT, ekT []byte) []byte {
   123  	h := sha3.New256()
   124  	h.Write(ssPQ)
   125  	h.Write(ssT)
   126  	h.Write(ctT)
   127  	h.Write(ekT)
   128  	h.Write([]byte(kem.label))
   129  	return h.Sum(nil)
   130  }
   131  
   132  type hybridPublicKey struct {
   133  	kem *hybridKEM
   134  	t   *ecdh.PublicKey
   135  	pq  crypto.Encapsulator
   136  }
   137  
   138  // NewHybridPublicKey returns a PublicKey implementing one of
   139  //
   140  //   - MLKEM768-X25519 (a.k.a. X-Wing)
   141  //   - MLKEM768-P256
   142  //   - MLKEM1024-P384
   143  //
   144  // from draft-ietf-hpke-pq, depending on the underlying curve of t
   145  // ([ecdh.X25519], [ecdh.P256], or [ecdh.P384]) and the type of pq (either
   146  // *[mlkem.EncapsulationKey768] or *[mlkem.EncapsulationKey1024]).
   147  //
   148  // This function is meant for applications that already have instantiated
   149  // crypto/ecdh and crypto/mlkem public keys. Otherwise, applications should use
   150  // the [KEM.NewPublicKey] method of e.g. [MLKEM768X25519].
   151  func NewHybridPublicKey(pq crypto.Encapsulator, t *ecdh.PublicKey) (PublicKey, error) {
   152  	switch t.Curve() {
   153  	case ecdh.X25519():
   154  		if _, ok := pq.(*mlkem.EncapsulationKey768); !ok {
   155  			return nil, errors.New("invalid PQ KEM for X25519 hybrid")
   156  		}
   157  		return &hybridPublicKey{mlkem768X25519, t, pq}, nil
   158  	case ecdh.P256():
   159  		if _, ok := pq.(*mlkem.EncapsulationKey768); !ok {
   160  			return nil, errors.New("invalid PQ KEM for P-256 hybrid")
   161  		}
   162  		return &hybridPublicKey{mlkem768P256, t, pq}, nil
   163  	case ecdh.P384():
   164  		if _, ok := pq.(*mlkem.EncapsulationKey1024); !ok {
   165  			return nil, errors.New("invalid PQ KEM for P-384 hybrid")
   166  		}
   167  		return &hybridPublicKey{mlkem1024P384, t, pq}, nil
   168  	default:
   169  		return nil, errors.New("unsupported curve")
   170  	}
   171  }
   172  
   173  func (kem *hybridKEM) NewPublicKey(data []byte) (PublicKey, error) {
   174  	if len(data) != kem.pqEncapsKeySize+kem.curvePointSize {
   175  		return nil, errors.New("invalid public key size")
   176  	}
   177  	pq, err := kem.pqNewPublicKey(data[:kem.pqEncapsKeySize])
   178  	if err != nil {
   179  		return nil, err
   180  	}
   181  	k, err := kem.curve.NewPublicKey(data[kem.pqEncapsKeySize:])
   182  	if err != nil {
   183  		return nil, err
   184  	}
   185  	return NewHybridPublicKey(pq, k)
   186  }
   187  
   188  func (pk *hybridPublicKey) KEM() KEM {
   189  	return pk.kem
   190  }
   191  
   192  func (pk *hybridPublicKey) Bytes() []byte {
   193  	return append(pk.pq.Bytes(), pk.t.Bytes()...)
   194  }
   195  
   196  var testingOnlyEncapsulate func() (ss, ct []byte)
   197  
   198  func (pk *hybridPublicKey) encap() (sharedSecret []byte, encapPub []byte, err error) {
   199  	skE, err := pk.t.Curve().GenerateKey(rand.Reader)
   200  	if err != nil {
   201  		return nil, nil, err
   202  	}
   203  	if testingOnlyGenerateKey != nil {
   204  		skE = testingOnlyGenerateKey()
   205  	}
   206  	ssT, err := skE.ECDH(pk.t)
   207  	if err != nil {
   208  		return nil, nil, err
   209  	}
   210  	ctT := skE.PublicKey().Bytes()
   211  
   212  	ssPQ, ctPQ := pk.pq.Encapsulate()
   213  	if testingOnlyEncapsulate != nil {
   214  		ssPQ, ctPQ = testingOnlyEncapsulate()
   215  	}
   216  
   217  	ss := pk.kem.sharedSecret(ssPQ, ssT, ctT, pk.t.Bytes())
   218  	ct := append(ctPQ, ctT...)
   219  	return ss, ct, nil
   220  }
   221  
   222  type hybridPrivateKey struct {
   223  	kem  *hybridKEM
   224  	seed []byte // can be nil
   225  	t    ecdh.KeyExchanger
   226  	pq   crypto.Decapsulator
   227  }
   228  
   229  // NewHybridPrivateKey returns a PrivateKey implementing
   230  //
   231  //   - MLKEM768-X25519 (a.k.a. X-Wing)
   232  //   - MLKEM768-P256
   233  //   - MLKEM1024-P384
   234  //
   235  // from draft-ietf-hpke-pq, depending on the underlying curve of t
   236  // ([ecdh.X25519], [ecdh.P256], or [ecdh.P384]) and the type of pq.Encapsulator()
   237  // (either *[mlkem.EncapsulationKey768] or *[mlkem.EncapsulationKey1024]).
   238  //
   239  // This function is meant for applications that already have instantiated
   240  // crypto/ecdh and crypto/mlkem private keys, or another implementation of a
   241  // [ecdh.KeyExchanger] and [crypto.Decapsulator] (e.g. a hardware key).
   242  // Otherwise, applications should use the [KEM.NewPrivateKey] method of e.g.
   243  // [MLKEM768X25519].
   244  func NewHybridPrivateKey(pq crypto.Decapsulator, t ecdh.KeyExchanger) (PrivateKey, error) {
   245  	return newHybridPrivateKey(pq, t, nil)
   246  }
   247  
   248  func (kem *hybridKEM) GenerateKey() (PrivateKey, error) {
   249  	seed := make([]byte, 32)
   250  	drbg.Read(seed)
   251  	return kem.NewPrivateKey(seed)
   252  }
   253  
   254  func (kem *hybridKEM) NewPrivateKey(priv []byte) (PrivateKey, error) {
   255  	if len(priv) != 32 {
   256  		return nil, errors.New("hpke: invalid hybrid KEM secret length")
   257  	}
   258  
   259  	s := sha3.NewSHAKE256()
   260  	s.Write(priv)
   261  
   262  	seedPQ := make([]byte, mlkem.SeedSize)
   263  	s.Read(seedPQ)
   264  	pq, err := kem.pqNewPrivateKey(seedPQ)
   265  	if err != nil {
   266  		return nil, err
   267  	}
   268  
   269  	seedT := make([]byte, kem.curveSeedSize)
   270  	for {
   271  		s.Read(seedT)
   272  		k, err := kem.curve.NewPrivateKey(seedT)
   273  		if err != nil {
   274  			continue
   275  		}
   276  		return newHybridPrivateKey(pq, k, priv)
   277  	}
   278  }
   279  
   280  func newHybridPrivateKey(pq crypto.Decapsulator, t ecdh.KeyExchanger, seed []byte) (PrivateKey, error) {
   281  	switch t.Curve() {
   282  	case ecdh.X25519():
   283  		if _, ok := pq.Encapsulator().(*mlkem.EncapsulationKey768); !ok {
   284  			return nil, errors.New("invalid PQ KEM for X25519 hybrid")
   285  		}
   286  		return &hybridPrivateKey{mlkem768X25519, bytes.Clone(seed), t, pq}, nil
   287  	case ecdh.P256():
   288  		if _, ok := pq.Encapsulator().(*mlkem.EncapsulationKey768); !ok {
   289  			return nil, errors.New("invalid PQ KEM for P-256 hybrid")
   290  		}
   291  		return &hybridPrivateKey{mlkem768P256, bytes.Clone(seed), t, pq}, nil
   292  	case ecdh.P384():
   293  		if _, ok := pq.Encapsulator().(*mlkem.EncapsulationKey1024); !ok {
   294  			return nil, errors.New("invalid PQ KEM for P-384 hybrid")
   295  		}
   296  		return &hybridPrivateKey{mlkem1024P384, bytes.Clone(seed), t, pq}, nil
   297  	default:
   298  		return nil, errors.New("unsupported curve")
   299  	}
   300  }
   301  
   302  func (kem *hybridKEM) DeriveKeyPair(ikm []byte) (PrivateKey, error) {
   303  	suiteID := byteorder.BEAppendUint16([]byte("KEM"), kem.id)
   304  	dk, err := SHAKE256().labeledDerive(suiteID, ikm, "DeriveKeyPair", nil, 32)
   305  	if err != nil {
   306  		return nil, err
   307  	}
   308  	return kem.NewPrivateKey(dk)
   309  }
   310  
   311  func (k *hybridPrivateKey) KEM() KEM {
   312  	return k.kem
   313  }
   314  
   315  func (k *hybridPrivateKey) Bytes() ([]byte, error) {
   316  	if k.seed == nil {
   317  		return nil, errors.New("private key seed not available")
   318  	}
   319  	return k.seed, nil
   320  }
   321  
   322  func (k *hybridPrivateKey) PublicKey() PublicKey {
   323  	return &hybridPublicKey{
   324  		kem: k.kem,
   325  		t:   k.t.PublicKey(),
   326  		pq:  k.pq.Encapsulator(),
   327  	}
   328  }
   329  
   330  func (k *hybridPrivateKey) decap(enc []byte) ([]byte, error) {
   331  	if len(enc) != k.kem.pqCiphertextSize+k.kem.curvePointSize {
   332  		return nil, errors.New("invalid encapsulated key size")
   333  	}
   334  	ctPQ, ctT := enc[:k.kem.pqCiphertextSize], enc[k.kem.pqCiphertextSize:]
   335  	ssPQ, err := k.pq.Decapsulate(ctPQ)
   336  	if err != nil {
   337  		return nil, err
   338  	}
   339  	pub, err := k.t.Curve().NewPublicKey(ctT)
   340  	if err != nil {
   341  		return nil, err
   342  	}
   343  	ssT, err := k.t.ECDH(pub)
   344  	if err != nil {
   345  		return nil, err
   346  	}
   347  	ss := k.kem.sharedSecret(ssPQ, ssT, ctT, k.t.PublicKey().Bytes())
   348  	return ss, nil
   349  }
   350  
   351  var mlkem768 = &mlkemKEM{
   352  	id:             0x0041,
   353  	ciphertextSize: mlkem.CiphertextSize768,
   354  	newPublicKey: func(data []byte) (crypto.Encapsulator, error) {
   355  		return mlkem.NewEncapsulationKey768(data)
   356  	},
   357  	newPrivateKey: func(data []byte) (crypto.Decapsulator, error) {
   358  		return mlkem.NewDecapsulationKey768(data)
   359  	},
   360  	generateKey: func() (crypto.Decapsulator, error) {
   361  		return mlkem.GenerateKey768()
   362  	},
   363  }
   364  
   365  // MLKEM768 returns a KEM implementing ML-KEM-768 from draft-ietf-hpke-pq.
   366  func MLKEM768() KEM {
   367  	return mlkem768
   368  }
   369  
   370  var mlkem1024 = &mlkemKEM{
   371  	id:             0x0042,
   372  	ciphertextSize: mlkem.CiphertextSize1024,
   373  	newPublicKey: func(data []byte) (crypto.Encapsulator, error) {
   374  		return mlkem.NewEncapsulationKey1024(data)
   375  	},
   376  	newPrivateKey: func(data []byte) (crypto.Decapsulator, error) {
   377  		return mlkem.NewDecapsulationKey1024(data)
   378  	},
   379  	generateKey: func() (crypto.Decapsulator, error) {
   380  		return mlkem.GenerateKey1024()
   381  	},
   382  }
   383  
   384  // MLKEM1024 returns a KEM implementing ML-KEM-1024 from draft-ietf-hpke-pq.
   385  func MLKEM1024() KEM {
   386  	return mlkem1024
   387  }
   388  
   389  type mlkemKEM struct {
   390  	id             uint16
   391  	ciphertextSize int
   392  	newPublicKey   func(data []byte) (crypto.Encapsulator, error)
   393  	newPrivateKey  func(data []byte) (crypto.Decapsulator, error)
   394  	generateKey    func() (crypto.Decapsulator, error)
   395  }
   396  
   397  func (kem *mlkemKEM) ID() uint16 {
   398  	return kem.id
   399  }
   400  
   401  func (kem *mlkemKEM) encSize() int {
   402  	return kem.ciphertextSize
   403  }
   404  
   405  type mlkemPublicKey struct {
   406  	kem *mlkemKEM
   407  	pq  crypto.Encapsulator
   408  }
   409  
   410  // NewMLKEMPublicKey returns a KEMPublicKey implementing
   411  //
   412  //   - ML-KEM-768
   413  //   - ML-KEM-1024
   414  //
   415  // from draft-ietf-hpke-pq, depending on the type of pub
   416  // (*[mlkem.EncapsulationKey768] or *[mlkem.EncapsulationKey1024]).
   417  //
   418  // This function is meant for applications that already have an instantiated
   419  // crypto/mlkem public key. Otherwise, applications should use the
   420  // [KEM.NewPublicKey] method of e.g. [MLKEM768].
   421  func NewMLKEMPublicKey(pub crypto.Encapsulator) (PublicKey, error) {
   422  	switch pub.(type) {
   423  	case *mlkem.EncapsulationKey768:
   424  		return &mlkemPublicKey{mlkem768, pub}, nil
   425  	case *mlkem.EncapsulationKey1024:
   426  		return &mlkemPublicKey{mlkem1024, pub}, nil
   427  	default:
   428  		return nil, errors.New("unsupported public key type")
   429  	}
   430  }
   431  
   432  func (kem *mlkemKEM) NewPublicKey(data []byte) (PublicKey, error) {
   433  	pq, err := kem.newPublicKey(data)
   434  	if err != nil {
   435  		return nil, err
   436  	}
   437  	return NewMLKEMPublicKey(pq)
   438  }
   439  
   440  func (pk *mlkemPublicKey) KEM() KEM {
   441  	return pk.kem
   442  }
   443  
   444  func (pk *mlkemPublicKey) Bytes() []byte {
   445  	return pk.pq.Bytes()
   446  }
   447  
   448  func (pk *mlkemPublicKey) encap() (sharedSecret []byte, encapPub []byte, err error) {
   449  	ss, ct := pk.pq.Encapsulate()
   450  	if testingOnlyEncapsulate != nil {
   451  		ss, ct = testingOnlyEncapsulate()
   452  	}
   453  	return ss, ct, nil
   454  }
   455  
   456  type mlkemPrivateKey struct {
   457  	kem *mlkemKEM
   458  	pq  crypto.Decapsulator
   459  }
   460  
   461  // NewMLKEMPrivateKey returns a KEMPrivateKey implementing
   462  //
   463  //   - ML-KEM-768
   464  //   - ML-KEM-1024
   465  //
   466  // from draft-ietf-hpke-pq, depending on the type of priv.Encapsulator()
   467  // (either *[mlkem.EncapsulationKey768] or *[mlkem.EncapsulationKey1024]).
   468  //
   469  // This function is meant for applications that already have an instantiated
   470  // crypto/mlkem private key. Otherwise, applications should use the
   471  // [KEM.NewPrivateKey] method of e.g. [MLKEM768].
   472  func NewMLKEMPrivateKey(priv crypto.Decapsulator) (PrivateKey, error) {
   473  	switch priv.Encapsulator().(type) {
   474  	case *mlkem.EncapsulationKey768:
   475  		return &mlkemPrivateKey{mlkem768, priv}, nil
   476  	case *mlkem.EncapsulationKey1024:
   477  		return &mlkemPrivateKey{mlkem1024, priv}, nil
   478  	default:
   479  		return nil, errors.New("unsupported public key type")
   480  	}
   481  }
   482  
   483  func (kem *mlkemKEM) GenerateKey() (PrivateKey, error) {
   484  	pq, err := kem.generateKey()
   485  	if err != nil {
   486  		return nil, err
   487  	}
   488  	return NewMLKEMPrivateKey(pq)
   489  }
   490  
   491  func (kem *mlkemKEM) NewPrivateKey(priv []byte) (PrivateKey, error) {
   492  	pq, err := kem.newPrivateKey(priv)
   493  	if err != nil {
   494  		return nil, err
   495  	}
   496  	return NewMLKEMPrivateKey(pq)
   497  }
   498  
   499  func (kem *mlkemKEM) DeriveKeyPair(ikm []byte) (PrivateKey, error) {
   500  	suiteID := byteorder.BEAppendUint16([]byte("KEM"), kem.id)
   501  	dk, err := SHAKE256().labeledDerive(suiteID, ikm, "DeriveKeyPair", nil, 64)
   502  	if err != nil {
   503  		return nil, err
   504  	}
   505  	return kem.NewPrivateKey(dk)
   506  }
   507  
   508  func (k *mlkemPrivateKey) KEM() KEM {
   509  	return k.kem
   510  }
   511  
   512  func (k *mlkemPrivateKey) Bytes() ([]byte, error) {
   513  	pq, ok := k.pq.(interface {
   514  		Bytes() []byte
   515  	})
   516  	if !ok {
   517  		return nil, errors.New("private key seed not available")
   518  	}
   519  	return pq.Bytes(), nil
   520  }
   521  
   522  func (k *mlkemPrivateKey) PublicKey() PublicKey {
   523  	return &mlkemPublicKey{
   524  		kem: k.kem,
   525  		pq:  k.pq.Encapsulator(),
   526  	}
   527  }
   528  
   529  func (k *mlkemPrivateKey) decap(enc []byte) ([]byte, error) {
   530  	return k.pq.Decapsulate(enc)
   531  }
   532  

View as plain text