# Recap, Recess

Author: GabiTulbaContest: Timisoara CTF Finals 2018

This was a 2 in one problem. We were given a copy of the code that was running on a server in which the flags were in some files flag1, respectively flag2:

```
import signal
import sys
import os
import binascii
import random
os.chdir(os.path.dirname(os.path.abspath(__file__)))
MAGIC_NUMBER = 11
MSG_LENGTH = 119
CERT_CNT = MAGIC_NUMBER
coeffs = [random.randint(0, MAGIC_NUMBER**128) for i in range(MAGIC_NUMBER)]
def enc_func(msg):
global coeffs
msg = msg * 0x100 + 0xFF
acc = 0
cur = 1
for coeff in coeffs:
acc = (acc + coeff * cur) % (MAGIC_NUMBER**128)
cur = (cur * msg) % (MAGIC_NUMBER**128)
return acc % (MAGIC_NUMBER**128)
def get_hex_msg():
try:
msg = raw_input()
msg = int(msg,16)
return msg % (MAGIC_NUMBER ** (128) )
except:
print "Bad input"
exit()
def encryption():
print 'Input your message:'
msg = get_hex_msg()
ct = enc_func( msg )
print "Encryption >>>> %#x" % ct
def challenge1():
for i in range(CERT_CNT):
msg = random.randint(0, MAGIC_NUMBER**(MSG_LENGTH) )
ct = enc_func( msg )
print 'Encrypt this: %#x' % msg
ct2 = get_hex_msg()
if ct != ct2:
print "Your input %#x should have been %#x" % (ct2, ct)
exit()
print "You win challenge 1"
print open("flag1").read()
def challenge2():
for i in range(CERT_CNT):
msg = random.randint(0, MAGIC_NUMBER**(MSG_LENGTH))
ct = enc_func( msg )
print 'Decrypt this: %#x' % (ct)
msg2 = get_hex_msg()
if enc_func(msg) != enc_func(msg2):
print "Your input %#x should have been %#x" % (msg2, msg)
exit()
print "You win challenge 2"
print open("flag2").read()
def input_int(prompt):
sys.stdout.write(prompt)
try:
n = int(raw_input())
return n
except ValueError:
return 0
except:
exit()
def menu():
while True:
print "Horrible Crypto"
print "1. Arbitrary Encryption"
print "2. Encryption Challenge"
print "3. Decryption Challenge"
print "4. Exit"
choice = input_int("Command: ")
{
1: encryption,
2: challenge1,
3: challenge2,
4: exit,
}.get(choice, lambda *args:1)()
if __name__ == "__main__":
signal.alarm(15)
a menu()
```

Ok, so let’s analyse what we are given and what we have to do:

Running the code we can see a menu:

Horrible Crypto

- Arbitrary Encryption
- Encryption Challenge
- Decryption Challenge
- Exit

So, we are given an encryption oracle and two challanges, encryption (**task recap**) and decryption (**task recess**)

### Understanding the encryption

Let’s start with the encryption:

```
def enc_func(msg):
global coeffs
msg = msg * 0x100 + 0xFF
acc = 0
cur = 1
for coeff in coeffs:
acc = (acc + coeff * cur) % (MAGIC_NUMBER**128)
cur = (cur * msg) % (MAGIC_NUMBER**128)
return acc % (MAGIC_NUMBER**128)
def get_hex_msg():
try:
msg = raw_input()
msg = int(msg,16)
return msg % (MAGIC_NUMBER ** (128) )
except:
print "Bad input"
exit()
def encryption():
print 'Input your message:'
msg = get_hex_msg()
ct = enc_func( msg )
print "Encryption >>>> %#x" % ct
```

Inside `encryption()`

the input is a hex number which gets decoded to an int and then actually encrypted in `enc_func(msg)`

.

Now, let’s focus on how the encryption works:

It takes the message, it multiplies it by 256, adds 255 `msg = msg * 0x100 + 0xFF`

and then computes the polynomial `coeffs(msg)`

modulo 11^{128} (`MAGIC_NUMBER = 11`

is a constant), and that’s our ciphertext.

### Task Recap

Recap was the first challenge. The adversary is given 11 messages to encrypt. By successfully encrypting the messages, the flag is given to the adversary.

```
def challenge1():
for i in range(CERT_CNT):
msg = random.randint(0, MAGIC_NUMBER**(MSG_LENGTH) )
ct = enc_func( msg )
print 'Encrypt this: %#x' % msg
ct2 = get_hex_msg()
if ct != ct2:
print "Your input %#x should have been %#x" % (ct2, ct)
exit()
print "You win challenge 1"
print open("flag1").read()
```

So basically we need to find the coefficients of the polynomial **coeffs** in order to be able to encrypt. Note that we have an encryption oracle so practically we can find any coefficient of the polynomial **coeffs** right? Well, that’s great because we can recreate the original polynomial by creating a **Lagrange polynomial** of degree equal to the degree of the polynomial **coeffs**.

You can read more about it on Wikipedia, the math is pretty simple, and there’s an image that visually describes very well the algorithm.
**NOTE**: since the contest’s servers don’t work anymore you will have to run the challange code locally.

### Task Recess

Now, the second task was the **real deal**. The adversary has to decrypt 11 messages. I didn’t solve this task during the contest (I solved it about 4 hours after the contest ended, but I still got goosebumbs when I got the flag). About 3 hours before the end of the contest a hint was added (sadly, i don’t remember the exact form of the hint), and eventually I found an article on Wikipedia and a bit of code on GitHub. After a good 30 minutes of reading I implemented the solution, practically for a message **msg** we had to find a root of the polynomial `f(x)=coeffs(x)-msg`

