Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Relax kem constaint from kem.AuthScheme to kem.Scheme.
  • Loading branch information
armfazh committed Jan 18, 2025
commit 4527e86e92125b993f522f78fa4924a16f474065
4 changes: 2 additions & 2 deletions hpke/algs.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func (k KEM) IsValid() bool {

// Scheme returns an instance of a KEM that supports authentication. Panics if
// the KEM identifier is invalid.
func (k KEM) Scheme() kem.AuthScheme {
func (k KEM) Scheme() kem.Scheme {
switch k {
case KEM_P256_HKDF_SHA256:
return dhkemp256hkdfsha256
Expand Down Expand Up @@ -283,6 +283,6 @@ func init() {
hybridkemX25519Kyber768.kemA = dhkemx25519hkdfsha256
hybridkemX25519Kyber768.kemB = kyber768.Scheme()

kemXwing.kem = xwing.Scheme()
kemXwing.Scheme = xwing.Scheme()
kemXwing.name = "HPKE_KEM_XWING"
}
59 changes: 4 additions & 55 deletions hpke/genericnoauthkem.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,70 +9,19 @@ import (

// genericNoAuthKEM wraps a generic KEM (kem.Scheme) to be used as a HPKE KEM.
type genericNoAuthKEM struct {
kem kem.Scheme
kem.Scheme
name string
}

func (h genericNoAuthKEM) PrivateKeySize() int { return h.kem.PrivateKeySize() }
func (h genericNoAuthKEM) SeedSize() int { return h.kem.SeedSize() }
func (h genericNoAuthKEM) CiphertextSize() int { return h.kem.CiphertextSize() }
func (h genericNoAuthKEM) PublicKeySize() int { return h.kem.PublicKeySize() }
func (h genericNoAuthKEM) EncapsulationSeedSize() int { return h.kem.EncapsulationSeedSize() }
func (h genericNoAuthKEM) SharedKeySize() int { return h.kem.SharedKeySize() }
func (h genericNoAuthKEM) Name() string { return h.name }

func (h genericNoAuthKEM) AuthDecapsulate(skR kem.PrivateKey,
ct []byte,
pkS kem.PublicKey,
) ([]byte, error) {
panic("AuthDecapsulate is not supported for this KEM")
}

func (h genericNoAuthKEM) AuthEncapsulate(pkr kem.PublicKey, sks kem.PrivateKey) (
ct []byte, ss []byte, err error,
) {
panic("AuthEncapsulate is not supported for this KEM")
}

func (h genericNoAuthKEM) AuthEncapsulateDeterministically(pkr kem.PublicKey, sks kem.PrivateKey, seed []byte) (ct, ss []byte, err error) {
panic("AuthEncapsulateDeterministically is not supported for this KEM")
}

func (h genericNoAuthKEM) Encapsulate(pkr kem.PublicKey) (
ct []byte, ss []byte, err error,
) {
return h.kem.Encapsulate(pkr)
}

func (h genericNoAuthKEM) Decapsulate(skr kem.PrivateKey, ct []byte) ([]byte, error) {
return h.kem.Decapsulate(skr, ct)
}

func (h genericNoAuthKEM) EncapsulateDeterministically(
pkr kem.PublicKey, seed []byte,
) (ct, ss []byte, err error) {
return h.kem.EncapsulateDeterministically(pkr, seed)
}
func (h genericNoAuthKEM) Name() string { return h.name }

// HPKE requires DeriveKeyPair() to take any seed larger than the private key
// size, whereas typical KEMs expect a specific seed size. We'll just use
// SHAKE256 to hash it to the right size as in X-Wing.
func (h genericNoAuthKEM) DeriveKeyPair(seed []byte) (kem.PublicKey, kem.PrivateKey) {
seed2 := make([]byte, h.kem.SeedSize())
seed2 := make([]byte, h.Scheme.SeedSize())
hh := sha3.NewShake256()
_, _ = hh.Write(seed)
_, _ = hh.Read(seed2)
return h.kem.DeriveKeyPair(seed2)
}

func (h genericNoAuthKEM) GenerateKeyPair() (kem.PublicKey, kem.PrivateKey, error) {
return h.kem.GenerateKeyPair()
}

func (h genericNoAuthKEM) UnmarshalBinaryPrivateKey(data []byte) (kem.PrivateKey, error) {
return h.kem.UnmarshalBinaryPrivateKey(data)
}

func (h genericNoAuthKEM) UnmarshalBinaryPublicKey(data []byte) (kem.PublicKey, error) {
return h.kem.UnmarshalBinaryPublicKey(data)
return h.Scheme.DeriveKeyPair(seed2)
}
15 changes: 13 additions & 2 deletions hpke/hpke.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,12 @@ func (s *Sender) allSetup(rnd io.Reader) ([]byte, Sealer, error) {
case modeBase, modePSK:
enc, ss, err = scheme.EncapsulateDeterministically(s.pkR, seed)
case modeAuth, modeAuthPSK:
enc, ss, err = scheme.AuthEncapsulateDeterministically(s.pkR, s.skS, seed)
authScheme, ok := scheme.(kem.AuthScheme)
if !ok {
return nil, nil, ErrInvalidAuthKEM
}

enc, ss, err = authScheme.AuthEncapsulateDeterministically(s.pkR, s.skS, seed)
}
if err != nil {
return nil, nil, err
Expand All @@ -246,7 +251,12 @@ func (r *Receiver) allSetup() (Opener, error) {
case modeBase, modePSK:
ss, err = scheme.Decapsulate(r.skR, r.enc)
case modeAuth, modeAuthPSK:
ss, err = scheme.AuthDecapsulate(r.skR, r.enc, r.pkS)
authScheme, ok := scheme.(kem.AuthScheme)
if !ok {
return nil, ErrInvalidAuthKEM
}

ss, err = authScheme.AuthDecapsulate(r.skR, r.enc, r.pkS)
}
if err != nil {
return nil, err
Expand All @@ -263,6 +273,7 @@ var (
ErrInvalidHPKESuite = errors.New("hpke: invalid HPKE suite")
ErrInvalidKDF = errors.New("hpke: invalid KDF identifier")
ErrInvalidKEM = errors.New("hpke: invalid KEM identifier")
ErrInvalidAuthKEM = errors.New("hpke: KEM does not support Auth mode")
ErrInvalidAEAD = errors.New("hpke: invalid AEAD identifier")
ErrInvalidKEMPublicKey = errors.New("hpke: invalid KEM public key")
ErrInvalidKEMPrivateKey = errors.New("hpke: invalid KEM private key")
Expand Down
38 changes: 18 additions & 20 deletions kem/xwing/scheme.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,21 @@ import (
// generic KEM API.

// Returns the generic KEM interface for X-Wing PQ/T hybrid KEM.
func Scheme() kem.Scheme { return &xwing }
func Scheme() kem.Scheme { return scheme{} }

type scheme struct{}

var xwing scheme

func (*scheme) Name() string { return "X-Wing" }
func (*scheme) PublicKeySize() int { return PublicKeySize }
func (*scheme) PrivateKeySize() int { return PrivateKeySize }
func (*scheme) SeedSize() int { return SeedSize }
func (*scheme) EncapsulationSeedSize() int { return EncapsulationSeedSize }
func (*scheme) SharedKeySize() int { return SharedKeySize }
func (*scheme) CiphertextSize() int { return CiphertextSize }
func (*PrivateKey) Scheme() kem.Scheme { return &xwing }
func (*PublicKey) Scheme() kem.Scheme { return &xwing }

func (sch *scheme) Encapsulate(pk kem.PublicKey) (ct, ss []byte, err error) {
func (scheme) Name() string { return "X-Wing" }
func (scheme) PublicKeySize() int { return PublicKeySize }
func (scheme) PrivateKeySize() int { return PrivateKeySize }
func (scheme) SeedSize() int { return SeedSize }
func (scheme) EncapsulationSeedSize() int { return EncapsulationSeedSize }
func (scheme) SharedKeySize() int { return SharedKeySize }
func (scheme) CiphertextSize() int { return CiphertextSize }
func (*PrivateKey) Scheme() kem.Scheme { return scheme{} }
func (*PublicKey) Scheme() kem.Scheme { return scheme{} }

func (sch scheme) Encapsulate(pk kem.PublicKey) (ct, ss []byte, err error) {
var seed [EncapsulationSeedSize]byte
_, err = cryptoRand.Read(seed[:])
if err != nil {
Expand All @@ -38,7 +36,7 @@ func (sch *scheme) Encapsulate(pk kem.PublicKey) (ct, ss []byte, err error) {
return sch.EncapsulateDeterministically(pk, seed[:])
}

func (sch *scheme) EncapsulateDeterministically(
func (scheme) EncapsulateDeterministically(
pk kem.PublicKey, seed []byte,
) ([]byte, []byte, error) {
if len(seed) != EncapsulationSeedSize {
Expand All @@ -56,7 +54,7 @@ func (sch *scheme) EncapsulateDeterministically(
return ct[:], ss[:], nil
}

func (*scheme) UnmarshalBinaryPublicKey(buf []byte) (kem.PublicKey, error) {
func (scheme) UnmarshalBinaryPublicKey(buf []byte) (kem.PublicKey, error) {
var pk PublicKey
if len(buf) != PublicKeySize {
return nil, kem.ErrPubKeySize
Expand All @@ -68,7 +66,7 @@ func (*scheme) UnmarshalBinaryPublicKey(buf []byte) (kem.PublicKey, error) {
return &pk, nil
}

func (*scheme) UnmarshalBinaryPrivateKey(buf []byte) (kem.PrivateKey, error) {
func (scheme) UnmarshalBinaryPrivateKey(buf []byte) (kem.PrivateKey, error) {
var sk PrivateKey
if len(buf) != PrivateKeySize {
return nil, kem.ErrPrivKeySize
Expand Down Expand Up @@ -114,17 +112,17 @@ func (pk *PublicKey) MarshalBinary() ([]byte, error) {
return ret[:], nil
}

func (*scheme) DeriveKeyPair(seed []byte) (kem.PublicKey, kem.PrivateKey) {
func (scheme) DeriveKeyPair(seed []byte) (kem.PublicKey, kem.PrivateKey) {
sk, pk := DeriveKeyPair(seed)
return pk, sk
}

func (sch *scheme) GenerateKeyPair() (kem.PublicKey, kem.PrivateKey, error) {
func (scheme) GenerateKeyPair() (kem.PublicKey, kem.PrivateKey, error) {
sk, pk, err := GenerateKeyPair(nil)
return pk, sk, err
}

func (*scheme) Decapsulate(sk kem.PrivateKey, ct []byte) ([]byte, error) {
func (scheme) Decapsulate(sk kem.PrivateKey, ct []byte) ([]byte, error) {
if len(ct) != CiphertextSize {
return nil, kem.ErrCiphertextSize
}
Expand Down
Loading