For educational purposes, I have implemented the AES block cipher in python. I decided to follow the interface for block cipher modules as defined in PEP 272. The implementation consists of two python files, aes.py
and block_cipher.py
aes.py (~300 lines of code)
# coding: utf-8
"""
Advanced Encryption Standard.
The implementations of mix_columns() and inv_mix_columns() use cl_mul with
hardcoded factors in order to prevent side channel attacks
"""
from block_cipher import BlockCipher, BlockCipherWrapper
from block_cipher import MODE_ECB, MODE_CBC, MODE_CFB, MODE_OFB, MODE_CTR
__all__ = [
'new', 'block_size', 'key_size',
'MODE_ECB', 'MODE_CBC', 'MODE_CFB', 'MODE_OFB', 'MODE_CTR'
]
SBOX = (
0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5,
0x30, 0x01, 0x67, 0x2b, 0xfe, 0xd7, 0xab, 0x76,
0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0,
0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0,
0xb7, 0xfd, 0x93, 0x26, 0x36, 0x3f, 0xf7, 0xcc,
0x34, 0xa5, 0xe5, 0xf1, 0x71, 0xd8, 0x31, 0x15,
0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 0x9a,
0x07, 0x12, 0x80, 0xe2, 0xeb, 0x27, 0xb2, 0x75,
0x09, 0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0,
0x52, 0x3b, 0xd6, 0xb3, 0x29, 0xe3, 0x2f, 0x84,
0x53, 0xd1, 0x00, 0xed, 0x20, 0xfc, 0xb1, 0x5b,
0x6a, 0xcb, 0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf,
0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85,
0x45, 0xf9, 0x02, 0x7f, 0x50, 0x3c, 0x9f, 0xa8,
0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5,
0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2,
0xcd, 0x0c, 0x13, 0xec, 0x5f, 0x97, 0x44, 0x17,
0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73,
0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88,
0x46, 0xee, 0xb8, 0x14, 0xde, 0x5e, 0x0b, 0xdb,
0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x5c,
0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79,
0xe7, 0xc8, 0x37, 0x6d, 0x8d, 0xd5, 0x4e, 0xa9,
0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08,
0xba, 0x78, 0x25, 0x2e, 0x1c, 0xa6, 0xb4, 0xc6,
0xe8, 0xdd, 0x74, 0x1f, 0x4b, 0xbd, 0x8b, 0x8a,
0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e,
0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e,
0xe1, 0xf8, 0x98, 0x11, 0x69, 0xd9, 0x8e, 0x94,
0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf,
0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68,
0x41, 0x99, 0x2d, 0x0f, 0xb0, 0x54, 0xbb, 0x16,
)
INV_SBOX = (
0x52, 0x09, 0x6a, 0xd5, 0x30, 0x36, 0xa5, 0x38,
0xbf, 0x40, 0xa3, 0x9e, 0x81, 0xf3, 0xd7, 0xfb,
0x7c, 0xe3, 0x39, 0x82, 0x9b, 0x2f, 0xff, 0x87,
0x34, 0x8e, 0x43, 0x44, 0xc4, 0xde, 0xe9, 0xcb,
0x54, 0x7b, 0x94, 0x32, 0xa6, 0xc2, 0x23, 0x3d,
0xee, 0x4c, 0x95, 0x0b, 0x42, 0xfa, 0xc3, 0x4e,
0x08, 0x2e, 0xa1, 0x66, 0x28, 0xd9, 0x24, 0xb2,
0x76, 0x5b, 0xa2, 0x49, 0x6d, 0x8b, 0xd1, 0x25,
0x72, 0xf8, 0xf6, 0x64, 0x86, 0x68, 0x98, 0x16,
0xd4, 0xa4, 0x5c, 0xcc, 0x5d, 0x65, 0xb6, 0x92,
0x6c, 0x70, 0x48, 0x50, 0xfd, 0xed, 0xb9, 0xda,
0x5e, 0x15, 0x46, 0x57, 0xa7, 0x8d, 0x9d, 0x84,
0x90, 0xd8, 0xab, 0x00, 0x8c, 0xbc, 0xd3, 0x0a,
0xf7, 0xe4, 0x58, 0x05, 0xb8, 0xb3, 0x45, 0x06,
0xd0, 0x2c, 0x1e, 0x8f, 0xca, 0x3f, 0x0f, 0x02,
0xc1, 0xaf, 0xbd, 0x03, 0x01, 0x13, 0x8a, 0x6b,
0x3a, 0x91, 0x11, 0x41, 0x4f, 0x67, 0xdc, 0xea,
0x97, 0xf2, 0xcf, 0xce, 0xf0, 0xb4, 0xe6, 0x73,
0x96, 0xac, 0x74, 0x22, 0xe7, 0xad, 0x35, 0x85,
0xe2, 0xf9, 0x37, 0xe8, 0x1c, 0x75, 0xdf, 0x6e,
0x47, 0xf1, 0x1a, 0x71, 0x1d, 0x29, 0xc5, 0x89,
0x6f, 0xb7, 0x62, 0x0e, 0xaa, 0x18, 0xbe, 0x1b,
0xfc, 0x56, 0x3e, 0x4b, 0xc6, 0xd2, 0x79, 0x20,
0x9a, 0xdb, 0xc0, 0xfe, 0x78, 0xcd, 0x5a, 0xf4,
0x1f, 0xdd, 0xa8, 0x33, 0x88, 0x07, 0xc7, 0x31,
0xb1, 0x12, 0x10, 0x59, 0x27, 0x80, 0xec, 0x5f,
0x60, 0x51, 0x7f, 0xa9, 0x19, 0xb5, 0x4a, 0x0d,
0x2d, 0xe5, 0x7a, 0x9f, 0x93, 0xc9, 0x9c, 0xef,
0xa0, 0xe0, 0x3b, 0x4d, 0xae, 0x2a, 0xf5, 0xb0,
0xc8, 0xeb, 0xbb, 0x3c, 0x83, 0x53, 0x99, 0x61,
0x17, 0x2b, 0x04, 0x7e, 0xba, 0x77, 0xd6, 0x26,
0xe1, 0x69, 0x14, 0x63, 0x55, 0x21, 0x0c, 0x7d,
)
round_constants = (0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36)
block_size = 16
key_size = None
def new(key, mode, IV=None, **kwargs) -> BlockCipherWrapper:
if mode in (MODE_CBC, MODE_CFB, MODE_OFB) and IV is None:
raise ValueError("This mode requires an IV")
cipher = BlockCipherWrapper()
cipher.block_size = block_size
cipher.IV = IV
cipher.mode = mode
cipher.cipher = AES(key)
if mode == MODE_CFB:
cipher.segment_size = kwargs.get('segment_size', block_size * 8)
elif mode == MODE_CTR:
counter = kwargs.get('counter')
if counter is None:
raise ValueError("CTR mode requires a callable counter object")
cipher.counter = counter
return cipher
class AES(BlockCipher):
def __init__(self, key: bytes):
self.key = key
self.Nk = len(self.key) // 4 # words per key
if self.Nk not in (4, 6, 8):
raise ValueError("invalid key size")
self.Nr = self.Nk + 6
self.Nb = 4 # words per block
self.state: list[list[int]] = []
# raise NotImplementedError
# key schedule
self.w: list[list[int]] = []
for i in range(self.Nk):
self.w.append(list(key[4*i:4*i+4]))
for i in range(self.Nk, self.Nb*(self.Nr+1)):
tmp: list[int] = self.w[i-1]
q, r = divmod(i, self.Nk)
if not r:
tmp = self.sub_word(self.rot_word(tmp))
tmp[0] ^= round_constants[q-1]
elif self.Nk > 6 and r == 4:
tmp = self.sub_word(tmp)
self.w.append(
[a ^ b for a, b in zip(self.w[i-self.Nk], tmp)]
)
def encrypt_block(self, block: bytes) -> bytes:
self.set_state(block)
self.add_round_key(0)
for r in range(1, self.Nr):
self.sub_bytes()
self.shift_rows()
self.mix_columns()
self.add_round_key(r)
self.sub_bytes()
self.shift_rows()
self.add_round_key(self.Nr)
return self.get_state()
def decrypt_block(self, block: bytes) -> bytes:
self.set_state(block)
self.add_round_key(self.Nr)
for r in range(self.Nr-1, 0, -1):
self.inv_shift_rows()
self.inv_sub_bytes()
self.add_round_key(r)
self.inv_mix_columns()
self.inv_shift_rows()
self.inv_sub_bytes()
self.add_round_key(0)
return self.get_state()
@staticmethod
def rot_word(word: list[int]):
# for key schedule
return word[1:] + word[:1]
@staticmethod
def sub_word(word: list[int]):
# for key schedule
return [SBOX[b] for b in word]
def set_state(self, block: bytes):
self.state = [
list(block[i:i+4])
for i in range(0, 16, 4)
]
def get_state(self) -> bytes:
return b''.join(
bytes(col)
for col in self.state
)
def add_round_key(self, r: int):
round_key = self.w[r*self.Nb:(r+1)*self.Nb]
for col, word in zip(self.state, round_key):
for row_index in range(4):
col[row_index] ^= word[row_index]
def mix_columns(self):
for i, word in enumerate(self.state):
new_word = []
for j in range(4):
# element wise cl mul with constants 2, 3, 1, 1
value = (word[0] << 1)
value ^= (word[1] << 1) ^ word[1]
value ^= word[2] ^ word[3]
# polynomial reduction in constant time
value ^= 0x11b & -(value >> 8)
new_word.append(value)
# rotate word in order to match the matrix multiplication
word = self.rot_word(word)
self.state[i] = new_word
def inv_mix_columns(self):
for i, word in enumerate(self.state):
new_word = []
for j in range(4):
# element wise cl mul with constants 0xe, 0xb, 0xd, 0x9
value = (word[0] << 3) ^ (word[0] << 2) ^ (word[0] << 1)
value ^= (word[1] << 3) ^ (word[1] << 1) ^ word[1]
value ^= (word[2] << 3) ^ (word[2] << 2) ^ word[2]
value ^= (word[3] << 3) ^ word[3]
# polynomial reduction in constant time
value ^= (0x11b << 2) & -(value >> 10)
value ^= (0x11b << 1) & -(value >> 9)
value ^= 0x11b & -(value >> 8)
new_word.append(value)
# rotate word in order to match the matrix multiplication
word = self.rot_word(word)
self.state[i] = new_word
def shift_rows(self):
for row_index in range(4):
row = [
col[row_index] for col in self.state
]
row = row[row_index:] + row[:row_index]
for col_index in range(4):
self.state[col_index][row_index] = row[col_index]
def inv_shift_rows(self):
for row_index in range(4):
row = [
col[row_index] for col in self.state
]
row = row[-row_index:] + row[:-row_index]
for col_index in range(4):
self.state[col_index][row_index] = row[col_index]
def sub_bytes(self):
for col in self.state:
for row_index in range(4):
col[row_index] = SBOX[col[row_index]]
def inv_sub_bytes(self):
for col in self.state:
for row_index in range(4):
col[row_index] = INV_SBOX[col[row_index]]
def print_state(self):
# debug function
for row_index in range(4):
print(' '.join(f'{col[row_index]:02x}' for col in self.state))
print()
def main():
key = bytes.fromhex('2b 7e 15 16 28 ae d2 a6 ab f7 15 88 09 cf 4f 3c')
cipher = new(key, MODE_ECB)
plain_text = bytes.fromhex('32 43 f6 a8 88 5a 30 8d 31 31 98 a2 e0 37 07 34')
print(plain_text)
cipher_text = cipher.encrypt(plain_text)
print(cipher_text)
print(cipher.decrypt(cipher_text))
if __name__ == '__main__':
main()
block_cipher.py (~200 lines of code)
# coding: utf-8
"""
this module provides the following classes:
BlockCipher
Counter
BlockCipherWrapper
"""
MODE_ECB = 1
MODE_CBC = 2
MODE_CFB = 3
MODE_PGP = 4 # optional
MODE_OFB = 5
MODE_CTR = 6
class BlockCipher:
def encrypt_block(self, block: bytes) -> bytes:
raise NotImplementedError
def decrypt_block(self, block: bytes) -> bytes:
raise NotImplementedError
# <removed lines of code>
class Counter:
def __init__(self, nonce: bytes, block_size: int, byte_order='big'):
self.nonce = nonce
self.counter_size = block_size - len(nonce)
self.byte_order = byte_order
self.counter = 0
def __call__(self) -> bytes:
out = self.nonce + self.counter.to_bytes(
self.counter_size, self.byte_order
)
self.counter += 1
return out
class BlockCipherWrapper:
def __init__(self):
"""initiate instance attributes."""
# PEP 272 required attributes
self.block_size: int = NotImplemented # measured in bytes
self.IV: bytes = NotImplemented # initialization vector
# other attributes
self.mode: int = NotImplemented
self.cipher: BlockCipher = NotImplemented
self.counter: Counter = NotImplemented
self.segment_size: int = NotImplemented
def encrypt(self, byte_string: bytes) -> bytes:
if self.mode == MODE_CFB and len(byte_string) * 8 % self.segment_size:
raise ValueError("message length doesn't match segment size")
if self.mode == MODE_CFB and self.segment_size & 7:
raise NotImplementedError
if self.mode != MODE_CFB and len(byte_string) % self.block_size:
raise ValueError("message length doesn't match block size")
blocks = [
byte_string[i:i+self.block_size]
for i in range(0, len(byte_string), self.block_size)
]
if self.mode == MODE_ECB:
return b''.join([
self.cipher.encrypt_block(block) for block in blocks
])
elif self.mode == MODE_CBC:
cipher_blocks = [self.IV]
for block in blocks:
cipher_blocks.append(
self.cipher.encrypt_block(
self.xor(block, cipher_blocks[-1])
)
)
return b''.join(cipher_blocks[1:])
elif self.mode == MODE_CFB:
s = self.segment_size >> 3
cipher = b''
current_input = self.IV
while byte_string:
cipher += self.xor(
byte_string[:s],
self.cipher.encrypt_block(current_input)[:s]
)
byte_string = byte_string[s:]
current_input = current_input[s:] + cipher[-s:]
return cipher
elif self.mode == MODE_PGP:
raise NotImplementedError
elif self.mode == MODE_OFB:
last_output = self.IV
cipher_blocks = [self.IV]
for block in blocks:
last_output = self.cipher.encrypt_block(last_output)
cipher_blocks.append(self.xor(block, last_output))
return b''.join(cipher_blocks[1:])
elif self.mode == MODE_CTR:
cipher_blocks = []
for block in blocks:
ctr = self.counter()
if len(ctr) != self.block_size:
raise ValueError("counter has the wrong size")
cipher_blocks.append(
self.xor(self.cipher.encrypt_block(ctr), block)
)
return b''.join(cipher_blocks)
else:
raise NotImplementedError("This mode is not supported")
def decrypt(self, byte_string: bytes) -> bytes:
if self.mode == MODE_CFB and len(byte_string) * 8 % self.segment_size:
raise ValueError("message length doesn't match segment size")
if self.mode == MODE_CFB and self.segment_size & 7:
raise NotImplementedError
if self.mode != MODE_CFB and len(byte_string) % self.block_size:
raise ValueError("message length doesn't match block size")
# split up into blocks
blocks = [
byte_string[i:i+self.block_size]
for i in range(0, len(byte_string), self.block_size)
]
if self.mode == MODE_ECB:
return b''.join([
self.cipher.decrypt_block(block)
for block in blocks
])
elif self.mode == MODE_CBC:
plain_blocks = []
blocks.insert(0, self.IV)
for i in range(1, len(blocks)):
plain_blocks.append(self.xor(
self.cipher.decrypt_block(blocks[i]), blocks[i-1]
))
return b''.join(plain_blocks)
elif self.mode == MODE_CFB:
s = self.segment_size >> 3
plain = b''
current_input = self.IV
while byte_string:
plain += self.xor(
byte_string[:s],
self.cipher.encrypt_block(current_input)[:s]
)
current_input = current_input[s:] + byte_string[:s]
byte_string = byte_string[s:]
return plain
elif self.mode == MODE_PGP:
raise NotImplementedError("PGP mode is not supported")
elif self.mode == MODE_OFB:
return self.encrypt(byte_string)
elif self.mode == MODE_CTR:
return self.encrypt(byte_string)
else:
raise ValueError("unknown mode")
def xor(self, block1, block2):
size = (
self.segment_size >> 3
if self.mode == MODE_CFB
else self.block_size
)
if not (len(block1) == len(block2) == size):
raise ValueError(str(size))
return bytes([block1[i] ^ block2[i] for i in range(size)])
def main():
pass
if __name__ == '__main__':
main()
2 Answers 2
That is a lot of code to review, so this review will only touch on more superficial things.
- I'm not familiar with the pattern of modifying
__all__
, but it seems like using an underscore at the start of names which should not be available from the outside would be a less "magic" way of achieving the same thing. - black and isort can help make the files more idiomatic, even if that means the very long lists will be linearized.
- Stricter type annotations (for all parameters and late-initialized variables) would make the code more self-documenting.
- pylint and flake8 can notify you of other non-idiomatic code, such as top-level variable names which are not all uppercase.
key_size
doesn't seem to be used for anything. Edit: Since it's required by PEP 272, I would suggest adding a comment to this effect. Otherwise it's pretty likely to be removed by the next developer unless its presence is actually enforced by commit hooks and/or CI.- Some of the variable and field names are not useful, like all the single letter ones and
tmp
. - Does "inv" in
INV_SBOX
stand for "inverse"? If so, can this variable be generated fromSBOX
? - Several things here could usefully be
enum
s, such asSBOX
and the modes. @staticmethod
is a code smell (for anyone tempted to comment, that doesn't always mean it's a bad thing). Often it's done because the IDE suggests it, but in my experience this is usually not helpful. Either because there's no reason to ever call it directly or because it is such a generic piece of code (often idempotent input transformation) that it really belongs outside the class. In this case,rot_word
andsub_word
are not actually called on the class. Whether they belong as functions or internal methods would depend on how they are likely to be used and reused.set_state
andget_state
both transform the state. I would normally expect only one of them to actually transform it.- The stuff in aes.py's
main
function belongs in a test case. I would expectmain
to parse arguments and to dispatch an action for the rest of the code. - The
BlockCipher
class doesn't do anything, and there is only one child class, so it looks like it can be removed without losing anything. xor
seems like a strange name for a method which doesn't justreturn first ^ second
. Is there a name which more closely says what it actually does?- block_cipher.py has a no-op
main
function. I'm sure it's meant to eventually contain something, but I would normally reject that as part of a merge request because it's effectively dead code.
-
\$\begingroup\$ thanks for your answer! you're right,
key_size
isn't used for anything but PEP 272 requires it for compatibility with other cipher modules / implementations. thanks for the tool recommendations, I'll check them out \$\endgroup\$Aemyl– Aemyl2020年12月23日 21:06:17 +00:00Commented Dec 23, 2020 at 21:06 -
\$\begingroup\$
xor
does just returnfirst^second
after some checking. \$\endgroup\$President James K. Polk– President James K. Polk2020年12月24日 16:14:20 +00:00Commented Dec 24, 2020 at 16:14
Advanced Encryption Standard.
Please simply include a well defined reference to the standard (it's FIPS 197 by NIST). If you do use any of the standard notations in there, indicate that as well, it will make your code much easier to understand.
'MODE_ECB', 'MODE_CBC', 'MODE_CFB', 'MODE_OFB', 'MODE_CTR'
Modes of operation are not part of the block cipher itself. I'm not that familiar with the Python language, but if possible the import of these modes should be avoided for the block cipher implementation itself.
INV_SBOX = (
The blocks are nicely formed and well named. But I wonder why there are so many double empty lines before some statements, and this one is missing a single one.
def new(key, mode, IV=None, **kwargs) -> BlockCipherWrapper:
There should be no need to repeat this for each and every cipher you implement. Also, the more modes are added, the more this will expand. What about CCM, EAX, GCM, SIV, disk modes? This is where you can have a Mode
class and perform the initialization within the implementations of that class.
if self.Nk not in (4, 6, 8):
Better define constants for 128, 192 and 256 and then use a divide by BYTE_SIZE
or just 8
if you think that's obvious enough.
# key schedule
Here you could point to the right chapter in the standard. The code is clear in this section when it comes to how the algorithm is implemented, but you're not really explaining what you are doing (apart from performing the key schedule, obviously).
The AES implementation itself is nice and well structured enough. The names are fine, the code looks fine. I'd add a few more lines about the "what", but since it is unlikely to change much and the function of these methods is well understood, you can be excused for not including it.
Generally I would take a look at when you are including empty lines. The empty line directly after a method declaration doesn't look good to me, especially since you yourself choose to ignore it for comments. This makes the code look slightly unbalanced to me. I'd just directly start after the method declaration myself, less is more.
MODE_PGP = 4 # optional
What's optional? Optional for who?
def encrypt(self, byte_string: bytes) -> bytes:
Here and during decrypt you are falling in the same trap. Use separate classes and maybe defer to them, but please do not implement all modes in a single method. Check for instance how the Bouncy Castle lightweight API is implemented.
Also, it isn't clear to me if you need to handle a full message at once or not.
elif self.mode == MODE_PGP: raise NotImplementedError
Either leave it out or implement it, but don't leave it hanging. Note that there doesn't seem to be any way to test if the implementation is available. What about returning a set of enumerations to see which modes are available?
def xor(self, block1, block2):
This is a bit of a mixed bag method. You should at the very least make it private (__xor
).
Explore related questions
See similar questions with these tags.