Recap, Recess
Author: GabiTulbaContest: Timisoara CTF Finals 2018
This was a 2 in one problem. We were given a copy of the code that was running on a server in which the flags were in some files flag1, respectively flag2:
import signal
import sys
import os
import binascii
import random
os.chdir(os.path.dirname(os.path.abspath(__file__)))
MAGIC_NUMBER = 11
MSG_LENGTH = 119
CERT_CNT = MAGIC_NUMBER
coeffs = [random.randint(0, MAGIC_NUMBER**128) for i in range(MAGIC_NUMBER)]
def enc_func(msg):
global coeffs
msg = msg * 0x100 + 0xFF
acc = 0
cur = 1
for coeff in coeffs:
acc = (acc + coeff * cur) % (MAGIC_NUMBER**128)
cur = (cur * msg) % (MAGIC_NUMBER**128)
return acc % (MAGIC_NUMBER**128)
def get_hex_msg():
try:
msg = raw_input()
msg = int(msg,16)
return msg % (MAGIC_NUMBER ** (128) )
except:
print "Bad input"
exit()
def encryption():
print 'Input your message:'
msg = get_hex_msg()
ct = enc_func( msg )
print "Encryption >>>> %#x" % ct
def challenge1():
for i in range(CERT_CNT):
msg = random.randint(0, MAGIC_NUMBER**(MSG_LENGTH) )
ct = enc_func( msg )
print 'Encrypt this: %#x' % msg
ct2 = get_hex_msg()
if ct != ct2:
print "Your input %#x should have been %#x" % (ct2, ct)
exit()
print "You win challenge 1"
print open("flag1").read()
def challenge2():
for i in range(CERT_CNT):
msg = random.randint(0, MAGIC_NUMBER**(MSG_LENGTH))
ct = enc_func( msg )
print 'Decrypt this: %#x' % (ct)
msg2 = get_hex_msg()
if enc_func(msg) != enc_func(msg2):
print "Your input %#x should have been %#x" % (msg2, msg)
exit()
print "You win challenge 2"
print open("flag2").read()
def input_int(prompt):
sys.stdout.write(prompt)
try:
n = int(raw_input())
return n
except ValueError:
return 0
except:
exit()
def menu():
while True:
print "Horrible Crypto"
print "1. Arbitrary Encryption"
print "2. Encryption Challenge"
print "3. Decryption Challenge"
print "4. Exit"
choice = input_int("Command: ")
{
1: encryption,
2: challenge1,
3: challenge2,
4: exit,
}.get(choice, lambda *args:1)()
if __name__ == "__main__":
signal.alarm(15)
a menu()
Ok, so let’s analyse what we are given and what we have to do:
Running the code we can see a menu:
Horrible Crypto
- Arbitrary Encryption
- Encryption Challenge
- Decryption Challenge
- Exit
So, we are given an encryption oracle and two challanges, encryption (task recap) and decryption (task recess)
Understanding the encryption
Let’s start with the encryption:
def enc_func(msg):
global coeffs
msg = msg * 0x100 + 0xFF
acc = 0
cur = 1
for coeff in coeffs:
acc = (acc + coeff * cur) % (MAGIC_NUMBER**128)
cur = (cur * msg) % (MAGIC_NUMBER**128)
return acc % (MAGIC_NUMBER**128)
def get_hex_msg():
try:
msg = raw_input()
msg = int(msg,16)
return msg % (MAGIC_NUMBER ** (128) )
except:
print "Bad input"
exit()
def encryption():
print 'Input your message:'
msg = get_hex_msg()
ct = enc_func( msg )
print "Encryption >>>> %#x" % ct
Inside encryption()
the input is a hex number which gets decoded to an int and then actually encrypted in enc_func(msg)
.
Now, let’s focus on how the encryption works:
It takes the message, it multiplies it by 256, adds 255 msg = msg * 0x100 + 0xFF
and then computes the polynomial coeffs(msg)
modulo 11128 (MAGIC_NUMBER = 11
is a constant), and that’s our ciphertext.
Task Recap
Recap was the first challenge. The adversary is given 11 messages to encrypt. By successfully encrypting the messages, the flag is given to the adversary.
def challenge1():
for i in range(CERT_CNT):
msg = random.randint(0, MAGIC_NUMBER**(MSG_LENGTH) )
ct = enc_func( msg )
print 'Encrypt this: %#x' % msg
ct2 = get_hex_msg()
if ct != ct2:
print "Your input %#x should have been %#x" % (ct2, ct)
exit()
print "You win challenge 1"
print open("flag1").read()
So basically we need to find the coefficients of the polynomial coeffs in order to be able to encrypt. Note that we have an encryption oracle so practically we can find any coefficient of the polynomial coeffs right? Well, that’s great because we can recreate the original polynomial by creating a Lagrange polynomial of degree equal to the degree of the polynomial coeffs.
You can read more about it on Wikipedia, the math is pretty simple, and there’s an image that visually describes very well the algorithm.
NOTE: since the contest’s servers don’t work anymore you will have to run the challange code locally.
Task Recess
Now, the second task was the real deal. The adversary has to decrypt 11 messages. I didn’t solve this task during the contest (I solved it about 4 hours after the contest ended, but I still got goosebumbs when I got the flag). About 3 hours before the end of the contest a hint was added (sadly, i don’t remember the exact form of the hint), and eventually I found an article on Wikipedia and a bit of code on GitHub. After a good 30 minutes of reading I implemented the solution, practically for a message msg we had to find a root of the polynomial f(x)=coeffs(x)-msg
. The mistake I made was that I didn’t realize what this line of code was doing:
msg = msg * 0x100 + 0xFF
This makes it so that the encription is a bit more complicated, encription isn’t just coeffs(msg), it’s coeffs(g(msg)), where g is another ploynomial: g(x) = 256*x+255
. That means that Hensel’s lemma had to be applied to the polynomial coeffs(g(x)). So I wrote a funtion that composes two polynomials.
The code
Here’s my solution to both problems:
from pwn import *
from Crypto.Util.number import inverse
MOD=11**128
def transform(x):
return (x*256+255)%MOD
#Polynomial part
class Polynomial(list):
def __init__(self,coeffs):
self.coeffs = coeffs
def evaluate(self,x,mod):
val = 0
for i in range(len(self.coeffs)):
val = (val + x**i * self.coeffs[i]) % mod
return val
def raise_degree(self,x):
coeffs=[]
for i in range(x):
coeffs.append(0)
for i in range(len(self.coeffs)):
coeffs.append(self.coeffs[i])
self.coeffs=coeffs
def add_to_degree(self,x,y):
while(len(self.coeffs)<=x):
self.coeffs.append(0)
self.coeffs[x]=(self.coeffs[x]+y)%MOD
def add_poly(self,x):
while(len(self.coeffs)<len(x.coeffs)):
self.coeffs.append(0)
for i in range(len(x.coeffs)):
self.coeffs[i]=(self.coeffs[i]+x.coeffs[i])%MOD
def multiply(self,x):
for i in range(len(self.coeffs)):
self.coeffs[i]=(self.coeffs[i] * x)%MOD
def multiply_with_poly(self,p):
coeffs=Polynomial([])
for i in range(len(self.coeffs)):
q=Polynomial(p.coeffs)
q.raise_degree(i)
q.multiply(self.coeffs[i])
coeffs.add_poly(q)
self.coeffs=coeffs.coeffs
def calculate_derivative(self):
p=Polynomial([])
for i in range(1,len(self.coeffs)):
p.coeffs.append((self.coeffs[i]*i)%MOD)
return p
def compose(self,p):
P=Polynomial([])
q=Polynomial([1])
for i in range(len(self.coeffs)):
q2=Polynomial(q.coeffs)
q2.multiply(self.coeffs[i])
P.add_poly(q2)
q.multiply_with_poly(p)
self.coeffs=P.coeffs
def print_poly(self):
return self.coeffs
#Lagrange Interpolation part
def Lagrange_Basis_Polynomial(xlist,index):
l=Polynomial([1])
for i in range(len(xlist)):
if(i==index):
continue
p=Polynomial([(MOD-xlist[i])%MOD,1])
p.multiply(inverse((MOD+xlist[index]-xlist[i])%MOD,MOD))
l.multiply_with_poly(p)
return l
def Lagrange_Polynomial(xylist):
xlist=[]
ylist=[]
L=Polynomial([])
for a in xylist:
xlist.append(a[0])
ylist.append(a[1])
for i in range(len(ylist)):
l=Lagrange_Basis_Polynomial(xlist,i)
l.multiply(ylist[i])
L.add_poly(l)
return L
#Hensel's Lemma part
def Hensel(f,d,p,k):
if k==1:
r=[]
for i in range(p):
if(f.evaluate(i,p)==0):
r.append(i)
return r
r=Hensel(f,d,p,k-1)
new=[]
for i,n in enumerate(r):
dr=d.evaluate(n,p)
fr=f.evaluate(n,p**k)
if dr!=0:
for t in range(p):
if(f.evaluate(r[i]+t*p**(k-1),p**k)==0):
new.append(r[i]+t*p**(k-1))
break
if dr==0:
if fr%p**k==0:
for t in range(p):
new.append(r[i]+t*p**(k-1))
return new
#Main part
xylist=[]
s=remote('localhost',1337)
for i in range(11):
s.recvuntil('Command: ')
s.sendline('1')
s.recvuntil('Input your message:')
s.sendline(hex(i)[2:])
s.recvline()
x=s.recvline()[18:][:-1]
x=int(x,16)
xylist.append((transform(i),x))
print "Values of the polynomial:"
for x in xylist:
print 'f('+str(x[0])+') = '+str(x[1])
coeffs=Lagrange_Polynomial(xylist)
print "\nReconstructed polynomial"
for i,x in zip([i for i in range(11)],coeffs.print_poly()):
print str(i)+':',x
print '\nStarting challenge 1!'
s.recvuntil('Command: ')
s.sendline('2')
for i in range(11):
s.recvuntil('Encrypt this: ')
x=int(s.recvline()[2:],16)
s.sendline(str(hex(coeffs.evaluate(transform(x),MOD))[2:]))
print 'step',i+1,'done'
print '\n'+s.recvline()
print s.recvline()
print '\nComposed polynomial'
f=Polynomial(coeffs.coeffs)
g=Polynomial([255,256])
f.compose(g)
for i,x in zip([i for i in range(11)],coeffs.print_poly()):
print str(i)+':',x
print '\nStarting challenge 2!'
s.recvuntil('Command: ')
s.sendline('3')
for i in range(11):
s.recvuntil('Decrypt this: ')
x=int(s.recvline()[2:],16)
print x
f=Polynomial(coeffs.coeffs)
g=Polynomial([255,256])
f.compose(g)
h=Polynomial(f.coeffs)
h.add_to_degree(0,-x)
d=f.calculate_derivative()
out=Hensel(h,d,11,128)
print coeffs.evaluate(transform(out[0]),MOD)
s.sendline(hex(out[0])[2:])
print 'step',i+1,'done'
print '\n'+s.recvline()
print s.recvline()