AlpacaHack Round 5 Broadcasting NTRU

2024-10-18

Challenge

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import os
from Crypto.Cipher import AES
from hashlib import sha256

FLAG = os.environ.get("FLAG", "fakeflag")

N = 509
q = 2048
p = 3
d = 253

Zx.<x> = ZZ[]

def invertmodprime(f, p):
T = Zx.change_ring(Integers(p)).quotient(x ^ N - 1)
return Zx(lift(1 / T(f)))

def invertmodpowerof2(f, q):
assert q.is_power_of(2)
h = invertmodprime(f, 2)
while True:
r = balancedmod(convolution(h, f), q)
if r == 1:
return h
h = balancedmod(convolution(h, 2 - r), q)

def balancedmod(f, q):
g = list(((f[i] + q // 2) % q) - q // 2 for i in range(N))
return Zx(g)

def convolution(f, g):
return (f * g) % (x ^ N - 1)

def generate_polynomial(d):
coeffs = [1] * d + [0] * (N - d)
shuffle(coeffs)
return Zx(coeffs)

def generate_keys():
while True:
try:
f = generate_polynomial(d)
g = generate_polynomial(d)
f_p = invertmodprime(f, p)
f_q = invertmodpowerof2(f, q)
break
except:
pass
public_key = balancedmod(p * convolution(f_q, g), q)
secret_key = f, f_p
return public_key, secret_key

def generate_message():
result = list(randrange(2) for j in range(N))
return Zx(result)

def encrypt(message, public_key):
r = Zx(list(randrange(2) for j in range(N)))
return balancedmod(convolution(public_key, r) + message, q)


msg = generate_message()

public_keys = []
ciphertexts = []
for _ in range(777):
public_key, secret_key = generate_keys()
ct = encrypt(msg, public_key)
public_keys.append(public_key)
ciphertexts.append(ct)

print("public keys:", public_keys)
print("ciphertexts:", ciphertexts)

key = sha256(str(msg).encode()).digest()[:16]
cipher = AES.new(key=key, mode=AES.MODE_CTR)
enc_flag = cipher.encrypt(FLAG.encode())
print("encrypted flag:", (cipher.nonce + enc_flag).hex())

The author also provides a link to the paper that describes a broadcasting attack on NTRU.

Quick recap of the paper

The encryption of NTRU is

\[\mathbf{c} = \mathbf{H}\mathbf{r} + \mathbf{m} \pmod{q},\]

where \(\mathbf{H}\) is the public key (in the form of a ciruclar matrix) and \(\mathbf{r}\) is the nonce. Define \(\mathbf{\hat{H}} = \mathbf{H}^{-1}\) and \(\mathbf{b} = \mathbf{H}^{-1}\mathbf{c}\). If \(\mathbf{r}^T\mathbf{r}\) is known, we have

\[\mathbf{r}^T\mathbf{r} = (\mathbf{b} - \mathbf{\hat{H}}\mathbf{m})^T(\mathbf{b} - \mathbf{\hat{H}}\mathbf{m}) = \mathbf{b}^T\mathbf{b} - 2\mathbf{b}\mathbf{b}^T\mathbf{\hat{H}}\mathbf{m} + \mathbf{m}^T\mathbf{\hat{H}}^T\mathbf{\hat{H}}\mathbf{m}.\]

This is a quadratic equation with variable \(\mathbf{m}\). If we use our usual linearization techinique, we will end up with a quadratic number of variables. However, notice that \(\mathbf{\hat{H}}^T\mathbf{\hat{H}}\) is a symmetric circular matrix, so many quadratic terms share the same coefficient and can be merged. This reduces the number of variables to \(N + \lceil N/2 \rceil\), and we need this amount of public key and ciphertext pair to recover the secret message.

From the paper to our challenge

In our challenge, \(\mathbf{r}\) and \(\mathbf{m}\) are random polynomials with binary coefficients (all coefficients are 0 or 1) , which means \(\mathbf{r}^T\mathbf{r}\). The paper mentions a trick: consider \(\mathbf{r'} = 2\mathbf{r} - (1, 1, \cdots, 1)\), all the entries of \(\mathbf{r'}\) will be either -1 or 1, and \(\mathbf{r'}^T\mathbf{r'}\) would be \(N\), a value we know. This is the official solution as well, but here we provide another trick that can potentially give a faster solution.

The key insight is that \(r(1)\) in polynomial form is the same as \(\mathbf{r}^T\mathbf{r}\) in vector form because of its binary coefficients, and we have the equation

\[c(1) = H(1)r(1) + m(1) \pmod{q}.\]

Since \(H\) is generated by calculating \(p\cdot f/g\) where \(p=3\), and \(f(1) = g(1) = 273\), we have

\[c(1) = 3 \cdot r(1) + m(1)\]

if we scale \(c(1)\) to be in the range \([0, q)\) (we use the fact that LHS is bounded by \(4N < q\)). This means if we know the value of \(m(1)\), we have \(r(1)\) and can perform the paper's attack!

Bruteforce \(m(1)\)

The value of \(m(1)\) can be viewed as the number of heads when flipping \(N = 509\) fair coins. The mean value is \(N/2\) and the standard variation is \(\sqrt{N/4} \approx 11.28\). We also know \(m(1) \bmod 3\), which turned out to be \(2\). Therefore, if we consider \(m(1) = 221, 224, \cdots, 287\) (a total of 23 cases), we are highly confident that we can find the correct \(m(1)\) as it covers up to 3 standard deviations.

Linearization

If we check the equation (3.7) in the paper

\[a_0x_0 + 2a_1x_1 + \cdots + 2a_{\lfloor N/2 \rfloor}x_{\lfloor N/2 \rfloor} -2w_0m_0 - 2w_1m_1 - \cdots - 2w_{N-1}m_{N-1} = \mathbf{r}^T\mathbf{r} - \mathbf{b}^T\mathbf{b} \pmod{q}, \]

the term \(a_0x_0\) is the only one without an even coefficient. \(a_0\) is a known value, and \(x_0\), coincidently, is the value \(m(1)\) (again, using the fact that \(\mathbf{m}^T\mathbf{m} = m(1)\) because it's binary)! Therefore, we can rewrite the equation as

\[a_1x_1 + \cdots + a_{\lfloor N/2 \rfloor}x_{\lfloor N/2 \rfloor} -w_0m_0 - w_1m_1 - \cdots - w_{N-1}m_{N-1} = \frac{1}{2}(\mathbf{r}^T\mathbf{r} - \mathbf{b}^T\mathbf{b} - a_0m(1)) \pmod{q/2}, \]

The benefit is that now we can solve the equation as a boolean linear system, which is a lot easier and faster than solving the system modulo \(q = 2048\).

Final optimization

It seems like the main advantage of this alternative solution is to solve the linear system in boolean instead of \(GF(q)\), but with the cost of bruteforcing 23 cases of \(m(1)\) and a small chance of failing. However, the matrix for the linear system is fixed for all cases, which means we only have to calculate the matrix inversion once, rather than 23 times. The only repeating operation would be the multiplication of a matrix (the inverted one) and a vector, which is not the bottleneck of the solution.

Furthermore, the constant terms of the linear system are the parities of

\[\frac{1}{2}(\mathbf{r}^T\mathbf{r} - \mathbf{b}^T\mathbf{b} - a_0m(1))\]

Since \(\mathbf{r}^T\mathbf{r} = r(1) = (c(1) - m(1)) / 3\),

\[\frac{1}{2}(\mathbf{r}^T\mathbf{r} - \mathbf{b}^T\mathbf{b} - a_0m(1)) \equiv \frac{1-a_0}{2}\cdot m(1) - \frac{1}{2} (\mathbf{b}^T\mathbf{b} + c(1)) \pmod{2}\]

This means that we only have to consider two cases: \(m(1)\) is odd, or \(m(1)\) is even. No more 23x overhead.

Solve script

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# Modified from official writeup: https://chocorusk.hatenablog.com/entry/2024/10/12/180717
from Crypto.Cipher import AES
from hashlib import sha256

N = 509
q = 2048
p = 3
d = 253

Zx.<x> = ZZ[]

def invertmodprime(f,p):
T = Zx.change_ring(Integers(p)).quotient(x^N-1)
return Zx(lift(1 / T(f)))

def invertmodpowerof2(f,q):
assert q.is_power_of(2)
h = invertmodprime(f,2)
while True:
r = balancedmod(convolution(h,f),q)
if r == 1: return h
h = balancedmod(convolution(h,2 - r),q)

def balancedmod(f,q):
g = list(((f[i] + q//2) % q) - q//2 for i in range(N))
return Zx(g)

def convolution(f,g):
return (f * g) % (x^N-1)

with open('output.txt') as f:
public_keys = sage_eval(f.readline()[len("public keys: "):].strip(), locals={'x':x})
ciphertexts = sage_eval(f.readline()[len("ciphertexts: "):].strip(), locals={'x':x})
enc_flag = bytes.fromhex(f.readline()[len("encrypted flag: "):].strip())

mat = []

a0s = []
c_lens = []
b_lens = []

print("Constructing matrix")
for h, c in zip(public_keys, ciphertexts):
try:
h1 = invertmodpowerof2(h, q)
except:
continue
b = balancedmod(convolution(h1,c), q)

ht = list(h1)
ht += [0]*(N-len(ht))
ht = ht[::-1]
ht = Zx([ht[-1]]+ht[:-1])
a = balancedmod(convolution(ht,h1), q)
w = balancedmod(convolution(ht,b), q)

a = list(a)
w = list(w)
a += [0]*(N-len(a))
w += [0]*(N-len(w))

a0s.append(a[0])
c_lens.append(int(c(1) % q))
b_lens.append(sum(v * v for v in b))
mat.append([a[i] for i in range(1,N//2+1)]+[-w[i] for i in range(N)])

mat = matrix(GF(2), mat)

msg_len_mod_3 = int(ciphertexts[0](1) % q) % 3

possible_msg_len = [i for i in range(msg_len_mod_3, N, 3)]
possible_msg_len = sorted(possible_msg_len, key=lambda x: abs(x-N/2))

for msg_len in possible_msg_len:
print(f"Trying message length = {msg_len}")
vec = []
for a0, c_len, b_len in zip(a0s, c_lens, b_lens):
r_len = (c_len - msg_len) // 3
constant_term = r_len - b_len - a0 * msg_len
constant_term //= 2
vec.append(constant_term)

vec = vector(GF(2), vec)
try:
res = mat.solve_right(vec)
assert mat * res == vec
res = balancedmod([int(v) for v in list(res)[-N:]], 4)
key = sha256(str(res).encode()).digest()[:16]
cipher = AES.new(key=key, mode=AES.MODE_CTR, nonce=enc_flag[:8])
flag = cipher.decrypt(enc_flag[8:])
if all(x < 128 for x in flag):
print(flag)
break
except:
pass