from Flag import FLAG
from Crypto.Util.number import bytes_to_long, isPrime, getPrime
from random import randint
assert len(FLAG) == 20 + 24
flag = bytes_to_long(FLAG)
def get_prime(n = 32, RANGE = 16):
a, p = 1, 1
for _ in range(n): p *= randint(1 << RANGE, 1 << (RANGE + 1))
while not isPrime(a * p + 1): a = getPrime(RANGE * RANGE)
return a * p + 1
p = get_prime()
g = bytes_to_long(b"20" + b"24")
with open(r"./crypto3-BabyDLog/SourceCode/output.txt", 'w') as f:
print(f"{p = }", file=f)
print(f"{g = }", file=f)
print(f"{pow(g, flag, p) = }", file=f)
典型的 Pohlig-Hellman 算法:
flag
的大小,所以其实我们不必要把 flag
对每个 $p_i^{e_i}$ 的模数求出来就可以做 CRT。with open(r"./crypto3-BabyDLog/SourceCode/output.txt", "r") as f:
p, G, H = [int(line.split("=")[1]) for line in f.readlines()]
print(f"{p = }")
print(f"{G = }")
print(f"{H = }")
from sympy import factorint, prod
from Crypto.Util.number import long_to_bytes
from itertools import product
def get_order(m, factorList, p):
newfactorList = {}
for q, alpha in factorList.items():
b = pow(m, (p - 1) // pow(q, alpha), p)
for j in range(alpha):
if pow(b, pow(q, j + 1), p) == 1: newfactorList[q] = j + 1; break
return newfactorList
def naive_dlp(g, h, p, r):
if r.bit_length() > 32: print(r.bit_length()); return [0]
res = []
for i in range(r):
if pow(g, i, p) == h % p:
res.append(i)
return res
def partial_dlp(g, h, q, alpha, order):
e, base = order, 1
new_g = pow(g, order // q, p)
res = [0]
while alpha:
assert e % q == 0
e = e // q
tmp_res = []
for ri in res:
new_h = pow(h * pow(g, - ri, p), e, p)
for pd in naive_dlp(new_g, new_h, p, q):
tmp_res.append(ri + pd * base)
base *= q
alpha -= 1
res = tmp_res.copy()
return res
def crt_solve(modList, primeExpList):
mod = prod(modList)
return sum(x * pow(mod // pa, -1, pa) * (mod // pa) for x, pa in zip(primeExpList, modList)) % mod
factorList = factorint(p - 1)
newfactorList = get_order(G, factorList, p)
neworder = prod(q ** alpha for q, alpha in newfactorList.items())
primeExpList = product(*[
partial_dlp(G, H, q, alpha, neworder) for q, alpha in newfactorList.items()
])
modList = [
pow(p, alpha) for p, alpha in newfactorList.items() if p.bit_length() < 33
]
primeExpList = list(map(lambda l: l[:len(modList)], primeExpList))
for pl in primeExpList:
msg = long_to_bytes(crt_solve(modList, pl))
if b"ucatflags" in msg: print(msg); print(primeExpList.index(pl))
这题出拐了,,,