1
2
3
4
5 package hpke
6
7 import (
8 "bytes"
9 "crypto"
10 "crypto/ecdh"
11 "crypto/internal/fips140/drbg"
12 "crypto/internal/rand"
13 "crypto/mlkem"
14 "crypto/sha3"
15 "errors"
16 "internal/byteorder"
17 )
18
19 var mlkem768X25519 = &hybridKEM{
20 id: 0x647a,
21 label: `\./` +
22 `/^\`,
23 curve: ecdh.X25519(),
24
25 curveSeedSize: 32,
26 curvePointSize: 32,
27 pqEncapsKeySize: mlkem.EncapsulationKeySize768,
28 pqCiphertextSize: mlkem.CiphertextSize768,
29
30 pqNewPublicKey: func(data []byte) (crypto.Encapsulator, error) {
31 return mlkem.NewEncapsulationKey768(data)
32 },
33 pqNewPrivateKey: func(data []byte) (crypto.Decapsulator, error) {
34 return mlkem.NewDecapsulationKey768(data)
35 },
36 pqGenerateKey: func() (crypto.Decapsulator, error) {
37 return mlkem.GenerateKey768()
38 },
39 }
40
41
42
43 func MLKEM768X25519() KEM {
44 return mlkem768X25519
45 }
46
47 var mlkem768P256 = &hybridKEM{
48 id: 0x0050,
49 label: "MLKEM768-P256",
50 curve: ecdh.P256(),
51
52 curveSeedSize: 32,
53 curvePointSize: 65,
54 pqEncapsKeySize: mlkem.EncapsulationKeySize768,
55 pqCiphertextSize: mlkem.CiphertextSize768,
56
57 pqNewPublicKey: func(data []byte) (crypto.Encapsulator, error) {
58 return mlkem.NewEncapsulationKey768(data)
59 },
60 pqNewPrivateKey: func(data []byte) (crypto.Decapsulator, error) {
61 return mlkem.NewDecapsulationKey768(data)
62 },
63 pqGenerateKey: func() (crypto.Decapsulator, error) {
64 return mlkem.GenerateKey768()
65 },
66 }
67
68
69 func MLKEM768P256() KEM {
70 return mlkem768P256
71 }
72
73 var mlkem1024P384 = &hybridKEM{
74 id: 0x0051,
75 label: "MLKEM1024-P384",
76 curve: ecdh.P384(),
77
78 curveSeedSize: 48,
79 curvePointSize: 97,
80 pqEncapsKeySize: mlkem.EncapsulationKeySize1024,
81 pqCiphertextSize: mlkem.CiphertextSize1024,
82
83 pqNewPublicKey: func(data []byte) (crypto.Encapsulator, error) {
84 return mlkem.NewEncapsulationKey1024(data)
85 },
86 pqNewPrivateKey: func(data []byte) (crypto.Decapsulator, error) {
87 return mlkem.NewDecapsulationKey1024(data)
88 },
89 pqGenerateKey: func() (crypto.Decapsulator, error) {
90 return mlkem.GenerateKey1024()
91 },
92 }
93
94
95 func MLKEM1024P384() KEM {
96 return mlkem1024P384
97 }
98
99 type hybridKEM struct {
100 id uint16
101 label string
102 curve ecdh.Curve
103
104 curveSeedSize int
105 curvePointSize int
106 pqEncapsKeySize int
107 pqCiphertextSize int
108
109 pqNewPublicKey func(data []byte) (crypto.Encapsulator, error)
110 pqNewPrivateKey func(data []byte) (crypto.Decapsulator, error)
111 pqGenerateKey func() (crypto.Decapsulator, error)
112 }
113
114 func (kem *hybridKEM) ID() uint16 {
115 return kem.id
116 }
117
118 func (kem *hybridKEM) encSize() int {
119 return kem.pqCiphertextSize + kem.curvePointSize
120 }
121
122 func (kem *hybridKEM) sharedSecret(ssPQ, ssT, ctT, ekT []byte) []byte {
123 h := sha3.New256()
124 h.Write(ssPQ)
125 h.Write(ssT)
126 h.Write(ctT)
127 h.Write(ekT)
128 h.Write([]byte(kem.label))
129 return h.Sum(nil)
130 }
131
132 type hybridPublicKey struct {
133 kem *hybridKEM
134 t *ecdh.PublicKey
135 pq crypto.Encapsulator
136 }
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151 func NewHybridPublicKey(pq crypto.Encapsulator, t *ecdh.PublicKey) (PublicKey, error) {
152 switch t.Curve() {
153 case ecdh.X25519():
154 if _, ok := pq.(*mlkem.EncapsulationKey768); !ok {
155 return nil, errors.New("invalid PQ KEM for X25519 hybrid")
156 }
157 return &hybridPublicKey{mlkem768X25519, t, pq}, nil
158 case ecdh.P256():
159 if _, ok := pq.(*mlkem.EncapsulationKey768); !ok {
160 return nil, errors.New("invalid PQ KEM for P-256 hybrid")
161 }
162 return &hybridPublicKey{mlkem768P256, t, pq}, nil
163 case ecdh.P384():
164 if _, ok := pq.(*mlkem.EncapsulationKey1024); !ok {
165 return nil, errors.New("invalid PQ KEM for P-384 hybrid")
166 }
167 return &hybridPublicKey{mlkem1024P384, t, pq}, nil
168 default:
169 return nil, errors.New("unsupported curve")
170 }
171 }
172
173 func (kem *hybridKEM) NewPublicKey(data []byte) (PublicKey, error) {
174 if len(data) != kem.pqEncapsKeySize+kem.curvePointSize {
175 return nil, errors.New("invalid public key size")
176 }
177 pq, err := kem.pqNewPublicKey(data[:kem.pqEncapsKeySize])
178 if err != nil {
179 return nil, err
180 }
181 k, err := kem.curve.NewPublicKey(data[kem.pqEncapsKeySize:])
182 if err != nil {
183 return nil, err
184 }
185 return NewHybridPublicKey(pq, k)
186 }
187
188 func (pk *hybridPublicKey) KEM() KEM {
189 return pk.kem
190 }
191
192 func (pk *hybridPublicKey) Bytes() []byte {
193 return append(pk.pq.Bytes(), pk.t.Bytes()...)
194 }
195
196 var testingOnlyEncapsulate func() (ss, ct []byte)
197
198 func (pk *hybridPublicKey) encap() (sharedSecret []byte, encapPub []byte, err error) {
199 skE, err := pk.t.Curve().GenerateKey(rand.Reader)
200 if err != nil {
201 return nil, nil, err
202 }
203 if testingOnlyGenerateKey != nil {
204 skE = testingOnlyGenerateKey()
205 }
206 ssT, err := skE.ECDH(pk.t)
207 if err != nil {
208 return nil, nil, err
209 }
210 ctT := skE.PublicKey().Bytes()
211
212 ssPQ, ctPQ := pk.pq.Encapsulate()
213 if testingOnlyEncapsulate != nil {
214 ssPQ, ctPQ = testingOnlyEncapsulate()
215 }
216
217 ss := pk.kem.sharedSecret(ssPQ, ssT, ctT, pk.t.Bytes())
218 ct := append(ctPQ, ctT...)
219 return ss, ct, nil
220 }
221
222 type hybridPrivateKey struct {
223 kem *hybridKEM
224 seed []byte
225 t ecdh.KeyExchanger
226 pq crypto.Decapsulator
227 }
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244 func NewHybridPrivateKey(pq crypto.Decapsulator, t ecdh.KeyExchanger) (PrivateKey, error) {
245 return newHybridPrivateKey(pq, t, nil)
246 }
247
248 func (kem *hybridKEM) GenerateKey() (PrivateKey, error) {
249 seed := make([]byte, 32)
250 drbg.Read(seed)
251 return kem.NewPrivateKey(seed)
252 }
253
254 func (kem *hybridKEM) NewPrivateKey(priv []byte) (PrivateKey, error) {
255 if len(priv) != 32 {
256 return nil, errors.New("hpke: invalid hybrid KEM secret length")
257 }
258
259 s := sha3.NewSHAKE256()
260 s.Write(priv)
261
262 seedPQ := make([]byte, mlkem.SeedSize)
263 s.Read(seedPQ)
264 pq, err := kem.pqNewPrivateKey(seedPQ)
265 if err != nil {
266 return nil, err
267 }
268
269 seedT := make([]byte, kem.curveSeedSize)
270 for {
271 s.Read(seedT)
272 k, err := kem.curve.NewPrivateKey(seedT)
273 if err != nil {
274 continue
275 }
276 return newHybridPrivateKey(pq, k, priv)
277 }
278 }
279
280 func newHybridPrivateKey(pq crypto.Decapsulator, t ecdh.KeyExchanger, seed []byte) (PrivateKey, error) {
281 switch t.Curve() {
282 case ecdh.X25519():
283 if _, ok := pq.Encapsulator().(*mlkem.EncapsulationKey768); !ok {
284 return nil, errors.New("invalid PQ KEM for X25519 hybrid")
285 }
286 return &hybridPrivateKey{mlkem768X25519, bytes.Clone(seed), t, pq}, nil
287 case ecdh.P256():
288 if _, ok := pq.Encapsulator().(*mlkem.EncapsulationKey768); !ok {
289 return nil, errors.New("invalid PQ KEM for P-256 hybrid")
290 }
291 return &hybridPrivateKey{mlkem768P256, bytes.Clone(seed), t, pq}, nil
292 case ecdh.P384():
293 if _, ok := pq.Encapsulator().(*mlkem.EncapsulationKey1024); !ok {
294 return nil, errors.New("invalid PQ KEM for P-384 hybrid")
295 }
296 return &hybridPrivateKey{mlkem1024P384, bytes.Clone(seed), t, pq}, nil
297 default:
298 return nil, errors.New("unsupported curve")
299 }
300 }
301
302 func (kem *hybridKEM) DeriveKeyPair(ikm []byte) (PrivateKey, error) {
303 suiteID := byteorder.BEAppendUint16([]byte("KEM"), kem.id)
304 dk, err := SHAKE256().labeledDerive(suiteID, ikm, "DeriveKeyPair", nil, 32)
305 if err != nil {
306 return nil, err
307 }
308 return kem.NewPrivateKey(dk)
309 }
310
311 func (k *hybridPrivateKey) KEM() KEM {
312 return k.kem
313 }
314
315 func (k *hybridPrivateKey) Bytes() ([]byte, error) {
316 if k.seed == nil {
317 return nil, errors.New("private key seed not available")
318 }
319 return k.seed, nil
320 }
321
322 func (k *hybridPrivateKey) PublicKey() PublicKey {
323 return &hybridPublicKey{
324 kem: k.kem,
325 t: k.t.PublicKey(),
326 pq: k.pq.Encapsulator(),
327 }
328 }
329
330 func (k *hybridPrivateKey) decap(enc []byte) ([]byte, error) {
331 if len(enc) != k.kem.pqCiphertextSize+k.kem.curvePointSize {
332 return nil, errors.New("invalid encapsulated key size")
333 }
334 ctPQ, ctT := enc[:k.kem.pqCiphertextSize], enc[k.kem.pqCiphertextSize:]
335 ssPQ, err := k.pq.Decapsulate(ctPQ)
336 if err != nil {
337 return nil, err
338 }
339 pub, err := k.t.Curve().NewPublicKey(ctT)
340 if err != nil {
341 return nil, err
342 }
343 ssT, err := k.t.ECDH(pub)
344 if err != nil {
345 return nil, err
346 }
347 ss := k.kem.sharedSecret(ssPQ, ssT, ctT, k.t.PublicKey().Bytes())
348 return ss, nil
349 }
350
351 var mlkem768 = &mlkemKEM{
352 id: 0x0041,
353 ciphertextSize: mlkem.CiphertextSize768,
354 newPublicKey: func(data []byte) (crypto.Encapsulator, error) {
355 return mlkem.NewEncapsulationKey768(data)
356 },
357 newPrivateKey: func(data []byte) (crypto.Decapsulator, error) {
358 return mlkem.NewDecapsulationKey768(data)
359 },
360 generateKey: func() (crypto.Decapsulator, error) {
361 return mlkem.GenerateKey768()
362 },
363 }
364
365
366 func MLKEM768() KEM {
367 return mlkem768
368 }
369
370 var mlkem1024 = &mlkemKEM{
371 id: 0x0042,
372 ciphertextSize: mlkem.CiphertextSize1024,
373 newPublicKey: func(data []byte) (crypto.Encapsulator, error) {
374 return mlkem.NewEncapsulationKey1024(data)
375 },
376 newPrivateKey: func(data []byte) (crypto.Decapsulator, error) {
377 return mlkem.NewDecapsulationKey1024(data)
378 },
379 generateKey: func() (crypto.Decapsulator, error) {
380 return mlkem.GenerateKey1024()
381 },
382 }
383
384
385 func MLKEM1024() KEM {
386 return mlkem1024
387 }
388
389 type mlkemKEM struct {
390 id uint16
391 ciphertextSize int
392 newPublicKey func(data []byte) (crypto.Encapsulator, error)
393 newPrivateKey func(data []byte) (crypto.Decapsulator, error)
394 generateKey func() (crypto.Decapsulator, error)
395 }
396
397 func (kem *mlkemKEM) ID() uint16 {
398 return kem.id
399 }
400
401 func (kem *mlkemKEM) encSize() int {
402 return kem.ciphertextSize
403 }
404
405 type mlkemPublicKey struct {
406 kem *mlkemKEM
407 pq crypto.Encapsulator
408 }
409
410
411
412
413
414
415
416
417
418
419
420
421 func NewMLKEMPublicKey(pub crypto.Encapsulator) (PublicKey, error) {
422 switch pub.(type) {
423 case *mlkem.EncapsulationKey768:
424 return &mlkemPublicKey{mlkem768, pub}, nil
425 case *mlkem.EncapsulationKey1024:
426 return &mlkemPublicKey{mlkem1024, pub}, nil
427 default:
428 return nil, errors.New("unsupported public key type")
429 }
430 }
431
432 func (kem *mlkemKEM) NewPublicKey(data []byte) (PublicKey, error) {
433 pq, err := kem.newPublicKey(data)
434 if err != nil {
435 return nil, err
436 }
437 return NewMLKEMPublicKey(pq)
438 }
439
440 func (pk *mlkemPublicKey) KEM() KEM {
441 return pk.kem
442 }
443
444 func (pk *mlkemPublicKey) Bytes() []byte {
445 return pk.pq.Bytes()
446 }
447
448 func (pk *mlkemPublicKey) encap() (sharedSecret []byte, encapPub []byte, err error) {
449 ss, ct := pk.pq.Encapsulate()
450 if testingOnlyEncapsulate != nil {
451 ss, ct = testingOnlyEncapsulate()
452 }
453 return ss, ct, nil
454 }
455
456 type mlkemPrivateKey struct {
457 kem *mlkemKEM
458 pq crypto.Decapsulator
459 }
460
461
462
463
464
465
466
467
468
469
470
471
472 func NewMLKEMPrivateKey(priv crypto.Decapsulator) (PrivateKey, error) {
473 switch priv.Encapsulator().(type) {
474 case *mlkem.EncapsulationKey768:
475 return &mlkemPrivateKey{mlkem768, priv}, nil
476 case *mlkem.EncapsulationKey1024:
477 return &mlkemPrivateKey{mlkem1024, priv}, nil
478 default:
479 return nil, errors.New("unsupported public key type")
480 }
481 }
482
483 func (kem *mlkemKEM) GenerateKey() (PrivateKey, error) {
484 pq, err := kem.generateKey()
485 if err != nil {
486 return nil, err
487 }
488 return NewMLKEMPrivateKey(pq)
489 }
490
491 func (kem *mlkemKEM) NewPrivateKey(priv []byte) (PrivateKey, error) {
492 pq, err := kem.newPrivateKey(priv)
493 if err != nil {
494 return nil, err
495 }
496 return NewMLKEMPrivateKey(pq)
497 }
498
499 func (kem *mlkemKEM) DeriveKeyPair(ikm []byte) (PrivateKey, error) {
500 suiteID := byteorder.BEAppendUint16([]byte("KEM"), kem.id)
501 dk, err := SHAKE256().labeledDerive(suiteID, ikm, "DeriveKeyPair", nil, 64)
502 if err != nil {
503 return nil, err
504 }
505 return kem.NewPrivateKey(dk)
506 }
507
508 func (k *mlkemPrivateKey) KEM() KEM {
509 return k.kem
510 }
511
512 func (k *mlkemPrivateKey) Bytes() ([]byte, error) {
513 pq, ok := k.pq.(interface {
514 Bytes() []byte
515 })
516 if !ok {
517 return nil, errors.New("private key seed not available")
518 }
519 return pq.Bytes(), nil
520 }
521
522 func (k *mlkemPrivateKey) PublicKey() PublicKey {
523 return &mlkemPublicKey{
524 kem: k.kem,
525 pq: k.pq.Encapsulator(),
526 }
527 }
528
529 func (k *mlkemPrivateKey) decap(enc []byte) ([]byte, error) {
530 return k.pq.Decapsulate(enc)
531 }
532
View as plain text