TFHE: guide walkthrough. Part 1

Posted on 2023-08-15 | source code

Walkthrough of Guide to Fully Homomorphic Encryption over the [Discretized] Torus by Marc Joye. Part 1.

@misc{cryptoeprint:2021/1402,
      author = {Marc Joye},
      title = {Guide to Fully Homomorphic Encryption over the [Discretized] Torus},
      howpublished = {Cryptology ePrint Archive, Paper 2021/1402},
      year = {2021},
      note = {\url{https://eprint.iacr.org/2021/1402}},
      url = {https://eprint.iacr.org/2021/1402}
}

1.1 Torus and Torus Polynomials

Example 1. Take for example $a = \frac{2}{5}, b = \frac{4}{5}$ and $c = \frac{1}{3}$.

Over $\mathbb{T}$, we get

$(a + b) \times c = \frac{1}{5} \times \frac{1}{3} = \frac{1}{15}$ and

$a \times c + b \times c = \frac{2}{15} + \frac{4}{15} = \frac{6}{15} = \frac{2}{5}$, a contradiction.

class T(Fraction): # represents the element of thorus
    def __new__(cls, n, d):
        return super().__new__(cls, n % d, d)
    
    def __add__(self, other):  # addition
        f = Fraction.__add__(self, other)
        return T(f.numerator, f.denominator)
    
    def dot(self, n): # external product
        return T(n * self.numerator, self.denominator)

    
a = T(2, 5)
b = T(4, 5)
c = T(1, 3)

# NOTE: * - is internal product, wich IS NOT DEFINED on T. 
# To perform it we move elements to real numbers (represented by Fraction)

print((a + b) * c, "!=", # (𝑎+𝑏)×𝑐
      (a * c) + (b * c)) # 𝑎×𝑐+𝑏×𝑐
1/15 != 2/5

Example 2. Take $k = 2, l = 3, a = \frac{2}{5}$ and $b = \frac{4}{5}$.

We get $(k + l) \cdot a = 5 \cdot \frac{2}{5} = 0$ and $k \cdot a + l \cdot a = \frac{4}{5} + \frac{1}{5} = 0$, as expected.

We also get $k \cdot (a + b) = 2 \cdot \frac{1}{5} = \frac{2}{5}$ and $k \cdot a + k \cdot b = \frac{4}{5} + \frac{3}{5} = \frac{2}{5}$.

Finally, taking $t = a = \frac{2}{5}$, we get $k \cdot (l \cdot t) = 2 \cdot \frac{1}{5} = \frac{2}{5}$ and $(kl) \cdot t = 6 \cdot \frac{2}{5} = \frac{2}{5}$, as expected.

k = 2
l = 3
t = a = T(2, 5)
b = T(4, 5)

print(
  a.dot(k + l), "=",   # (𝑘+𝑙)⋅𝑎
  a.dot(k) + a.dot(l)) # 𝑘⋅𝑎+𝑙⋅𝑎
print(
  (a + b).dot(k), "=", # 𝑘⋅(𝑎+𝑏)
  a.dot(k) + b.dot(k)) # 𝑘⋅𝑎+𝑘⋅𝑏
print(
  t.dot(l).dot(k), "=",# 𝑘⋅(𝑙⋅𝑡)
  t.dot(l * k))        # (𝑘𝑙)⋅𝑡
0 = 0
2/5 = 2/5
2/5 = 2/5

Example 3. Take for example $\mathscr{p}(X) = \frac{2}{5}X + \frac{1}{3}$, $\mathscr{q}(X) = \frac{4}{5}X + \frac{1}{2}$, and $\mathscr{r}(X) = 2X + 7$.

Then $(\mathscr{p} + \mathscr{q})(X) = \frac{1}{5}X + \frac{5}{6}$

p = (T(1, 3), T(2, 5)) # 1/3 + 2/5X
q = (T(1, 2), T(4, 5)) # 1/2 + 4/5X
r = (7, 2)             #   7 + 2X

print(p[0] + q[0], p[1] + q[1]) # p + q
5/6 1/5

and

$(\mathscr{r} \cdot \mathscr{p})(X) = \frac{4}{5}X^2 + \frac{7}{15}X + \frac{1}{3} = ...$

t = [T(0, 1)] * 3 # 0 + 0X + 0X^2

for i in range(2):
    for j in range(2):
        t[i + j] += p[i].dot(r[j])
print(*t)
1/3 7/15 4/5

$...= -\frac{4}{5} + \frac{7}{15}X + \frac{1}{3} = \frac{7}{15}X + \frac{8}{15}$

Recall that polynomials are defined modulo $X^2 + 1$ (and thus $X^2 ≡ −1$).

print(t[0] + t[2].dot(-1), t[1])
8/15 7/15

1.2 Discretized Torus (+Jaxite)

Later in this walkthrough, instead of implementing it from scratch, I am going to use Jaxite, an FHE library by Google.

Let's compute the $(\mathscr{r} \cdot \mathscr{p})(X)$ from Example 3 using Jaxite.

For practical reasons, thorus elements are not represented with fractions, but rather as integers modulo $q$.

Consider $\mathscr{p}(X) = \frac{2}{5}X + \frac{1}{3}$ from $\mathbb{T}_2[X]$,

it can be represented as $\mathscr{\overline{p}}(X) = 6X + 5$ in $\mathbb{Z}_{{2},{15}}[X]$, and

$(\mathscr{r} \cdot \mathscr{\overline{p}})(X) = 7X + 8$ in $\mathbb{Z}_{{2},{15}}[X]$.

import jax.numpy as jnp
from jaxite.jaxite_lib.matrix_utils import poly_mul

p = jnp.array([5, 6]) # in Z_{2,15}[X]
r = jnp.array([7, 2]) # in Z[X]
# poly_mul will account for modulo (X^2 + 1), but not for modulo 15, do it manually
print(poly_mul(r, p) % 15)
[8 7]

1.3 Notation

Example 5. The vector $\pmb{\mathcal{v}} = (\pmb{3}, \pmb{4}) \in \mathbb{Z}^2$ is regarded as the row matrix $\begin{pmatrix} 3 & 4 \end{pmatrix} \in \mathbb{Z}^{1 \times 2}$, and if $\pmb{A} = \begin{pmatrix} 1 & 2 \\ 0 & 1 \end{pmatrix}$ then $\pmb{\mathcal{v}A} = \begin{pmatrix} 3 & 10 \end{pmatrix} = (\pmb{3}, \pmb{10})$

v = jnp.array([[3, 4]]) 
A = jnp.array([[1, 2], [0, 1]])
print(jnp.matmul(v, A))
[[ 3 10]]