Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make it constexpr #29

Merged
merged 9 commits into from
Aug 5, 2024
183 changes: 119 additions & 64 deletions README.md

Large diffs are not rendered by default.

16 changes: 14 additions & 2 deletions include/ml_dsa/internals/math/field.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,18 @@ struct zq_t
// Modulo Multiplication
inline constexpr zq_t operator*(const zq_t rhs) const
{
#ifdef __SIZEOF_INT128__
__extension__ using uint128_t = unsigned __int128;

const uint64_t t = static_cast<uint64_t>(this->v) * static_cast<uint64_t>(rhs.v); // (23+23) significant bits, from LSB
const uint128_t tR = static_cast<uint128_t>(t) * static_cast<uint128_t>(R); // (23+23+24) significant bits, from LSB

const uint64_t res = static_cast<uint64_t>(tR >> 46); // 24 significant bits, from LSB
const uint64_t resQ = res * static_cast<uint64_t>(Q); // (24+23) significant bits, from LSB

const uint32_t reduced = reduce_once(static_cast<uint32_t>(t - resQ));
return reduced;
#else
const uint64_t t0 = static_cast<uint64_t>(this->v);
const uint64_t t1 = static_cast<uint64_t>(rhs.v);
const uint64_t t2 = t0 * t1;
Expand Down Expand Up @@ -101,6 +113,7 @@ struct zq_t

const uint32_t t7 = reduce_once(t6);
return zq_t(t7);
#endif
}
inline constexpr void operator*=(const zq_t rhs) { *this = *this * rhs; }

Expand Down Expand Up @@ -175,8 +188,7 @@ struct zq_t
return t5;
}

