Sunday, April 26, 2020

Modular Exponentiation in Montgomery Space


How to compute modular exponentiation c = xe mod n efficiently, especially when exponentiation part e is large?

A better algorithm is exponentiation by squaring [1], we can express e in binary representation:

/// Modular Exponentiation
/// Inputs: x, e, n 128-bit integer
/// Output: x^e mod n (128-bit integer)
uint128_t modexp(uint128_t x, uint128_t e, uint128_t n) {
  uint128_t base = x;
  x = 1;
  while (e) {
    if (e & 0x1) {
      uint256_t xb = uint256_t::mul(x, base);
      x = xb.mod(n);
    }

    uint256_t xb = uint256_t::mul(base, base);
    base = xb.mod(n);
    e >>= 1;
  }
  return x;
}
And modulo:
  uint128_t mod(const uint128_t b) const {
    uint128_t r = 0;
    for (uint16_t x = bits(); x > 0; x--) {
      r <<= 1;
      if (x > 128) {
        if ((high >> (x - 1U - 128)) & 1)
          ++r;
      } else {
        if ((low >> (x - 1)) & 1)
          ++r;
      }
      if (r >= b)
        r -= b;
    }
    return r;
  }
Even though the mod only use bit operations, but it is actually a very expensive operation, because of the for loop. Better algorithm? In modexp function, we compute double width product x*base, and base*base,  then find their smallest nonnegative representation by performing a mod.
    base*base ≡ c (mod n)
    x*base ≡ d (mod n)

The idea is, we can use different equivalence class for x and base, such that divide and modulo in these classes become very cheap [2][3]. On the computer system, we choose R = 2k and R > n, transform x and base into Montgomery representation by x*R mod n, base*R mod n, such that modulo R just discards high bits, divide R just shifts bits right [2][3].
Let's look at the code, the code framework is similar to modexp.
  /// Inputs: x, e 128-bit integer
  /// Output: x^e mod n (128-bit integer)
  ///
  /// n is initialized in constructor
  uint128_t montModExp(uint128_t x, uint128_t e) {
    uint128_t mX = toMontSpace(1);
    uint128_t mBase = toMontSpace(x);

    while (e) {
      if (e & 0x1)
        mX = montMul(mX, mBase);
      mBase = montMul(mBase, mBase);
      e >>= 1;
    }
    return montMul(mX, 1);
  }
First, we transform x and base into Montgomery representation.
This is a 128-bit integer implementation, so we choose R = 2128
  /// Transform x to Montgomery Representation.
  /// Inputs: x (128-bit integer)
  /// Output: c (128-bit integer) = x * R mod n
  ///
  /// R = 2^128 in 128-bit integer implementation.
  uint128_t toMontSpace(uint128_t x) {
    x %= N;
    for (int i = 0; i < 128; i++) {
      x <<= 1;
      if (x >= N)
        x -= N;
    }
    return x;
  }
Then we perform Montgomery multiplication xR (mod n) * baseR (mod n), and (baseR)2 (mod n),
  /// Compute Montgomery multiplication.
  /// cR = aR * bR * R^-1 (mod n)
  ///
  /// mA = aR, mB = bR. mA and mB are in Montgomery space. (128-bit integer)
  uint128_t montMul(uint128_t mA, uint128_t mB) {
    return reduce(uint256_t::mul(mA, mB));
  }
Note:

This is a critical step and is known as Montgomery reduction. [4][3][5][2]
Let's see the mathematics and how to implement it.
Let T = (aR mod n)(bR mod n)
The idea is, add kn to T such that T + kn  0 (mod R) , therefor TR-1  (T + kn)/R (mod n) [2][3]
Why it is correct? If we just multiply two numbers without (mod n) for each number:
(aR)(bR) / R = abR, then perform (mod R), we can have correct result abR (mod n)
But we can't directly T / R, because after modulo operation, T is not divisible by R.
However, T + kn is divisible by R, and T + kn ≡ T (mod n) 
So we can add kn to T, such that T + kn  0 (mod R), then we can directly divide R from it.


How to find k? The algorithm [4][3]:
function reduce(T):
    k = (T mod R) * n' mod R
    a = (T - k * n) / R
    if a < 0:
        a += n
    return a


