Last weekend I wrote some code to help people use PyCrypto in a canonical fashion to symmetrically encrypt and decrypt strings and files. By canonical I mean I've followed Colin Percival's advice in Cryptographic Right Answers.
The code isn't very verbose, and I'd appreciate any comments or questions about the code, the tests, and the documentation. After I get this review I'm going to submit this to Security Stack Exchange for a design review.
Source code and a subset of tests follow:
crypto.py:
# ---------------------------------------------------------------------------
# Copyright (c) 2012 Asim Ihsan (asim dot ihsan at gmail dot com)
# Distributed under the MIT/X11 software license, see the accompanying
# file license.txt or http://www.opensource.org/licenses/mit-license.php.
# ---------------------------------------------------------------------------
import os
import sys
import struct
import cStringIO as StringIO
import bz2
from Crypto.Cipher import AES
from Crypto.Hash import SHA256, HMAC
from Crypto.Protocol.KDF import PBKDF2
# ----------------------------------------------------------------------------
# Constants.
# ----------------------------------------------------------------------------
# Length of salts in bytes.
salt_length_in_bytes = 16
# Hash function to use in general.
hash_function = SHA256
# PBKDF pseudo-random function. Used to mix a password and a salt.
# See Crypto\Protocol\KDF.py
pbkdf2_prf = lambda p, s: HMAC.new(p, s, hash_function).digest()
# PBKDF count, number of iterations.
pbkdf2_count = 1000
# PBKDF derived key length.
pbkdf2_dk_len = 32
# ----------------------------------------------------------------------------
class HMACIsNotValidException(Exception):
pass
class InvalidFormatException(Exception):
def __init__(self, reason):
self.reason = reason
def __str__(self):
return repr(self.reason)
class CTRCounter:
""" Callable class that returns an iterating counter for PyCrypto
AES in CTR mode."""
def __init__(self, nonce):
""" Initialize the counter object.
@nonce An 8-byte binary string.
"""
assert(len(nonce)==8)
self.nonce = nonce
self.cnt = 0
def __call__(self):
""" Return the next 16 byte counter, as a binary string. """
right_half = struct.pack('>Q', self.cnt)
self.cnt += 1
return self.nonce + right_half
def encrypt_string(plaintext, key, compress=False):
plaintext_obj = StringIO.StringIO(plaintext)
ciphertext_obj = StringIO.StringIO()
encrypt_file(plaintext_obj, key, ciphertext_obj, compress=compress)
return ciphertext_obj.getvalue()
def decrypt_string(ciphertext, key):
plaintext_obj = StringIO.StringIO()
ciphertext_obj = StringIO.StringIO(ciphertext)
decrypt_file(ciphertext_obj, key, plaintext_obj)
return plaintext_obj.getvalue()
def decrypt_file(ciphertext_file_obj,
key,
plaintext_file_obj,
chunk_size=4096):
# ------------------------------------------------------------------------
# Unpack the header values from the ciphertext.
# ------------------------------------------------------------------------
header_format = ">HHHHHQ?H"
header_size = struct.calcsize(header_format)
header_string = ciphertext_file_obj.read(header_size)
try:
header = struct.unpack(header_format, header_string)
except struct.error:
raise InvalidFormatException("Header is invalid.")
pbkdf2_count = header[0]
pbkdf2_dk_len = header[1]
password_salt_size = header[2]
nonce_size = header[3]
hmac_salt_size = header[4]
ciphertext_size = header[5]
compress = header[6]
hmac_size = header[7]
# ------------------------------------------------------------------------
# Unpack everything except the ciphertext and HMAC, which are the
# last two strings in the ciphertext file.
# ------------------------------------------------------------------------
encrypted_string_format = ''.join([">",
"%ss" % (password_salt_size, ) ,
"%ss" % (nonce_size, ),
"%ss" % (hmac_salt_size, )])
encrypted_string_size = struct.calcsize(encrypted_string_format)
body_string = ciphertext_file_obj.read(encrypted_string_size)
try:
body = struct.unpack(encrypted_string_format, body_string)
except struct.error:
raise InvalidFormatException("Start of body is invalid.")
password_salt = body[0]
nonce = body[1]
hmac_salt = body[2]
# ------------------------------------------------------------------------
# ------------------------------------------------------------------------
# Prepare the HMAC with everything except the ciphertext.
#
# Notice we do not HMAC the ciphertext_size, just like the encrypt
# stage.
# ------------------------------------------------------------------------
hmac_password_derived = PBKDF2(password = key,
salt = hmac_salt,
dkLen = pbkdf2_dk_len,
count = pbkdf2_count,
prf = pbkdf2_prf)
elems_to_hmac = [str(pbkdf2_count),
str(pbkdf2_dk_len),
str(len(password_salt)),
password_salt,
str(len(nonce)),
nonce,
str(len(hmac_salt)),
hmac_salt]
hmac_object = HMAC.new(key = hmac_password_derived,
msg = ''.join(elems_to_hmac),
digestmod = hash_function)
# ------------------------------------------------------------------------
# ------------------------------------------------------------------------
# First pass: stream in the ciphertext object into the HMAC object
# and verify that the HMAC is correct.
#
# Notice we don't need to decompress anything here even if compression
# is in use. We're using Encrypt-Then-MAC.
# ------------------------------------------------------------------------
ciphertext_file_pos = ciphertext_file_obj.tell()
ciphertext_bytes_read = 0
while True:
bytes_remaining = ciphertext_size - ciphertext_bytes_read
current_chunk_size = min(bytes_remaining, chunk_size)
ciphertext_chunk = ciphertext_file_obj.read(current_chunk_size)
if ciphertext_chunk == '':
break
ciphertext_bytes_read += len(ciphertext_chunk)
hmac_object.update(ciphertext_chunk)
if ciphertext_bytes_read != ciphertext_size:
raise InvalidFormatException("first pass ciphertext_bytes_read %s != ciphertext_size %s" % (ciphertext_bytes_read, ciphertext_size))
# the rest of the file is the HMAC.
hmac = ciphertext_file_obj.read()
if len(hmac) != hmac_size:
raise InvalidFormatException("len(hmac) %s != hmac_size %s" % (len(hmac), hmac_size))
hmac_calculated = hmac_object.digest()
if hmac != hmac_calculated:
raise HMACIsNotValidException
# ------------------------------------------------------------------------
# ------------------------------------------------------------------------
# Second pass: stream in the ciphertext object and decrypt it into the
# plaintext object.
# ------------------------------------------------------------------------
cipher_password_derived = PBKDF2(password = key,
salt = password_salt,
dkLen = pbkdf2_dk_len,
count = pbkdf2_count,
prf = pbkdf2_prf)
cipher_ctr = AES.new(key = cipher_password_derived,
mode = AES.MODE_CTR,
counter = CTRCounter(nonce))
ciphertext_file_obj.seek(ciphertext_file_pos, os.SEEK_SET)
ciphertext_bytes_read = 0
if compress:
decompressor = bz2.BZ2Decompressor()
while True:
bytes_remaining = ciphertext_size - ciphertext_bytes_read
current_chunk_size = min(bytes_remaining, chunk_size)
ciphertext_chunk = ciphertext_file_obj.read(current_chunk_size)
end_of_file = ciphertext_chunk == ''
ciphertext_bytes_read += len(ciphertext_chunk)
plaintext_chunk = cipher_ctr.decrypt(ciphertext_chunk)
if compress:
try:
decompressed = decompressor.decompress(plaintext_chunk)
except EOFError:
decompressed = ""
plaintext_chunk = decompressed
plaintext_file_obj.write(plaintext_chunk)
if end_of_file:
break
if ciphertext_bytes_read != ciphertext_size:
raise InvalidFormatException("second pass ciphertext_bytes_read %s != ciphertext_size %s" % (ciphertext_bytes_read, ciphertext_size))
# ------------------------------------------------------------------------
def encrypt_file(plaintext_file_obj,
key,
ciphertext_file_obj,
chunk_size = 4096,
compress = False):
# ------------------------------------------------------------------------
# Prepare input key.
# ------------------------------------------------------------------------
password_salt = os.urandom(salt_length_in_bytes)
cipher_password_derived = PBKDF2(password = key,
salt = password_salt,
dkLen = pbkdf2_dk_len,
count = pbkdf2_count,
prf = pbkdf2_prf)
# ------------------------------------------------------------------------
# ------------------------------------------------------------------------
# Prepare cipher object.
# ------------------------------------------------------------------------
nonce_size = 8
nonce = os.urandom(nonce_size)
cipher_ctr = AES.new(key = cipher_password_derived,
mode = AES.MODE_CTR,
counter = CTRCounter(nonce))
# ------------------------------------------------------------------------
# ------------------------------------------------------------------------
# Prepare HMAC object, and hash what we have so far.
#
# Notice that we do not HMAC the size of the ciphertext. We don't
# know how big it'll be until we compress it, if we do, and we can't
# compress it without reading it into memory. So let the HMAC of
# the ciphertext itself do.
# ------------------------------------------------------------------------
hmac_salt = os.urandom(salt_length_in_bytes)
hmac_password_derived = PBKDF2(password = key,
salt = hmac_salt,
dkLen = pbkdf2_dk_len,
count = pbkdf2_count,
prf = pbkdf2_prf)
elems_to_hmac = [str(pbkdf2_count),
str(pbkdf2_dk_len),
str(len(password_salt)),
password_salt,
str(len(nonce)),
nonce,
str(len(hmac_salt)),
hmac_salt]
hmac_object = HMAC.new(key = hmac_password_derived,
msg = ''.join(elems_to_hmac),
digestmod = hash_function)
# ------------------------------------------------------------------------
# ------------------------------------------------------------------------
# Write in what we have so far into the output, ciphertext file.
#
# Given that the plaintext may be compressed we don't know what
# it's final length will be without compressing it, and we can't
# do this without reading it all into memory. Hence let's
# put 0 as the ciphertext length for now and fill it in after.
# ------------------------------------------------------------------------
header_format = ''.join([">",
"H", # PBKDF2 count
"H", # PBKDF2 derived key length
"H", # Length of password salt
"H", # Length of CTR nonce
"H", # Length of HMAC salt
"Q", # Length of ciphertext
"?", # Is compression used?
"H", # Length of HMAC
"%ss" % (len(password_salt), ) , # Password salt
"%ss" % (nonce_size, ), # CTR nonce
"%ss" % (len(hmac_salt), )]) # HMAC salt
header = struct.pack(header_format,
pbkdf2_count,
pbkdf2_dk_len,
len(password_salt),
len(nonce),
len(hmac_salt),
0, # This is the ciphertext size, wrong for now.
compress,
hmac_object.digest_size,
password_salt,
nonce,
hmac_salt)
ciphertext_file_obj.write(header)
# ------------------------------------------------------------------------
# ------------------------------------------------------------------------
# Stream in the input file and stream out ciphertext into the
# ciphertext file.
# ------------------------------------------------------------------------
ciphertext_size = 0
if compress:
compressor = bz2.BZ2Compressor()
while True:
plaintext_chunk = plaintext_file_obj.read(chunk_size)
end_of_file = plaintext_chunk == ''
if compress:
if end_of_file:
compressed = compressor.flush()
else:
compressed = compressor.compress(plaintext_chunk)
plaintext_chunk = compressed
ciphertext_chunk = cipher_ctr.encrypt(plaintext_chunk)
ciphertext_size += len(ciphertext_chunk)
ciphertext_file_obj.write(ciphertext_chunk)
hmac_object.update(ciphertext_chunk)
if end_of_file:
break
# ------------------------------------------------------------------------
# ------------------------------------------------------------------------
# Write the HMAC to the ciphertext file.
# ------------------------------------------------------------------------
hmac = hmac_object.digest()
ciphertext_file_obj.write(hmac)
# ------------------------------------------------------------------------
# ------------------------------------------------------------------------
# Go back to the header and update the ciphertext size.
#
# Notice that we capture the header such that the last unpacked
# element of the struct is the unsigned long long of the ciphertext
# length.
# ------------------------------------------------------------------------
# Read in.
ciphertext_file_obj.seek(0, os.SEEK_SET)
header_format = ">HHHHHQ"
header_size = struct.calcsize(header_format)
header = ciphertext_file_obj.read(header_size)
# Modify.
header_elems = list(struct.unpack(header_format, header))
header_elems[-1] = ciphertext_size
# Write out.
header = struct.pack(header_format, *header_elems)
ciphertext_file_obj.seek(0, os.SEEK_SET)
ciphertext_file_obj.write(header)
# ------------------------------------------------------------------------
test_crypto.py:
#!/usr/bin/env python2.7
# ---------------------------------------------------------------------------
# Copyright (c) 2012 Asim Ihsan (asim dot ihsan at gmail dot com)
# Distributed under the MIT/X11 software license, see the accompanying
# file license.txt or http://www.opensource.org/licenses/mit-license.php.
# ---------------------------------------------------------------------------
import os
import sys
import tempfile
import cStringIO as StringIO
src_path = os.path.abspath(os.path.join(__file__, os.pardir, os.pardir, "src"))
assert(os.path.isdir(src_path))
sys.path.append(src_path)
from utilities import crypto
from nose.tools import raises, assert_false, assert_true, assert_not_equal, assert_equal, assert_less
from nose.plugins.skip import SkipTest, Skip
_multiprocess_can_split_ = True
def test_encrypt_then_decrypt_string():
plaintext = "this is some text"
key = "this is my key"
ciphertext = crypto.encrypt_string(plaintext, key)
plaintext_after = crypto.decrypt_string(ciphertext, key)
assert_equal(plaintext_after, plaintext)
@raises(crypto.HMACIsNotValidException)
def test_encrypt_then_alter_raises_exception():
plaintext = "this is some text"
key = "this is my key"
ciphertext = crypto.encrypt_string(plaintext, key)
ciphertext = ciphertext[:-len(plaintext)] + '0円' * len(plaintext)
plaintext_after = crypto.decrypt_string(ciphertext, key)
@raises(crypto.InvalidFormatException)
def test_decrypt_then_damage_raises_exception():
plaintext = "this is some text"
key = "this is my key"
ciphertext = crypto.encrypt_string(plaintext, key)
ciphertext = ciphertext[:len(ciphertext)-5]
plaintext_after = crypto.decrypt_string(ciphertext, key)
@raises(crypto.HMACIsNotValidException)
def test_decrypt_with_wrong_password_raises_exception():
plaintext = "this is some text"
key = "this is my key"
ciphertext = crypto.encrypt_string(plaintext, key)
another_key = "this is my key 2"
plaintext_after = crypto.decrypt_string(ciphertext, another_key)
def test_encrypt_then_decrypt_empty_string():
plaintext = ""
key = "this is my key"
ciphertext = crypto.encrypt_string(plaintext, key)
plaintext_after = crypto.decrypt_string(ciphertext, key)
assert_equal(plaintext_after, plaintext)
def test_compressed_encrypt_then_decrypt_string():
plaintext = "X" * 4096
key = "this is my key"
ciphertext = crypto.encrypt_string(plaintext, key, compress=True)
assert_less(len(ciphertext), len(plaintext) / 10)
plaintext_after = crypto.decrypt_string(ciphertext, key)
assert_equal(plaintext, plaintext_after)
def test_compressed_encrypt_then_decrypt_random_string():
plaintext = os.urandom(1024 * 1024)
key = "this is my key"
ciphertext = crypto.encrypt_string(plaintext, key, compress=True)
plaintext_after = crypto.decrypt_string(ciphertext, key)
assert_equal(plaintext, plaintext_after)
def _test_encrypt_then_decrypt_file(plaintext_size,
chunk_size,
wrong_password=False,
alter_file=False,
truncate_header=False,
truncate_body=False,
compress=False):
tempfile.tempdir = os.path.join(__file__, os.pardir)
plaintext_file = tempfile.NamedTemporaryFile(delete=False)
ciphertext_file = tempfile.NamedTemporaryFile(delete=False)
plaintext_after_file = tempfile.NamedTemporaryFile(delete=False)
plaintext_filepath = plaintext_file.name
ciphertext_filepath = ciphertext_file.name
plaintext_after_filepath = plaintext_after_file.name
plaintext_file.close()
ciphertext_file.close()
plaintext_after_file.close()
try:
key = "this is my key"
# --------------------------------------------------------------------
# Write plaintext to file.
# --------------------------------------------------------------------
with open(plaintext_filepath, "wb") as f:
cnt = 0
while cnt < plaintext_size:
current_chunk_size = min(4096, plaintext_size - cnt)
f.write("X" * current_chunk_size)
cnt += current_chunk_size
# --------------------------------------------------------------------
# --------------------------------------------------------------------
# Encrypt plaintext file to ciphertext file.
#
# Notice that the output, ciphertext file requires read and
# write access.
# --------------------------------------------------------------------
with open(plaintext_filepath, "rb") as f_in:
with open(ciphertext_filepath, "rb+") as f_out:
crypto.encrypt_file(f_in,
key,
f_out,
chunk_size=chunk_size,
compress=compress)
# --------------------------------------------------------------------
# --------------------------------------------------------------------
# If wrong password then let's adjust the password.
# --------------------------------------------------------------------
if wrong_password:
key = "this is my key 2"
# --------------------------------------------------------------------
# --------------------------------------------------------------------
# If alter file then let's alter the file.
# If truncate_body then let's skip the last ten bytes from the
# file.
# If truncate_header then let's skip the first ten bytes from the
# file.
# --------------------------------------------------------------------
if alter_file:
with open(ciphertext_filepath, "rb+") as f:
f.seek(-10, os.SEEK_END)
f.write('0円' * 10)
if truncate_header or truncate_body:
with open(ciphertext_filepath, "rb") as f:
contents = f.read()
if truncate_header:
with open(ciphertext_filepath, "wb") as f:
f.write(contents[10:])
else:
with open(ciphertext_filepath, "wb") as f:
f.write(contents[:-10])
# --------------------------------------------------------------------
# --------------------------------------------------------------------
# Decrypt ciphertext file to another plaintext file.
# --------------------------------------------------------------------
with open(ciphertext_filepath, "rb") as f_in:
with open(plaintext_after_filepath, "wb") as f_out:
crypto.decrypt_file(f_in, key, f_out, chunk_size=chunk_size)
# --------------------------------------------------------------------
with open(plaintext_filepath, "rb") as f_original:
with open(plaintext_after_filepath, "rb") as f_after:
while True:
f_original_chunk = f_original.read(chunk_size)
f_after_chunk = f_after.read(chunk_size)
assert_equal(f_original_chunk, f_after_chunk)
if f_original_chunk == '':
break
finally:
for (file_obj, filepath) in [(plaintext_file, plaintext_filepath),
(ciphertext_file, ciphertext_filepath),
(plaintext_after_file, plaintext_after_filepath)]:
os.remove(filepath)
def test_encrypt_then_decrypt_file_normal():
_test_encrypt_then_decrypt_file(plaintext_size = 17,
chunk_size = 4096)
2 Answers 2
Well then, I guess no one is stepping up to review the code! I've been thinking it over and reading over the code and tests a few times and have come up with the following review comments. These comments apply to the git branch at commit 655694e.
- crypto.py line 171 compares the calculated HMAC with the correct HMAC using a linear-time comparison operator. This leaks potentially useful information via a timing attack if this function is used as part of a web service. Fix is to use Nate Lawson's constant-time comparison function, described by Coda Hale here.
- crypto.py lines 39 and 42. All different exceptions thrown by this library should derive from a base "BaseCryptoException". This allows callers to catch the base exception, in cases where they don't care why the decryption failed.
- crypto.py line 211. Why is the PBKDF2 iteration count (1000) hard-coded? I agree the caller maybe shouldn't be allowed to modify this, because we want to simply the API. But we may need to increase this number depending on the speed of the machine in use. Automatically determining a constant to use also seems too complicated. I suggest allowing the caller to set a value, with it set to 1000 by default. You'll also want to pack this as an unsigned integer ("I") rather than an unsigned short ("H") in order to allow> 65535 values.
- test_crypto.py lines 31 and 39. Damaging the ciphertext payload and then attempting decryption needs to be more comprehensive. I suggest testing that strings that have a bit-wise Levenshtein edit distance of one (i.e. all strings with one bit added, removed, or toggled) fail decryption either because the format is wrong or because the HMAC fails.
If anyone has anything else to add please let me know, thanks.
-
\$\begingroup\$ I'd rather say variable-time instead of linear-time. \$\endgroup\$CodesInChaos– CodesInChaos2015年12月05日 16:07:21 +00:00Commented Dec 5, 2015 at 16:07
Some thoughts -- mostly stylistic. I'm not qualified to comment on the crypto stuff!
I'm not sure why you need to have a package called
utilities
, given that it just contains a single file.utilities
is a pretty common name. Perhaps you could just have a module calledcrypto_utils
or something.You don't need to define
__init__
and__str__
forInvalidFormatException
-- you'll get these for free by inheriting fromException
.encrypt_file()
anddecrypt_file()
are both fairly long functions. Could you break them up so that it's easier to see what they're doing?Encrypting and decrypting are symmetric operations -- is there any way you could exploit the symmetry to share code? For example, could you reuse the patterns you use to pack and unpack data?
There's a bunch of repetition (eg
key = "this is my key"
) in your tests. Could you make your tests functions of a subclass ofunittest.TestCase
and provide a commonsetup()
function?I'm not a fan of the
@raises
decorator -- I think it's useful to know where the exception has been raised, since good tests serve as documentation.
Explore related questions
See similar questions with these tags.
_test_encrypt_then_decrypt_file
) is not good, because it get's out of sync quickly. I don't know about the encryption, but I know that the functiondecrypt_file
is far too long. You should split it up in multiple functions and maybie create a class for some parts of your code. \$\endgroup\$