Featured image of post Toxicity - SummerRush CTF

Toxicity - SummerRush CTF

Click to expand challenge code
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from Crypto.Util.number import bytes_to_long
from os import urandom
from PIL import Image
import numpy as np
import random

class DoubleLCG:
    def __init__(self, a1,a2, b1,b2, m, seed1, seed2):
        self.a1 = a1
        self.a2 = a2
        self.b1 = b1
        self.b2 = b2
        self.m = m
        self.state1 = bytes_to_long(urandom(6)) if seed1 is None else seed1
        self.state2 = bytes_to_long(urandom(6)) if seed2 is None else seed2
        self.counter = 0

    def refresh(self):
        self.counter = 0
        self.state2 = (self.a2 * self.state2 + self.b2) % self.m
        self.state1 = self.state1 ^ (self.state2 >> 40)

    def next_state(self):
        self.state1 = (self.a1 * self.state1 + self.b1) % self.m

    def get_random_bits(self, k):
        if self.counter == 16:
            self.refresh()
        self.counter += 1
        self.next_state()
        return self.state1 >> (48 - k)

    def get_random_bytes(self, number):
        bytes_sequence = b''
        for i in range(number):
            bytes_sequence += bytes([self.get_random_bits(8)])
        return bytes_sequence

def xor_image(strm, im_bytes, im_array):
    xor_bytes = bytes([a ^ b for a, b in zip(im_bytes, strm)])
    xor_array = np.frombuffer(xor_bytes, dtype=im_array.dtype)
    xor_array = xor_array.reshape(im_array.shape)
    return Image.fromarray(xor_array)

def and_image(strm, im_bytes, im_array):
    and_out = bytes([a & b for a, b in zip(im_bytes, strm)])
    and_out = np.frombuffer(and_out, dtype=im_array.dtype)
    and_out = and_out.reshape(im_array.shape)
    return Image.fromarray(and_out)

a1, b1, m = 0xC0FF1555, 0xB1, 1 << 48
a2, b2 = 0xBABE1337, 0xB2
seed1 = bytes_to_long(urandom(6))
seed2 = bytes_to_long(urandom(6))
lcg = DoubleLCG(a1, a2, b1, b2, m, seed1, seed2)
inp = Image.open('serjy.png')
img_array = np.array(inp)
img_bytes = img_array.tobytes()

stream = lcg.get_random_bytes(len(img_bytes))

print(len(stream))

out = xor_image(stream, img_bytes, img_array)
out.save('serjy_out.png')

# Corrupt image
stream2 = b'\x00'*len(img_bytes)
save = []
for _ in range(10):
    shift = random.randint(0, 7)
    reveal_int = int.from_bytes(b'\xff'*9, 'big') << shift*8
    reveal = reveal_int.to_bytes(16, 'big')
    save.append(reveal)
