**Title**: Additional Problems

**Category**: cryptography

**Points**: 500

**Solves**: 2

**Description**:

Version 1.0 of our new encryption service has just launched! It is blazingly fast and uses state-of-the-art encryption. Stay tuned for version 2.0; I hear it will bring tons of improvements and security fixes.

Access the server via `nc IP_ADDRESS PORT`

(the server is now down, but you can run it yourself)

# Introduction

This challenge proved to be among the hardest challenges at CSCBE, with only one team submitting a flag. The teams at Zeus WPI ended up in 2nd, 9th (virtual, since winners are not allowed to participate again), 52nd and 83rd and 104th place at the online qualifiers.

# Description

The challenge provides a Python file, which runs a server accessible through a TCP socket. The server implements a form of homomorphic encryption: with homomorphic encryption, addition of two ciphertexts results in a ciphertext that decrypts to the addition of the two plaintexts. Here, the encryption was implemented with a secret 128-bit prime `p`

, and an `N`

between 128 and 255. The cryptosystem works as follows:

```
def dghv_encrypt(p, N, m):
"""
Encrypt a value to later decrypt with `dghv_decrypt`
"""
assert 2**7 <= N < 2**8 # Normally this is 2, but by using a bigger `N` we can encode ASCII bytes instead of bits! That's much more efficient. All `N` in this range should be secure, so let's make it an assertion
q = reduce(lambda x, y: x*y, [Crypto.Util.number.getPrime(128) for _ in range(8)]) # `q` can be any number, but as we all know, big primes are the safest numbers there are
rmax = 2**128 / N / 4
r = random.randint(0, rmax) # In v2.0, we will let `r` be negative as well as positive => double the randomness!
return p*q + N*r + m
def dghv_decrypt(p, N, c):
"""
Since c = pq + Nr + m, we can find m as (c mod p) mod N!
"""
return (c % p) % N
def dghv_add(c1, c2):
"""
The sum of ciphertexts decodes to the sum of plaintexts!!!
"""
return c1 + c2 # We will add bootstrapping to make this fully homomorphic in v2.0
```

Encryption, decryption and adding works on byte-based granularity: every byte is encrypted separately.

When first connecting, the server generates a new `p`

, encrypts the flag with it and prints the encrypted flag. It then asks for an `N`

value and a (hexadecimal) plaintext that will be used in your first session.

It then provides a menu where you have three options:

- Start a new session, where you need to give an
`N`

value and a plaintext; it will use`dghv_encrypt`

to set the ciphertext - Add an additional plaintext to the current session: it will encrypt that plaintext with
`dghv_encrypt`

, then use`dghv_add`

to add it to the current ciphertext - Decrypt the ciphertext: it returns the result of
`dghv_decrypt`

on the current ciphertext

Note that other than the initial ciphertext printed at startup, we never have access to any ciphertexts.

# Vulnerability

If you add too many plaintexts, the `N*r`

terms become greater than `p`

, making the decryption step `(c % p) % N`

not work correctly anymore. Suppose you always send zero-filled plaintext, and it eventually decrypts to a value `x != 0`

, then we have:

```
(p*q + N*r + (m_0 + m_1 + m_... + m_n)) % p % N == x
because all message bytes m_n are 0, we have
(p*q + N*r) % p % N == x
now, let N*r > p, but smaller than 2*p (we can be sure of this,
since r only increases with maximum 2**128 / N / 4 per encryption
and p is a 128 bit prime), then we have
N*r = p + a.
By substituting N*r for p + a
(p*q + p + a) % p % N == x
which is then equal to
a % p % N == x
and thus, since a < p,
a % N = x
and since 0 < x < N,
p % N == N - x.
```

We can then iterate through all prime values of N between 128 and 255 to get each value of `p % N_i`

; once we have these values, we can use the Chinese Remainder Theorem to calculate p (since the product of all those primes is bigger than 128 bits, p is uniquely determined).

# Solution script

This script was written in the heat of the moment, so itâ€™s not the cleanest, but it works. It uses the excellent pwntools library.

```
from pwn import remote
conn = remote('additional_problems.challenges.cybersecuritychallenge.be', 1340)
def dghv_decrypt(p, N, c):
"""
Since c = pq + Nr + m, we can find m as (c mod p) mod N!
"""
return (c % p) % N
def add(conn, msg):
print('add')
conn.sendline(b'2')
conn.sendline(msg)
conn.recvuntil(b'> ')
def get(conn):
print('get')
conn.sendline(b'3')
lines = conn.recvuntil(b'> ')
hexed = [int(a, 16) for a in lines.decode().split('\n')[0].split(': ')[1].split(' ')]
return hexed
def new(conn, N, msg):
print('new')
conn.sendline(b'1')
conn.recvuntil(b'Choose N: ')
conn.sendline(str(N).encode())
conn.recvuntil(b'Message to encode (converted to hexadecimal): ')
conn.sendline(msg)
conn.recvuntil(b'> ')
conn.recvuntil(b"We've even encrypted our secret flag with it:\n")
encrypted = conn.recvuntil(b'\n\n').decode().strip()
# only in modified testserver
# p, encrypted = encrypted.split('\n', maxsplit=1)
# p = int(p.strip())
ciphertexts = [int(e.strip()) for e in encrypted.split('\n')]
# skip initial setup
conn.sendline(b'128')
conn.recvuntil(b'Message to encode (converted to hexadecimal): ')
conn.sendline((' '.join('00' for _ in range(30))).encode())
conn.recvuntil(b'> ')
# real shit
all_zeroes = (' '.join('00' for _ in range(10))).encode()
primes = [131, 137, 139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193, 197, 199, 211, 223, 227, 229, 233, 239, 241, 251]
mapped = {}
from functools import reduce
def chinese_remainder(m, a):
sum = 0
prod = reduce(lambda acc, b: acc*b, m)
for n_i, a_i in zip(m, a):
p = prod // n_i
sum += a_i * mul_inv(p, n_i) * p
return sum % prod
def mul_inv(a, b):
b0 = b
x0, x1 = 0, 1
if b == 1: return 1
while a > 1:
q = a // b
a, b = b, a%b
x0, x1 = x1 - q * x0, x0
if x1 < 0: x1 += b0
return x1
for prime_idx, prime in enumerate(primes):
new(conn, prime, all_zeroes)
for i in range(32):
add(conn, all_zeroes)
received = get(conn)
reduced = [d for d in received if d != 0]
if reduced:
mapped[prime] = prime - reduced[0]
# only in modified testserver
# assert (p % prime == mapped[prime])
# print("OK!")
break
print(f'{prime_idx} of {len(primes)} {i=}')
moduli = [mapped[prime] for prime in primes]
recovered = chinese_remainder(primes, moduli)
print('FLAG=')
for cipher in ciphertexts:
print(chr(dghv_decrypt(recovered, 128, cipher)), end='')
print('')
```