I am looking for ways to make this code more "pythonic" and any issues with my implementation. This was designed to solve Crytopal's Challenge 10. This challenge requires recreating the AES-CBC cipher using a library-provided AES-ECB function. PKCS7 padding is also used to align bytes to block boundaries.
from Cryptodome.Cipher import AES
def xor(in1, in2):
ret = []
for i in range(0, max(len(in1), len(in2))):
ret.append(in1[i % len(in1)] ^ in2[i % len(in2)])
return bytes(ret)
def decrypt_aes_ecb(data, key):
cipher = AES.new(key, AES.MODE_ECB)
return cipher.decrypt(data)
def encrypt_aes_ecb(data, key):
cipher = AES.new(key, AES.MODE_ECB)
return cipher.encrypt(data)
def pkcs7(val, block_size=16):
remaining = block_size - len(val) % block_size
if remaining == block_size:
remaining = 16
ret = val + chr(remaining).encode() * remaining
return ret
def unpkcs7(val, block_size=16):
pad_amount = val[-1]
if pad_amount == 0:
raise Exception
for i in range(len(val) - 1, len(val) - (pad_amount + 1), -1):
if val[i] != pad_amount:
raise Exception
return val[:-pad_amount]
def decrypt_aes_cbc(data, key, iv = b'\x00' * 16, pad=True):
prev_chunk = iv
decrypted = []
for i in range(0, len(data), 16):
chunk = data[i : i + 16]
decrypted += xor(decrypt_aes_ecb(chunk, key), prev_chunk)
prev_chunk = chunk
if pad:
return unpkcs7(bytes(decrypted))
return bytes(decrypted)
def encrypt_aes_cbc(data, key, iv = b'\x00' * 16, pad=True):
if pad:
padded = pkcs7(data)
else:
padded = data
prev_chunk = iv
encrypted = []
for i in range(0, len(padded), 16):
chunk = padded[i : i + 16]
encrypted_block = encrypt_aes_ecb(xor(chunk, prev_chunk), key)
encrypted += encrypted_block
prev_chunk = encrypted_block
return bytes(encrypted)
-
\$\begingroup\$ Links can rot. Please include a description of the challenge here in your question. \$\endgroup\$Vogel612– Vogel6122019年03月21日 15:45:28 +00:00Commented Mar 21, 2019 at 15:45
2 Answers 2
Try to follow PEP8. Just a few little things like when assigning defaults to arguments don't put spaces around the
=
; two spaces between functions and imports;[i : i + 16]
->[i: i + 16]
. Very minor stuff but adds up over a larger piece of code.Avoid assigning variables to small bits of logic if they're only going to be used once. e.g.:
ret = val + chr(remaining).encode() * remaining return ret
Could become:
return val + chr(remaining).encode() * remaining
Your names could be improved for clarity; avoid using vague names like key, var, data, etc.
I like ternary operators so I'd change
encrypt_aes_cbc()
to:padded = pkcs7(data) if pad else data
You can do something similar in
decrypt_aes_cbc()
with your returns.In fact, I'd extract the second part of the function to a generator and call it in
encrpyt_aes_cbc()
. By which I mean:def gen_encrypted(key, iv): """Maybe give this function a better name!""" for i in range(0, len(padded), 16): ... yield encrypted
Then change the second half of
encrypt_aes_cbc()
to call that generator. This nicely compartmentalises your code, removes the need for theencrypted
list and should be faster. I can explain this more thoroughly if you're confused aboutyield
.Your argument
block_size=16
inunpkcs7
is not used.Your indentation is off in
pkcs7
.Try to be specific when raising exceptions. What error exactly are you catching? Exceptions also takes a string argument so you can explain the problem. e.g.:
raise IndexError('This is not a valid position in some_list')
-
\$\begingroup\$ Especially the comment on using specific exceptions and error messages is enormously important! Nothing is more annoying than having to read error messages (if any) that tell you nothing apart from "It's broken!" \$\endgroup\$AlexV– AlexV2019年03月21日 21:09:57 +00:00Commented Mar 21, 2019 at 21:09
In addition to Hobo's good answer, you can simplify the xor
function
def xor(in1, in2): ret = [] for i in range(0, max(len(in1), len(in2))): ret.append(in1[i % len(in1)] ^ in2[i % len(in2)]) return bytes(ret)
You don't need the 0
in the range's first element because it will automatically start at 0
, secondly bytes
will accept any iterable so you can feed it a comprehension
def xor(in1, in2):
return bytes(in1[i % len(in1)] ^ in2[i % len(in2)] for i in range(max(len(in1), len(in2)))
You could also use cycle
from the itertools
module in combination with zip
from itertools import cycle
def xor2(in1, in2):
value, repeater = (in1, in2) if len(in1) > len(in2) else (in2, in1)
return bytes(v ^ r for v, r in zip(value, cycle(repeater)))
Explore related questions
See similar questions with these tags.