save = b''.join(save)
idx = random.randint(0, (len(stream2) - len(save))//16) * 16
stream2 = stream2[:idx] + save + stream2[idx + len(save):]

out = and_image(stream2, img_bytes, img_array)
out.save('serjy_corrupt.png')

The other notorious challenge from this ctf that managed to stay unsolved for the entire competition. It’s a shame, the decrypted image is really nice, I wanted you to see it πŸ˜”


Analysis

Functionality

You’re given 2 images alongside the source code:

  • serjy_out.png: The plain image xored with a random keystream derived from two intertwined LCGs
  • serjy_corrupt.png: A small part of the plain image

The keystream is pseudorandomly generated using 2 LCGs Linear congruential generators (lets’s call them $L_1$ and $L_2$ for convenience) in the following manner:

  • Both LCGs are initialized with a 48 bit random initial seed.
  • The output keystream bytes are exclusively derived from $L_1$
  • Each output byte advances the $L_1$ state by $1$
  • Every $16$ bytes, $L_2$ advances in state by $1$ and $L_1$’s state is “refreshed” as follows: $$ state_1 = state_1 \oplus \left( state_2 \gg 40 \right) $$ AKA: $state_1$ is set to $state_1$ XOR $state_2$’s most significant byte

This keeps on going until we encrypt the entire image

Known Plaintext Attack

We know that LCGs are a deprecated cryptographic primitive that have been proven broken time and time again, as with just a few outputs we’re capable of reconstructing the entire state of the generator (more about the technical details here)

To do so, you need enough to know some parts of the keystream otherwise you’ll be driving blindly. That is why I gave you serjy_corrupt.png, letting you perform some sort of known plaintext attack. Xoring it with the encrypted image will grant you some bytes of the keystream, but only some of them!


Truncated LCG Attack

the problem with the aforementioned method, is that you need the full states of the LCG generator (pretty similar to mt19937 for the jigsaw challenge)… or do you? *Vsauce theme *

Turns out that, even with a truncated output, you can recover the entire state of a congruential generator! We can use Z3 like the jigsaw task, but it’ll take a LOOOOT longer, so why don’t we use something else?


Notice the first word of the acronym LCG? Linear? Interesting… What can we use to model linear systems? THAT’s RIGHT! FUCKING ALGEBRAAAAAA

Since our LCGs are defined by the recurrence $x_{i+1} \equiv a \cdot x_i + c \pmod m$. We are given the truncated outputs $y_i$, which are the $s$ most significant bits of the $k$-bit state $x_i$. This means $x_i = y_i \cdot 2^{k-s} + \delta_i$, where $\delta_i$ is the unknown lower part, $0 \le \delta_i < 2^{k-s}$.

Substituting this into the LCG recurrence gives:

$$ y_{i+1} \cdot 2^{k-s} + \delta_{i+1} \equiv a \cdot (y_i \cdot 2^{k-s} + \delta_i) + c \pmod m $$

Rearranging for the unknown $\delta_i$ terms:

$$ \delta_{i+1} - a \cdot \delta_i \equiv a \cdot y_i \cdot 2^{k-s} + c - y_{i+1} \cdot 2^{k-s} \pmod m $$

Let $z_i = a \cdot y_i \cdot 2^{k-s} + c - y_{i+1} \cdot 2^{k-s}$. The $z_i$ values are known. We now have a system of linear congruences for the small unknown values $\delta_i$:

$$ \delta_1 - a \cdot \delta_0 \equiv z_0 \pmod m \\ \delta_2 - a \cdot \delta_1 \equiv z_1 \pmod m \\ \vdots \\ \delta_n - a \cdot \delta_{n-1} \equiv z\_{n-1} \pmod m $$

Does this look familiar? if you said that it resembles the Hidden Number Problem then you’re a fucking nerd, ngl twinπŸ’€ but you’re right!

We can solve this by constructing a lattice as follows:

$$ B = \begin{pmatrix} m & 0 & 0 & \cdots & 0 \\ a^1 & -1 & 0 & \cdots & 0 \\ a^2 & 0 & -1 & \cdots & 0 \\ \vdots & \vdots & \vdots & \ddots & \vdots \\ a^{n-1} & 0 & 0 & \cdots & -1 \end{pmatrix} $$

with $a$ and $m$ being the parameters of our LCG. With this matrix in place, we call for $LLL$ to reduce the lattice, making the calculations more managable, then we can solve our system of equations (which is modeled as a lattice in this case) to recover the $\delta_i$ values, letting us recover the full state of the LCG in a pretty elegant way!

Dual LCG

You might have noticed that we only talked about a single truncated LCG, but we have two of them, and they’re intertwined! Guess what, we know $L_1$ is partially derived from $L_2$, meaning that, if we recovered enough concrete states of $L_1$ we’re gonna get enough MSBs of $L_2$ to, you guessed it, perform ANOTHER Truncated LCG attack!

This might sound a bit convoluted, but think of it like how the earth spins $365$ times around itself, and once around the sun in a single year; now think of the earth as $L_1$ and the sun as $L_2$.
Sounds good xd? Hopefully the script clears things up

Solution

Plan

  • Recover the $L_1$ keysteam bytes with the “known plaintext attack”
  • Reconstruct each $L_1$ state
  • Derive the $L_2$ state MSBs associated with each $L_1$
  • Reconstruct the $L_2$ state
  • Use the same keystream derivation backwards to reconstruct the entire keystream
  • XOR the encrypted image with the keystream

Note: You might notice that sometimes when you decrypt, the output will still look gibberish. That’s why you should always debug your code step by step to get a clear idea of what’s going on.
I purposely left the debugging logs in the solver so you can understand it aswell!

Solver

Click to expand solver code
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
from PIL import Image
import numpy as np
from Crypto.Util.number import bytes_to_long
from pprint import pprint
from os import urandom
from sage.all import QQ
from sage.all import ZZ
from sage.all import matrix
from sage.all import vector


def truncated_lcg_rec(y, k, s, m, a, c):
    diff_bit_length = k - s
    delta = c % m
    y = vector(ZZ, y)
    for i in range(len(y)):
        y[i] = (y[i] << diff_bit_length) - delta
        delta = (a * delta + c) % m

    B = matrix(ZZ, len(y), len(y))
    B[0, 0] = m
    for i in range(1, len(y)):
        B[i, 0] = a ** i
        B[i, i] = -1

    B = B.LLL()
    b = B * y
    for i in range(len(b)):
        b[i] = round(QQ(b[i]) / m) * m - b[i]

    delta = c % m
    x = list(B.solve_right(b))
    for i, state in enumerate(x):
        x[i] = int(y[i] + state + delta)
        delta = (a * delta + c) % m

    return x

class DoubleLCG:
    def __init__(self, a1,a2, b1,b2, m, seed1, seed2):
        self.a1 = a1
        self.a2 = a2
        self.b1 = b1
        self.b2 = b2
        self.m = m
        self.state1 = bytes_to_long(urandom(6)) if seed1 is None else seed1
        self.state2 = bytes_to_long(urandom(6)) if seed2 is None else seed2
        self.counter = 0

    def refresh(self):
        self.counter = 0
        self.state2 = (self.a2 * self.state2 + self.b2) % self.m
        self.state1 = self.state1 ^ (self.state2 >> 40)

    def next_state(self):
        self.state1 = (self.a1 * self.state1 + self.b1) % self.m

    def get_random_bits(self, k):
        if self.counter == 16:
            self.refresh()
        self.counter += 1
        self.next_state()
        return self.state1 >> (48 - k)

    def get_random_bytes(self, number):
        bytes_sequence = b''
        for i in range(number):
            bytes_sequence += bytes([self.get_random_bits(8)])
        return bytes_sequence

def xor_image(strm):
    xor_bytes = bytes([a ^ b for a, b in zip(enc_bytes, strm)])
    xor_array = np.frombuffer(xor_bytes, dtype=enc_array.dtype)
    xor_array = xor_array.reshape(enc_array.shape)

    return Image.fromarray(xor_array)




a1, b1, m = 0xC0FF1555, 0xB1, 1 << 48
a2, b2 = 0xBABE1337, 0xB2
def prev_state(state, a, b, m):
    return (state - b) * pow(a, -1, m) % m

def prev_refresh(state1_after, state2_after, a1, a2, b1, b2, m):
    state1_before = state1_after
    for _ in range(16):
        state1_before = prev_state(state1_before, a1, b1, m)
    state1_mid = state1_before ^ (state2_after >> 40)
    state2_before = prev_state(state2_after, a2, b2, m)



    return state1_mid, state2_before


corrupt = Image.open('serjy_corrupt.png')
corrupt_array = np.array(corrupt)
corrupt_bytes = corrupt_array.tobytes()

enc = Image.open('serjy_out.png')
enc_array = np.array(enc)
enc_bytes = enc_array.tobytes()

chunks = {}
idx = 0

for i in range(0,len(corrupt_bytes),16):
    if corrupt_bytes[i:i+16] != b'\x00'*16:
        chunk = corrupt_bytes[i:i+16]
        for j in range(16):
            if chunk[j] != 0:
                idx = j
                break
        idx = i + idx
        print(f"\nIndex: {idx}")
        ret = [x ^ y for x, y in zip(corrupt_bytes[idx:idx+9], enc_bytes[idx:idx+9])]
        print(ret)
        chunks[idx] = ret

seeds = []
for k, v in chunks.items():
    if len(v) >= 9:
        try:
            s = truncated_lcg_rec(v, 48, 8, m, a1, b1)
            if s:
                st = prev_state(s[0], a1, b1, m)
                print(f"Recovered state: {st}")
                position_in_stream = k
                lg = DoubleLCG(a1, a2, b1, b2, m, st, st)
                stream = lg.get_random_bytes(16)
                splits = [sk for sk in stream]
                print(f"Index: {k} | Stream: {splits}")

                block_number = position_in_stream // 16
                offset_in_block = position_in_stream % 16
                seeds.append((st, offset_in_block, block_number))
                print(f"Successfully recovered state from chunk at position {k}")

        except Exception as e:
            print(f"Failed to recover state from chunk at position {k}: {e}")
            continue
    else:
        print(f"Chunk at position {k} is too short to recover a state: {len(v)} bytes")
        exit()

if not seeds:
    print("No valid seeds recovered!")
    exit()
seeds.sort(key=lambda x: x[2])  # Sort by block number
pprint(seeds)

normalized_seeds = []
for s, offset, block_num in seeds:
    ps = s
    for i in range(offset):
        ps = prev_state(ps, a1, b1, m)
    normalized_seeds.append((ps, block_num))

print("Normalized seeds:")
pprint(normalized_seeds)

prev = 0
second_states = []
for i in range(len(normalized_seeds)-1):
    lgg = DoubleLCG(a1, a2, b1, b2, m, normalized_seeds[i][0], normalized_seeds[i][0])
    last = 0
    for j in range(16):
        lgg.next_state()
        last = lgg.state1
        print(last>>40, end=' ')
    print()
    print(f"Last state for seed {i}: {last}")
    second_states.append(last ^ normalized_seeds[i+1][0])
print("Second states:")
pprint(second_states)


current_state2 = truncated_lcg_rec(second_states, 48, 8, m, a2, b2)[0]
print(f"Recovered second state: {current_state2}")

sec1, block_num = normalized_seeds[0]
sc2 = prev_state(current_state2, a2, b2, m)
lcg = DoubleLCG(a1, a2, b1, b2, m, sec1,sc2 )

_ = lcg.get_random_bytes(16) #for alignment
s1,s2 = lcg.state1, lcg.state2
for _ in range(block_num):
    s1, s2 = prev_refresh(s1, s2, a1, a2, b1, b2, m)
s2 = (s2*a2+b2)%m
s1 = s1 ^ (s2 >> 40)

initial_lcg = DoubleLCG(a1, a2, b1, b2, m, s1, s2)
stream = initial_lcg.get_random_bytes(len(enc_bytes))
sk = b'\x00' * 16  + stream
out = xor_image(sk)
out.save(f'decrypted_serjy.png')

Here’s the decrypted picture if you’re wondering!

cute serj :3