BalsnCTF 2021

2021-11-21

Introduction

Solo BalsnCTF and get destroyed QwQ. I only solve one crypto, which is about predicting output of Mersanne twister. I've only implmented the twister before, so this is quite a new experience for me :D.

1337 pins

Challenge

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
N = 1337

remaining = N
for i in range(31337):
y = random.getrandbits(32) % 10
x = int(input())

if x == y:
remaining -= 1
print('.')
else:
remaining = N
print(y)

if remaining == 0:
with open('/flag.txt') as f:
print(f.read())
break

Solution

By implementation of random and getrandombits, we can see that it is just a simple mt19937. The initial state is an array of 624 32-bit numbers (which the lowest 31 bits of the first number is not important), and as this PRNG "twist", it does some bitwise operation magic. A cool property is that, if we consider the original 19968 bits as 19968 variables, every bit of the future states can be written as an XOR sum of some of these 19968 variables. As mentioned earlier, the lower 31 bits of the first state is not important, so actually there's only 19937 varaibles (which somehow relates to the name mt19937).

The challenge provides the last digit of the output, but we only need the last bit (aka the parity). We collect 19937 output, and construct 19937 equations out of the 19937 variables above. Solving these linear equation takes \(O(n^3)\) time, but we can do it with bit operation, which makes it faster. With the solution of the equations, we can clone the mt19937 PRNG, and predict the output.

As a side note, the equation, written as a matrix, seems to be sparse, so skipping unnecessary iteration of the for loop is a huge optimization. I personally don't know why, but adding this optimization solves the euqations in less than 5 seconds.

The script:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import random
from pwn import *
from tqdm import tqdm

class MT19937:
def __init__(self, seed):
(MT19937.w, MT19937.n, MT19937.m, MT19937.r) = (32, 624, 397, 31)
MT19937.a = 0x9908B0DF
(MT19937.u, MT19937.d) = (11, 0xFFFFFFFF)
(MT19937.s, MT19937.b) = (7, 0x9D2C5680)
(MT19937.t, MT19937.c) = (15, 0xEFC60000)
MT19937.l = 18
self.states = [seed]
MT19937.lowerMask = (1 << MT19937.r) - 1
MT19937.mask = (1 << MT19937.w) - 1
MT19937.upperMask = MT19937.mask ^ MT19937.lowerMask
self.index = MT19937.n
MT19937.f = 1812433253

for i in range(1, self.n):
self.states.append(self.mask & (i + self.f * (self.states[i-1] ^ (self.states[i-1] >> (self.w - 2)))))

def temper(self,num):
num = num ^ ((num >> MT19937.u) & MT19937.d)
num = num ^ ((num << MT19937.s) & MT19937.b)
num = num ^ ((num << MT19937.t) & MT19937.c)
num = num ^ (num >> MT19937.l)
return num

def rand(self):
if self.index >= MT19937.n:
self.twist()
y = self.states[self.index]
self.index += 1
return self.temper(y)

def twist(self):
for i in range(MT19937.n):
x = (self.states[i] & MT19937.upperMask) ^ (self.states[(i + 1) % MT19937.n] & MT19937.lowerMask)
xA = x >> 1
if x & 1:
xA = xA ^ self.a

self.states[i] = self.states[(i + MT19937.m) % MT19937.n] ^ xA

self.index = 0

class bitwiseMT19937:
def __init__(self):
self.states = [[1 << (i * 32 + j) for j in range(32)] for i in range(MT19937.n)]
self.index = MT19937.n

def temper(self, num):
ret = num[:]
for i in range(32):
if (MT19937.d >> i) & 1:
if i + MT19937.u < 32:
ret[i] ^= ret[i + MT19937.u]

for i in range(31, -1, -1):
if (MT19937.b >> i) & 1:
if i - MT19937.s >= 0:
ret[i] ^= ret[i - MT19937.s]

for i in range(31, -1, -1):
if (MT19937.c >> i) & 1:
if i - MT19937.t >= 0:
ret[i] ^= ret[i - MT19937.t]

for i in range(32):
if i + MT19937.l < 32:
ret[i] ^= ret[i + MT19937.l]
# num = num ^ ((num >> MT19937.u) & MT19937.d)
# num = num ^ ((num << MT19937.s) & MT19937.b)
# num = num ^ ((num << MT19937.t) & MT19937.c)
# num = num ^ (num >> MT19937.l)
return ret

def rand(self):
if self.index >= MT19937.n:
self.twist()
y = self.states[self.index]
self.index += 1
return self.temper(y)

def twist(self):
for i in range(MT19937.n):
x = self.states[(i + 1) % MT19937.n][:-1] + self.states[i][-1:]
xA = x[1:] + [0]
for t in range(32):
if (MT19937.a >> t) & 1:
xA[t] ^= x[0]
# x = (self.states[i] & MT19937.upperMask) ^ (self.states[(i + 1) % MT19937.n] & MT19937.lowerMask)
# xA = x >> 1
# if x & 1:
# xA = xA ^ self.a

for t in range(32):
self.states[i][t] = self.states[(i + MT19937.m) % MT19937.n][t] ^ xA[t]
# self.states[i] = self.states[(i + MT19937.m) % MT19937.n] ^ xA

self.index = 0

def count(x):
ret = 0
for i in range(20000):
ret ^= ((x >> i) & 1)
return ret

TOTAL_BITS = 19968

linear_base = [(-1, -1) for _ in range(TOTAL_BITS)]
total = 0

def add(bits, output):
global total
while bits:
idx = bits.bit_length() - 1
if linear_base[idx] == (-1, -1):
linear_base[idx] = (bits, output)
total += 1
return
else:
bits ^= linear_base[idx][0]
output ^= linear_base[idx][1]


rng = MT19937(9487)
bitRng = bitwiseMT19937()

r = remote('1337pins.balsnctf.com', 27491)

def getNextOutput(debug = False):
if debug:
return random.getrandbits(32) % 10
else:
r.sendline(b'0')
tmp = r.recvline().strip()
return 0 if tmp == b'.' else int(tmp)

for i in tqdm(range(20000)):
bits = bitRng.rand()[0]
output = getNextOutput() & 1
add(bits, output)

print(total)

for i in range(TOTAL_BITS):
if linear_base[i] == (-1, -1):
continue
assert (linear_base[i][0] & ((1 << 31) - 1)) == 0


for i in tqdm(range(TOTAL_BITS)):
if linear_base[i] == (-1, -1):
linear_base[i] = (1 << i, 0)
continue

mask = linear_base[i][0] ^ (1 << i)
while mask:
idx = mask.bit_length() - 1
linear_base[i] = (linear_base[i][0] ^ linear_base[idx][0], linear_base[i][1] ^ linear_base[idx][1])
mask ^= (1 << idx)

for i in range(TOTAL_BITS):
assert linear_base[i][0] == (1 << i)

stateLong = sum((1 << i) * linear_base[i][1] for i in range(TOTAL_BITS))
states = []
for i in range(MT19937.n):
states.append(stateLong & ((1 << 32) - 1))
stateLong >>= 32

rng.states = states[:]

for i in range(20000):
rng.rand()

for i in range(1337):
myGuess = rng.rand() % 10
r.sendline(str(myGuess))
r.recvline()

r.interactive()