[cp-algorithm:montgomery_multiplication]


Note: R = 2k , in our implementation, it is R = 2128,
∴ (mod R) = drop bits >= 128 bits,
(div R) = shift right 128 bits, in other words, drop lower 128 bits
They are cheap on the computer system.
Also a = [-n, n-1]
  /// Compute Montgomery Reduction.
  /// cR = xR^2 * R^-1 (mod n)
  ///
  /// Algorithm:
  ///   q = (xR^2 mod R) * n' mod R
  ///   a = (xR^2 + q * n) / R
  ///   if a >= n:
  ///     a -= n
  ///   return a
  uint128_t reduce(uint256_t xR2) {
    uint128_t q = xR2.low * Ninv;
    int128_t a = xR2.high - uint256_t::mul(q, N).high;
    if (a < 0)
      a += N;
    return a;
  }
We can also think in this way (wiki Montgomery reduction):

Next question is how to find n' (that is: inverse of n, such that nn'  1 (mod R) )?
There is a fast inverse trick:
(directly copy from: https://cp-algorithms.com/algebra/montgomery_multiplication.html#toc-tgt-2)

Proof: Let ax = 1 + m * 2k

    // Compute N^-1 s.t. N * N^-1 ≡ 1 (mod R)
    // Fast inverse trick:
    //    a * x ≡ 1 (mod 2^k) => a * x * (2 - a * x) ≡ 1 (mod 2^(2k))
    //
    // Ninv = 1,
    // (1) N * Ninv ≡ 1 (mod 2^1) => N * Ninv * (2 - N * Ninv) (mod 2^2)
    //     Update Ninv = Ninv * (2 - N * Ninv)
    // (2) N * Ninv ≡ 1 (mod 2^2) => N * Ninv * (2 - N * Ninv) (mod 2^4)
    //     Update Ninv = Ninv * (2 - N * Ninv)
    // (3) ...
    // (7) N * Ninv ≡ 1 (mod 2^64) => N * Ninv * (2 - N * Ninv) (mod 2^128)
    Ninv = 1;
    for (int i = 0; i < 7; ++i)
      Ninv = Ninv * (2 - N * Ninv);
We finished the Montgomery exponentiation implementation.
But the transformation of a number to Montgomery representation is too slow, actually we can optimize it by the fact:

The algorithm for computing R2 mod n:
    R2 = -N % N;
    for (int i = 0; i < 4; i++) {
      R2 <<= 1;
      if (R2 >= N)
        R2 -= N;
    }

    for (int i = 0; i < 5; i++) {
      R2 = montMul(R2, R2);
    }
Note -N % N  R - N  R (mod N).
Then modify our montModExp:
  /// Compute: c = x^e mod n
  ///
  /// x, e, and c are 128-bit integers.
  /// n and r^2 is initialized in constructor
  uint128_t montModExp(uint128_t x, uint128_t e) {
    // Result x in Montgomery Space
    uint128_t mX = montMul(1, R2);

    // Transform x to Montgomery Space as base
    // mBase = x * r^2 * r^-1 (mod N)
    uint128_t mBase = montMul(x, R2);
    while (e) {
      if (e & 0x1)
        mX = montMul(mX, mBase);
      mBase = montMul(mBase, mBase);
      e >>= 1;
    }

    // Transform x from Montgomery Space to Normal Space.
    return montMul(mX, 1);
  }
Done! (=゚ω゚)ノ

Performance Comparison:
  • Compile Flags: -std=c++11 -Wall -O3 -DNDEBUG
  • Machine: i7-4770HQ CPU @ 2.20GHz
  • n = 0x9e40fd675571e0af74d65da4ea541cf
    x = 0xfbeab553608bdf65b2ab09bb910317f9
    e = 0x172a202e867b11779604827082342863
  • Test loop: 10000
  • modexp takes 3133.34 msec
  • Montgomery exponential: takes 18.7284 msec.
    167 times faster than modexp.
Full Implementation:
https://github.com/cycheng/coding-for-fun/blob/master/crypto/montgomery.cpp

References:
[1] https://en.wikipedia.org/wiki/Modular_arithmetic
[2] (Traditional Chinese) https://ee.ntu.edu.tw/upload/hischool/doc/2011.07.pdf
[3] https://en.wikipedia.org/wiki/Montgomery_modular_multiplication
[4] https://cp-algorithms.com/algebra/montgomery_multiplication.html
[5] https://cryptography.fandom.com/wiki/Montgomery_reduction

Good Tools:
[1] Big number calculator https://www.boxentriq.com/code-breaking/big-number-calculator
[2] Online LaTeX editor https://www.codecogs.com/latex/eqneditor.php

Wednesday, April 8, 2020

Mathematics behind secp256k1_fe_mul_inner

Brief:
  • The function computes r ≡ a × b (mod p)
    • a, b, r, and p are 256-bit integers
    • p = 2256 - 0x1000003D1
Source Code:
void secp256k1_fe_mul_inner(const uint32_t *a, const uint32_t *b, uint32_t *r) {
    uint64_t c = (uint64_t)a[0] * b[0];
    uint32_t t0 = c & 0x3FFFFFFUL; c = c >> 26;
    c = c + (uint64_t)a[0] * b[1] +
            (uint64_t)a[1] * b[0];
    uint32_t t1 = c & 0x3FFFFFFUL; c = c >> 26;
    c = c + (uint64_t)a[0] * b[2] +
            (uint64_t)a[1] * b[1] +
            (uint64_t)a[2] * b[0];
    uint32_t t2 = c & 0x3FFFFFFUL; c = c >> 26;
    c = c + (uint64_t)a[0] * b[3] +
            (uint64_t)a[1] * b[2] +
            (uint64_t)a[2] * b[1] +
            (uint64_t)a[3] * b[0];
    uint32_t t3 = c & 0x3FFFFFFUL; c = c >> 26;
    c = c + (uint64_t)a[0] * b[4] +
            (uint64_t)a[1] * b[3] +
            (uint64_t)a[2] * b[2] +
            (uint64_t)a[3] * b[1] +
            (uint64_t)a[4] * b[0];
    uint32_t t4 = c & 0x3FFFFFFUL; c = c >> 26;
    c = c + (uint64_t)a[0] * b[5] +
            (uint64_t)a[1] * b[4] +
            (uint64_t)a[2] * b[3] +
            (uint64_t)a[3] * b[2] +
            (uint64_t)a[4] * b[1] +
            (uint64_t)a[5] * b[0];
    uint32_t t5 = c & 0x3FFFFFFUL; c = c >> 26;
    c = c + (uint64_t)a[0] * b[6] +
            (uint64_t)a[1] * b[5] +
            (uint64_t)a[2] * b[4] +
            (uint64_t)a[3] * b[3] +
            (uint64_t)a[4] * b[2] +
            (uint64_t)a[5] * b[1] +
            (uint64_t)a[6] * b[0];
    uint32_t t6 = c & 0x3FFFFFFUL; c = c >> 26;
    c = c + (uint64_t)a[0] * b[7] +
            (uint64_t)a[1] * b[6] +
            (uint64_t)a[2] * b[5] +
            (uint64_t)a[3] * b[4] +
            (uint64_t)a[4] * b[3] +
            (uint64_t)a[5] * b[2] +
            (uint64_t)a[6] * b[1] +
            (uint64_t)a[7] * b[0];
    uint32_t t7 = c & 0x3FFFFFFUL; c = c >> 26;
    c = c + (uint64_t)a[0] * b[8] +
            (uint64_t)a[1] * b[7] +
            (uint64_t)a[2] * b[6] +
            (uint64_t)a[3] * b[5] +
            (uint64_t)a[4] * b[4] +
            (uint64_t)a[5] * b[3] +
            (uint64_t)a[6] * b[2] +
            (uint64_t)a[7] * b[1] +
            (uint64_t)a[8] * b[0];
    uint32_t t8 = c & 0x3FFFFFFUL; c = c >> 26;
    c = c + (uint64_t)a[0] * b[9] +
            (uint64_t)a[1] * b[8] +
            (uint64_t)a[2] * b[7] +
            (uint64_t)a[3] * b[6] +
            (uint64_t)a[4] * b[5] +
            (uint64_t)a[5] * b[4] +
            (uint64_t)a[6] * b[3] +
            (uint64_t)a[7] * b[2] +
            (uint64_t)a[8] * b[1] +
            (uint64_t)a[9] * b[0];
    uint32_t t9 = c & 0x3FFFFFFUL; c = c >> 26;
    c = c + (uint64_t)a[1] * b[9] +
            (uint64_t)a[2] * b[8] +
            (uint64_t)a[3] * b[7] +
            (uint64_t)a[4] * b[6] +
            (uint64_t)a[5] * b[5] +
            (uint64_t)a[6] * b[4] +
            (uint64_t)a[7] * b[3] +
            (uint64_t)a[8] * b[2] +
            (uint64_t)a[9] * b[1];
    uint32_t t10 = c & 0x3FFFFFFUL; c = c >> 26;
    c = c + (uint64_t)a[2] * b[9] +
            (uint64_t)a[3] * b[8] +
            (uint64_t)a[4] * b[7] +
            (uint64_t)a[5] * b[6] +
            (uint64_t)a[6] * b[5] +
            (uint64_t)a[7] * b[4] +
            (uint64_t)a[8] * b[3] +
            (uint64_t)a[9] * b[2];
    uint32_t t11 = c & 0x3FFFFFFUL; c = c >> 26;
    c = c + (uint64_t)a[3] * b[9] +
            (uint64_t)a[4] * b[8] +
            (uint64_t)a[5] * b[7] +
            (uint64_t)a[6] * b[6] +
            (uint64_t)a[7] * b[5] +
            (uint64_t)a[8] * b[4] +
            (uint64_t)a[9] * b[3];
    uint32_t t12 = c & 0x3FFFFFFUL; c = c >> 26;
    c = c + (uint64_t)a[4] * b[9] +
            (uint64_t)a[5] * b[8] +
            (uint64_t)a[6] * b[7] +
            (uint64_t)a[7] * b[6] +
            (uint64_t)a[8] * b[5] +
            (uint64_t)a[9] * b[4];
    uint32_t t13 = c & 0x3FFFFFFUL; c = c >> 26;
    c = c + (uint64_t)a[5] * b[9] +
            (uint64_t)a[6] * b[8] +
            (uint64_t)a[7] * b[7] +
            (uint64_t)a[8] * b[6] +
            (uint64_t)a[9] * b[5];
    uint32_t t14 = c & 0x3FFFFFFUL; c = c >> 26;
    c = c + (uint64_t)a[6] * b[9] +
            (uint64_t)a[7] * b[8] +
            (uint64_t)a[8] * b[7] +
            (uint64_t)a[9] * b[6];
    uint32_t t15 = c & 0x3FFFFFFUL; c = c >> 26;
    c = c + (uint64_t)a[7] * b[9] +
            (uint64_t)a[8] * b[8] +
            (uint64_t)a[9] * b[7];
    uint32_t t16 = c & 0x3FFFFFFUL; c = c >> 26;
    c = c + (uint64_t)a[8] * b[9] +
            (uint64_t)a[9] * b[8];
    uint32_t t17 = c & 0x3FFFFFFUL; c = c >> 26;
    c = c + (uint64_t)a[9] * b[9];
    uint32_t t18 = c & 0x3FFFFFFUL; c = c >> 26;
    uint32_t t19 = c;

    c = t0 + (uint64_t)t10 * 0x3D10UL;
    t0 = c & 0x3FFFFFFUL; c = c >> 26;
    c = c + t1 + (uint64_t)t10*0x400UL + (uint64_t)t11 * 0x3D10UL;
    t1 = c & 0x3FFFFFFUL; c = c >> 26;
    c = c + t2 + (uint64_t)t11*0x400UL + (uint64_t)t12 * 0x3D10UL;
    t2 = c & 0x3FFFFFFUL; c = c >> 26;
    c = c + t3 + (uint64_t)t12*0x400UL + (uint64_t)t13 * 0x3D10UL;
    r[3] = c & 0x3FFFFFFUL; c = c >> 26;
    c = c + t4 + (uint64_t)t13*0x400UL + (uint64_t)t14 * 0x3D10UL;
    r[4] = c & 0x3FFFFFFUL; c = c >> 26;
    c = c + t5 + (uint64_t)t14*0x400UL + (uint64_t)t15 * 0x3D10UL;
    r[5] = c & 0x3FFFFFFUL; c = c >> 26;
    c = c + t6 + (uint64_t)t15*0x400UL + (uint64_t)t16 * 0x3D10UL;
    r[6] = c & 0x3FFFFFFUL; c = c >> 26;
    c = c + t7 + (uint64_t)t16*0x400UL + (uint64_t)t17 * 0x3D10UL;
    r[7] = c & 0x3FFFFFFUL; c = c >> 26;
    c = c + t8 + (uint64_t)t17*0x400UL + (uint64_t)t18 * 0x3D10UL;
    r[8] = c & 0x3FFFFFFUL; c = c >> 26;
    c = c + t9 + (uint64_t)t18*0x400UL + (uint64_t)t19 * 0x1000003D10UL;
    r[9] = c & 0x03FFFFFUL; c = c >> 22;
    uint64_t d = t0 + c * 0x3D1UL;
    r[0] = d & 0x3FFFFFFUL; d = d >> 26;
    d = d + t1 + c*0x40;
    r[1] = d & 0x3FFFFFFUL; d = d >> 26;
    r[2] = t2 + d;
}

Mathematics:

We can represent a and b in this way:
a = a0 + a1 × 226 + a2 × 252 + a3 × 278 + + a7 × 2182 + a8 × 2208 + a9 × 2234
b = b0 + b1 × 226 + b2 × 252 + b3 × 278 + + b7 × 2182 + b8 × 2208 + b9 × 2234
Where a0..8 and b0..8 are 26-bit integers, a9 and b9 are 22-bit integers

The coefficients of t are:

t0 = a0b0
t1 = a0b+ a1b0 + t0's carry
t2 = a0b+ a1b1 + a2b0 + t1's carry
t3 = a0b+ a1b2 + a2b1 + a3b0 + t2's carry

t9 = a0b+ a1b8 + a2b7 + a3b6 + a4b+ a5b4 + a6b3 + a7b2 + a8b1 + a9b0 + t8's carry

t17 = a8b+ a9b8 + t16's carry
t18 = a9b9 + t17's carry
t19 = t18's carry

Reduced Representation:

= 2256 - 0x1000003D1
 2256 0x1000003D1 ≡ 0 (mod p)
2256 ≡ 0x1000003D1 (mod p)

a×b = t0 + t1 × 226 + t2 × 252 + t3 × 278 + + t9 × 2234 + + t17 × 2442 + t18 × 2468 + t19 × 2494
2256 ≡ 0x1000003D1 (mod p), and by compatibility with scaling property, we can rewrite t10..t19 in this way:
t10 × 2260 = t10 × 2× 2256 ≡ t10 × 24 × 0x1000003D1 (mod p)
t11 × 2286 = t11 × 230 × 2256 ≡ t11 × 230 × 0x1000003D1 (mod p)
t12 × 2312 = t12 × 256 × 2256 ≡ t12 × 256 × 0x1000003D1 (mod p)

t19 × 2494 = t19 × 2238 × 2256 ≡ t19 × 2238 × 0x1000003D1 (mod p)

by compatibility with addition property, we can reduce modulo p as
r0 = t0 + t10 × 24 × 0x1000003D1 (mod p), but r0 only hold lower 26-bits, so we leave 0x100000000 to r1. So:
r0 = t0 + t10 × 24 × 0x3D1 (mod p)
r1 = t1 + t10 × 210 + t11 × 24 × 0x3D1 + r0's carry (mod p)
r2 = t2 + t11 × 210 + t12 × 24 × 0x3D1 + r1's carry (mod p)
r3 = t3 + t12 × 210 + t13 × 24 × 0x3D1 + r2's carry (mod p)

r8 = t8 + t17 × 210 + t18 × 24 × 0x3D1 + r7's carry (mod p)
r9 = t9 + t18 × 210 + t19 × 0x1000003D10 + r8's carry (mod p)

Wait, what if r9 > 222 ? Don't worry, just reduce it again by applying:
2256 ≡ 0x1000003D1 (mod p)

Good reference:


Especially thanks Shen-Ta Hsieh for helping me understand the magics =)