// Given a 32 -bit unsigned integer `v` such that `v` ∈ [0, 2*Q), this routine can be invoked for reducing `v` modulo
// prime Q.
// Given a 32 -bit unsigned integer `v` such that `v` ∈ [0, 2*Q), this routine can be invoked for reducing `v` modulo prime Q.
static inline constexpr uint32_t reduce_once(const uint32_t val)
{
const uint32_t t0 = val - Q;
Expand Down
8 changes: 3 additions & 5 deletions include/ml_dsa/internals/ml_dsa.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ static constexpr size_t RND_BYTE_LEN = 32;
//
// See algorithm 1 of ML-DSA draft standard @ https://doi.org/10.6028/NIST.FIPS.204.ipd.
template<size_t k, size_t l, size_t d, uint32_t η>
static inline void
static inline constexpr void
keygen(std::span<const uint8_t, KEYGEN_SEED_BYTE_LEN> ξ,
std::span<uint8_t, ml_dsa_utils::pub_key_len(k, d)> pubkey,
std::span<uint8_t, ml_dsa_utils::sec_key_len(k, l, η, d)> seckey)
Expand Down Expand Up @@ -117,7 +117,7 @@ keygen(std::span<const uint8_t, KEYGEN_SEED_BYTE_LEN> ξ,
//
// See algorithm 2 of ML-DSA draft standard @ https://doi.org/10.6028/NIST.FIPS.204.ipd.
template<size_t k, size_t l, size_t d, uint32_t η, uint32_t γ1, uint32_t γ2, uint32_t τ, uint32_t β, size_t ω, size_t λ>
static inline void
static inline constexpr void
sign(std::span<const uint8_t, RND_BYTE_LEN> rnd,
std::span<const uint8_t, ml_dsa_utils::sec_key_len(k, l, η, d)> seckey,
std::span<const uint8_t> msg,
Expand Down Expand Up @@ -187,7 +187,6 @@ sign(std::span<const uint8_t, RND_BYTE_LEN> rnd,
std::array<uint8_t, (2 * λ) / std::numeric_limits<uint8_t>::digits> c_tilda{};
auto c_tilda_span = std::span(c_tilda);
auto c1_tilda = c_tilda_span.template first<32>();
auto c2_tilda = c_tilda_span.template last<32>();

while (!has_signed) {
std::array<ml_dsa_field::zq_t, l * ml_dsa_ntt::N> y{};
Expand Down Expand Up @@ -293,7 +292,7 @@ sign(std::span<const uint8_t, RND_BYTE_LEN> rnd,
//
// See algorithm 3 of ML-DSA draft standard @ https://doi.org/10.6028/NIST.FIPS.204.ipd.
template<size_t k, size_t l, size_t d, uint32_t γ1, uint32_t γ2, uint32_t τ, uint32_t β, size_t ω, size_t λ>
static inline bool
static inline constexpr bool
verify(std::span<const uint8_t, ml_dsa_utils::pub_key_len(k, d)> pubkey, std::span<const uint8_t> msg, std::span<const uint8_t, ml_dsa_utils::sig_len(k, l, γ1, ω, λ)> sig)
requires(ml_dsa_params::check_verify_params(k, l, d, γ1, γ2, τ, β, ω, λ))
{
Expand All @@ -308,7 +307,6 @@ verify(std::span<const uint8_t, ml_dsa_utils::pub_key_len(k, d)> pubkey, std::sp

auto c_tilda = sig.template first<sigoff1 - sigoff0>();
auto c1_tilda = c_tilda.template first<32>();
auto c2_tilda = c_tilda.template last<32>();
auto z_encoded = sig.template subspan<sigoff1, sigoff2 - sigoff1>();
auto h_encoded = sig.template subspan<sigoff2, sigoff3 - sigoff2>();

Expand Down
14 changes: 14 additions & 0 deletions include/ml_dsa/internals/poly/ntt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ static constexpr auto ζ_NEG_EXP = []() {
static inline constexpr void
ntt(std::span<ml_dsa_field::zq_t, N> poly)
{
#if (not defined __clang__) && (defined __GNUG__)
#pragma GCC unroll 8
#endif
for (int64_t l = LOG2N - 1; l >= 0; l--) {
const size_t len = 1ul << l;
const size_t lenx2 = len << 1;
Expand All @@ -76,6 +79,10 @@ ntt(std::span<ml_dsa_field::zq_t, N> poly)
const size_t k_now = k_beg + (start >> (l + 1));
const ml_dsa_field::zq_t ζ_exp = ζ_EXP[k_now];

#if (not defined __clang__) && (defined __GNUG__)
#pragma GCC unroll 4
#pragma GCC ivdep
#endif
for (size_t i = start; i < start + len; i++) {
auto tmp = ζ_exp * poly[i + len];

Expand All @@ -97,6 +104,9 @@ ntt(std::span<ml_dsa_field::zq_t, N> poly)
static inline constexpr void
intt(std::span<ml_dsa_field::zq_t, N> poly)
{
#if (not defined __clang__) && (defined __GNUG__)
#pragma GCC unroll 8
#endif
for (size_t l = 0; l < LOG2N; l++) {
const size_t len = 1ul << l;
const size_t lenx2 = len << 1;
Expand All @@ -106,6 +116,10 @@ intt(std::span<ml_dsa_field::zq_t, N> poly)
const size_t k_now = k_beg - (start >> (l + 1));
const ml_dsa_field::zq_t neg_ζ_exp = ζ_NEG_EXP[k_now];

#if (not defined __clang__) && (defined __GNUG__)
#pragma GCC unroll 4
#pragma GCC ivdep
#endif
for (size_t i = start; i < start + len; i++) {
const auto tmp = poly[i];

Expand Down
26 changes: 21 additions & 5 deletions include/ml_dsa/internals/poly/poly.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ power2round(std::span<const ml_dsa_field::zq_t, ml_dsa_ntt::N> poly,
static inline constexpr void
mul(std::span<const ml_dsa_field::zq_t, ml_dsa_ntt::N> polya, std::span<const ml_dsa_field::zq_t, ml_dsa_ntt::N> polyb, std::span<ml_dsa_field::zq_t, ml_dsa_ntt::N> polyc)
{
#if (not defined __clang__) && (defined __GNUG__)
#pragma GCC unroll 16
#pragma GCC ivdep
#endif
for (size_t i = 0; i < polya.size(); i++) {
polyc[i] = polya[i] * polyb[i];
}
Expand All @@ -41,6 +45,9 @@ sub_from_x(std::span<ml_dsa_field::zq_t, ml_dsa_ntt::N> poly)
{
constexpr ml_dsa_field::zq_t x_cap(x);

#if defined __clang__
#pragma clang loop unroll(enable) vectorize(enable) interleave(enable)
#endif
for (size_t i = 0; i < poly.size(); i++) {
poly[i] = x_cap - poly[i];
}
Expand Down Expand Up @@ -76,10 +83,18 @@ infinity_norm(std::span<const ml_dsa_field::zq_t, ml_dsa_ntt::N> poly)
auto res = ml_dsa_field::zq_t::zero();

for (size_t i = 0; i < poly.size(); i++) {
#ifdef __clang__
if (poly[i] > qby2) {
res = std::max(res, -poly[i]);
} else {
res = std::max(res, poly[i]);
}
#else
const bool flg = poly[i] > qby2;
const ml_dsa_field::zq_t br[]{ poly[i], -poly[i] };

res = std::max(res, br[flg]);
#endif
}

return res;
Expand Down Expand Up @@ -112,17 +127,18 @@ use_hint(std::span<const ml_dsa_field::zq_t, ml_dsa_ntt::N> polyh,
}

// Given a degree-255 polynomial, this routine counts number of coefficients having value 1.
// Note, following implementation makes an assumption, coefficieints of input polynomial must be either 0 or 1.
// In case, one invokes this function with arbitrary polynomial, expect wrong result.
static inline constexpr size_t
count_1s(std::span<const ml_dsa_field::zq_t, ml_dsa_ntt::N> poly)
{
constexpr auto one = ml_dsa_field::zq_t::one();
size_t cnt = 0;
size_t count = 0;

for (size_t i = 0; i < poly.size(); i++) {
cnt += 1 * (poly[i] == one);
for (auto coeff : poly) {
count += coeff.raw();
}

return cnt;
return count;
}

// Given a degree-255 polynomial, this routine shifts each coefficient leftwards, by d bits.
Expand Down
6 changes: 3 additions & 3 deletions include/ml_dsa/ml_dsa_44.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ static constexpr size_t SigningSeedByteLen = ml_dsa::RND_BYTE_LEN;
static constexpr size_t SigByteLen = ml_dsa_utils::sig_len(k, l, γ1, ω, λ);

// Given a 32 -bytes seed, this routine can be used for generating a fresh ML-DSA-44 keypair.
inline void
constexpr void
keygen(std::span<const uint8_t, KeygenSeedByteLen> ξ, std::span<uint8_t, PubKeyByteLen> pubkey, std::span<uint8_t, SecKeyByteLen> seckey)
{
ml_dsa::keygen<k, l, d, η>(ξ, pubkey, seckey);
Expand All @@ -42,7 +42,7 @@ keygen(std::span<const uint8_t, KeygenSeedByteLen> ξ, std::span<uint8_t, PubKey
//
// Default (and recommended) signing mode is "hedged" i.e. using 32B input randomness for signing, results into
// randomized signature. For "deterministic" signing mode, simply fill `rnd` with zero bytes.
inline void
constexpr void
sign(std::span<const uint8_t, SigningSeedByteLen> rnd, std::span<const uint8_t, SecKeyByteLen> seckey, std::span<const uint8_t> msg, std::span<uint8_t, SigByteLen> sig)
{
ml_dsa::sign<k, l, d, η, γ1, γ2, τ, β, ω, λ>(rnd, seckey, msg, sig);
Expand All @@ -51,7 +51,7 @@ sign(std::span<const uint8_t, SigningSeedByteLen> rnd, std::span<const uint8_t,
// Given a ML-DSA-44 public key, a message M and a signature S, this routine can be used for verifying if the signature
// is valid for the provided message or not, returning truth value only in case of successful signature verification,
// otherwise false is returned.
inline bool
constexpr bool
verify(std::span<const uint8_t, PubKeyByteLen> pubkey, std::span<const uint8_t> msg, std::span<const uint8_t, SigByteLen> sig)
{
return ml_dsa::verify<k, l, d, γ1, γ2, τ, β, ω, λ>(pubkey, msg, sig);
Expand Down
6 changes: 3 additions & 3 deletions include/ml_dsa/ml_dsa_65.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ static constexpr size_t SigningSeedByteLen = ml_dsa::RND_BYTE_LEN;
static constexpr size_t SigByteLen = ml_dsa_utils::sig_len(k, l, γ1, ω, λ);

// Given a 32 -bytes seed, this routine can be used for generating a fresh ML-DSA-65 keypair.
inline void
constexpr void
keygen(std::span<const uint8_t, KeygenSeedByteLen> ξ, std::span<uint8_t, PubKeyByteLen> pubkey, std::span<uint8_t, SecKeyByteLen> seckey)
{
ml_dsa::keygen<k, l, d, η>(ξ, pubkey, seckey);
Expand All @@ -42,7 +42,7 @@ keygen(std::span<const uint8_t, KeygenSeedByteLen> ξ, std::span<uint8_t, PubKey
//
// Default (and recommended) signing mode is "hedged" i.e. using 32B input randomness for signing, results into
// randomized signature. For "deterministic" signing mode, simply fill `rnd` with zero bytes.
inline void
constexpr void
sign(std::span<const uint8_t, SigningSeedByteLen> rnd, std::span<const uint8_t, SecKeyByteLen> seckey, std::span<const uint8_t> msg, std::span<uint8_t, SigByteLen> sig)
{
ml_dsa::sign<k, l, d, η, γ1, γ2, τ, β, ω, λ>(rnd, seckey, msg, sig);
Expand All @@ -51,7 +51,7 @@ sign(std::span<const uint8_t, SigningSeedByteLen> rnd, std::span<const uint8_t,
// Given a ML-DSA-65 public key, a message M and a signature S, this routine can be used for verifying if the signature
// is valid for the provided message or not, returning truth value only in case of successful signature verification,
// otherwise false is returned.
inline bool
constexpr bool
verify(std::span<const uint8_t, PubKeyByteLen> pubkey, std::span<const uint8_t> msg, std::span<const uint8_t, SigByteLen> sig)
{
return ml_dsa::verify<k, l, d, γ1, γ2, τ, β, ω, λ>(pubkey, msg, sig);
Expand Down
6 changes: 3 additions & 3 deletions include/ml_dsa/ml_dsa_87.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ static constexpr size_t SigningSeedByteLen = ml_dsa::RND_BYTE_LEN;
static constexpr size_t SigByteLen = ml_dsa_utils::sig_len(k, l, γ1, ω, λ);

// Given a 32 -bytes seed, this routine can be used for generating a fresh ML-DSA-87 keypair.
inline void
constexpr void
keygen(std::span<const uint8_t, KeygenSeedByteLen> ξ, std::span<uint8_t, PubKeyByteLen> pubkey, std::span<uint8_t, SecKeyByteLen> seckey)
{
ml_dsa::keygen<k, l, d, η>(ξ, pubkey, seckey);
Expand All @@ -42,7 +42,7 @@ keygen(std::span<const uint8_t, KeygenSeedByteLen> ξ, std::span<uint8_t, PubKey
//
// Default (and recommended) signing mode is "hedged" i.e. using 32B input randomness for signing, results into
// randomized signature. For "deterministic" signing mode, simply fill `rnd` with zero bytes.
inline void
constexpr void
sign(std::span<const uint8_t, SigningSeedByteLen> rnd, std::span<const uint8_t, SecKeyByteLen> seckey, std::span<const uint8_t> msg, std::span<uint8_t, SigByteLen> sig)
{
ml_dsa::sign<k, l, d, η, γ1, γ2, τ, β, ω, λ>(rnd, seckey, msg, sig);
Expand All @@ -51,7 +51,7 @@ sign(std::span<const uint8_t, SigningSeedByteLen> rnd, std::span<const uint8_t,
// Given a ML-DSA-87 public key, a message M and a signature S, this routine can be used for verifying if the signature
// is valid for the provided message or not, returning truth value only in case of successful signature verification,
// otherwise false is returned.
inline bool
constexpr bool
verify(std::span<const uint8_t, PubKeyByteLen> pubkey, std::span<const uint8_t> msg, std::span<const uint8_t, SigByteLen> sig)
{
return ml_dsa::verify<k, l, d, γ1, γ2, τ, β, ω, λ>(pubkey, msg, sig);
Expand Down