. The mistake I made was that I didn’t realize what this line of code was doing:

```
msg = msg * 0x100 + 0xFF
```

This makes it so that the encription is a bit more complicated, encription isn’t just **coeffs(msg)**, it’s **coeffs(g(msg))**, where **g** is another ploynomial: `g(x) = 256*x+255`

. That means that Hensel’s lemma had to be applied to the polynomial **coeffs(g(x))**. So I wrote a funtion that composes two polynomials.

### The code

Here’s my solution to both problems:

```
from pwn import *
from Crypto.Util.number import inverse
MOD=11**128
def transform(x):
return (x*256+255)%MOD
#Polynomial part
class Polynomial(list):
def __init__(self,coeffs):
self.coeffs = coeffs
def evaluate(self,x,mod):
val = 0
for i in range(len(self.coeffs)):
val = (val + x**i * self.coeffs[i]) % mod
return val
def raise_degree(self,x):
coeffs=[]
for i in range(x):
coeffs.append(0)
for i in range(len(self.coeffs)):
coeffs.append(self.coeffs[i])
self.coeffs=coeffs
def add_to_degree(self,x,y):
while(len(self.coeffs)<=x):
self.coeffs.append(0)
self.coeffs[x]=(self.coeffs[x]+y)%MOD
def add_poly(self,x):
while(len(self.coeffs)<len(x.coeffs)):
self.coeffs.append(0)
for i in range(len(x.coeffs)):
self.coeffs[i]=(self.coeffs[i]+x.coeffs[i])%MOD
def multiply(self,x):
for i in range(len(self.coeffs)):
self.coeffs[i]=(self.coeffs[i] * x)%MOD
def multiply_with_poly(self,p):
coeffs=Polynomial([])
for i in range(len(self.coeffs)):
q=Polynomial(p.coeffs)
q.raise_degree(i)
q.multiply(self.coeffs[i])
coeffs.add_poly(q)
self.coeffs=coeffs.coeffs
def calculate_derivative(self):
p=Polynomial([])
for i in range(1,len(self.coeffs)):
p.coeffs.append((self.coeffs[i]*i)%MOD)
return p
def compose(self,p):
P=Polynomial([])
q=Polynomial([1])
for i in range(len(self.coeffs)):
q2=Polynomial(q.coeffs)
q2.multiply(self.coeffs[i])
P.add_poly(q2)
q.multiply_with_poly(p)
self.coeffs=P.coeffs
def print_poly(self):
return self.coeffs
#Lagrange Interpolation part
def Lagrange_Basis_Polynomial(xlist,index):
l=Polynomial([1])
for i in range(len(xlist)):
if(i==index):
continue
p=Polynomial([(MOD-xlist[i])%MOD,1])
p.multiply(inverse((MOD+xlist[index]-xlist[i])%MOD,MOD))
l.multiply_with_poly(p)
return l
def Lagrange_Polynomial(xylist):
xlist=[]
ylist=[]
L=Polynomial([])
for a in xylist:
xlist.append(a[0])
ylist.append(a[1])
for i in range(len(ylist)):
l=Lagrange_Basis_Polynomial(xlist,i)
l.multiply(ylist[i])
L.add_poly(l)
return L
#Hensel's Lemma part
def Hensel(f,d,p,k):
if k==1:
r=[]
for i in range(p):
if(f.evaluate(i,p)==0):
r.append(i)
return r
r=Hensel(f,d,p,k-1)
new=[]
for i,n in enumerate(r):
dr=d.evaluate(n,p)
fr=f.evaluate(n,p**k)
if dr!=0:
for t in range(p):
if(f.evaluate(r[i]+t*p**(k-1),p**k)==0):
new.append(r[i]+t*p**(k-1))
break
if dr==0:
if fr%p**k==0:
for t in range(p):
new.append(r[i]+t*p**(k-1))
return new
#Main part
xylist=[]
s=remote('localhost',1337)
for i in range(11):
s.recvuntil('Command: ')
s.sendline('1')
s.recvuntil('Input your message:')
s.sendline(hex(i)[2:])
s.recvline()
x=s.recvline()[18:][:-1]
x=int(x,16)
xylist.append((transform(i),x))
print "Values of the polynomial:"
for x in xylist:
print 'f('+str(x[0])+') = '+str(x[1])
coeffs=Lagrange_Polynomial(xylist)
print "\nReconstructed polynomial"
for i,x in zip([i for i in range(11)],coeffs.print_poly()):
print str(i)+':',x
print '\nStarting challenge 1!'
s.recvuntil('Command: ')
s.sendline('2')
for i in range(11):
s.recvuntil('Encrypt this: ')
x=int(s.recvline()[2:],16)
s.sendline(str(hex(coeffs.evaluate(transform(x),MOD))[2:]))
print 'step',i+1,'done'
print '\n'+s.recvline()
print s.recvline()
print '\nComposed polynomial'
f=Polynomial(coeffs.coeffs)
g=Polynomial([255,256])
f.compose(g)
for i,x in zip([i for i in range(11)],coeffs.print_poly()):
print str(i)+':',x
print '\nStarting challenge 2!'
s.recvuntil('Command: ')
s.sendline('3')
for i in range(11):
s.recvuntil('Decrypt this: ')
x=int(s.recvline()[2:],16)
print x
f=Polynomial(coeffs.coeffs)
g=Polynomial([255,256])
f.compose(g)
h=Polynomial(f.coeffs)
h.add_to_degree(0,-x)
d=f.calculate_derivative()
out=Hensel(h,d,11,128)
print coeffs.evaluate(transform(out[0]),MOD)
s.sendline(hex(out[0])[2:])
print 'step',i+1,'done'
print '\n'+s.recvline()
print s.recvline()
```