题目
#!/usr/bin/env python3
from random import *
from binascii import *
from Crypto.Cipher import DES
from signal import *
import sys, os
from flag import flag
def die(*args):
pr(*args)
quit()
def pr(*args):
s = " ".join(map(str, args))
sys.stdout.write(s + "\n")
sys.stdout.flush()
def sc():
return sys.stdin.buffer.readline()
def shift(msg, l):
assert l < len(msg)
return msg[l:] + msg[:l]
def pad(text):
if len(text) % 8 != 0:
text += (b'\xff' * (8 - len(text) % 8))
return text
def encrypt(msg, key):
msg = pad(msg)
assert len(msg) % 8 == 0
assert len(key) == 8
des = DES.new(key, DES.MODE_ECB)
enc = des.encrypt(msg)
return enc
def main():
border = "+"
pr(border*72)
pr(border, ".:: Hi all, you should be fast, I mean super fact!! ::.", border)
pr(border, "You should send twenty 8-byte keys to encrypt the secret message and", border)
pr(border, "just decrypt the ciphertext to get the flag, Are you ready to start?", border)
pr(border*72)
secret_msg = b'TOP_SECRET:' + os.urandom(40)
cnt, STEP, KEYS = 0, 14, []
md = 1
while True:
pr(border, "please send your key as hex: ")
alarm(md + 1)
ans = sc().decode().strip()
alarm(0)
try:
key = unhexlify(ans)
if len(key) == 8 and key not in KEYS:
KEYS += [key]
cnt += 1
else:
die(border, 'Kidding me!? Bye!!')
except:
die(border, 'Your key is not valid! Bye!!')
if len(KEYS) == STEP:
print(KEYS)
HKEY = KEYS[:7]
shuffle(HKEY)
NKEY = KEYS[-7:]
shuffle(NKEY)
for h in HKEY: NKEY = [key, shift(key, 1)] + NKEY
enc = encrypt(secret_msg, NKEY[0])
for key in NKEY[1:]:
enc = encrypt(enc, key)
pr(border, f'enc = {hexlify(enc)}')
pr(border, f'Can you guess the secret message? ')
alarm(md + 1)
msg = sc().strip()
alarm(0)
if msg == hexlify(secret_msg):
die(border, f'Congrats, you deserve the flag: {flag}')
else:
die(border, f'Sorry, your input is incorrect! Bye!!')
if __name__ == '__main__':
main()
本题要求选手指定14个互不相同的DES 8bytes密钥(不过前7个在加密过程中完全没用到,所以实际上是指定7个)。我们将这7个密钥记为k0~k6,进而生成如下21个密钥:
k6
shift(k6, 1)
k6
shift(k6, 1)
k6
shift(k6, 1)
k6
shift(k6, 1)
k6
shift(k6, 1)
k6
shift(k6, 1)
k6
shift(k6, 1)
shuffle(k1~k6)[0]
shuffle(k1~k6)[1]
shuffle(k1~k6)[2]
shuffle(k1~k6)[3]
shuffle(k1~k6)[4]
shuffle(k1~k6)[5]
shuffle(k1~k6)[6]
将secret用这21个密钥迭代地进行加密。给出最终的密文,要求选手在非常短的时间内猜出secret的值。
查阅维基百科-Weak key,发现DES存在所谓的semi-weak keys,即满足性质 $E_{k_1}(E_{k_2}(M))=M$ 的密钥。
6 个常见的部分弱密钥对
- 0x011F011F010E010E and 0x1F011F010E010E01
- 0x01E001E001F101F1 and 0xE001E001F101F101
- 0x01FE01FE01FE01FE and 0xFE01FE01FE01FE01
- 0x1FE01FE00EF10EF1 and 0xE01FE01FF10EF10E
- 0x1FFE1FFE0EFE0EFE and 0xFE1FFE1FFE0EFE0E
- 0xE0FEE0FEF1FEF1FE and 0xFEE0FEE0FEF1FEF1
观察发现这些密钥对恰好满足题中的shift关系。所以当我们将k6设置为其中一个semi-weak key,那么前14轮相当于什么也没做。接下来只需要穷举最后7个密钥的组合,即穷举 $7!=5040$ 种情况即可。
由于题目给的时间限制比较短,可以用以空间换时间的方式,考虑建立字典:用已知的第一个block作为明文,预计算好加密结果和7个密钥的顺序之间的映射关系,这样当远程给到我们明文的时候,可以通过查表迅速得到密钥的顺序,按逆序依次解密即可得到secret。
solve.py
from pwn import remote
import ast
from random import randbytes
from Crypto.Cipher import DES
from get_dict import get_dict
def decrypt(ct, key):
assert len(ct) % 8 == 0
assert len(key) == 8
des = DES.new(key, DES.MODE_ECB)
enc = des.decrypt(ct)
return enc
key_dict = get_dict()
KEYS1 = [randbytes(8).hex().encode() for _ in range(7)]
KEYS2 = [b'aa'*8, b'bb'*8, b'cc'*8, b'dd'*8, b'ee'*8, b'ff'*8, b'011f011f010e010e']
KEYS = KEYS1 + KEYS2
io = remote("3.75.180.117", 37773, level="debug")
io.recvuntil(b"+ please send your key as hex:")
for key in KEYS:
io.sendline(key)
io.recvline()
io.recvuntil("+")
enc_hex = ast.literal_eval(io.recvline().decode().split('=')[1]).decode()
enc = bytes.fromhex(enc_hex)
print(key_dict[enc_hex[:16]])
for key in key_dict[enc_hex[:16]][::-1]:
enc = decrypt(enc, key)
io.sendline(enc.hex().encode()[:-10])
io.interactive()
# + Congrats, you deserve the flag: b'MAPNA{DES_h4s_A_f3W_5pec1f!c_kEys_7eRm3d_we4K_k3Ys_And_Sem1-wE4k_KeY5!}'
get_dict.py
from random import randbytes
from Crypto.Cipher import DES
def pad(text):
if len(text) % 8 != 0:
text += (b'\xff' * (8 - len(text) % 8))
return text
def encrypt(msg, key):
msg = pad(msg)
assert len(msg) % 8 == 0
assert len(key) == 8
des = DES.new(key, DES.MODE_ECB)
enc = des.encrypt(msg)
return enc
import itertools
def generate_permutations(lst):
permutations = itertools.permutations(lst)
permutations_list = list(permutations)
return permutations_list
KEYS1 = [randbytes(8) for _ in range(7)]
KEYS2 = [b'aa'*8, b'bb'*8, b'cc'*8, b'dd'*8, b'ee'*8, b'ff'*8, b'011f011f010e010e']
KEYS2 = [bytes.fromhex(k.decode()) for k in KEYS2]
KEYS = KEYS1 + KEYS2
permutations = generate_permutations(KEYS2)
def get_dict():
res = {}
for keys in permutations:
ct = b'TOP_SECR'
for key in keys:
ct = encrypt(ct, key)
res[ct.hex()] = keys
return res