I'd like to have some feedback on this small snippet I wrote to implement a base64 encoder and decoder.
Specifically, I'm not sure I'm handling padding in the best way possible.
import string
values = string.ascii_uppercase + string.ascii_lowercase + string.digits + '+/'
ALPHABET = dict(zip(range(64), values))
def encode(unpadded_buf):
padded_buf = unpadded_buf[:]
if len(padded_buf) % 3 != 0:
padded_buf += [0] * (3 - len(unpadded_buf) % 3)
encoded_padded_buf = []
for i in range(0, len(unpadded_buf), 3):
encoded_padded_buf += [ padded_buf[i] >> 2 ]
encoded_padded_buf += [ (padded_buf[i] & 3) << 4 | padded_buf[i + 1] >> 4 ]
encoded_padded_buf += [ (padded_buf[i + 1] & 15) << 2 | padded_buf[i + 2] >> 6 ]
encoded_padded_buf += [ padded_buf[i + 2] & 63]
asciied = ''.join(map(lambda n: ALPHABET[n], encoded_padded_buf))
if len(padded_buf) - len(unpadded_buf) == 2:
return asciied[:-2] + '=='
elif len(padded_buf) - len(unpadded_buf) == 1:
return asciied[:-1] + '='
return asciied[:]
def decode(asciied):
_alphabet = dict(map(lambda p: (p[1], p[0]), ALPHABET.items()))
encoded_buf = list(map(lambda c: _alphabet.get(c, 0), asciied))
decoded_buf = []
for i in range(0, len(encoded_buf), 4):
decoded_buf += [ encoded_buf[i] << 2 | encoded_buf[i + 1] >> 4 ]
decoded_buf += [ (encoded_buf[i + 1] & 15) << 4 | encoded_buf[i + 2] >> 2 ]
decoded_buf += [ (encoded_buf[i + 2] & 3) << 6 | encoded_buf[i + 3] ]
return decoded_buf
if __name__ == '__main__':
assert bytes(decode(encode(list('foo'.encode())))).decode() == 'foo'
-
2\$\begingroup\$ Are you aware that Python has base64 support built in? \$\endgroup\$200_success– 200_success2016年07月25日 21:15:05 +00:00Commented Jul 25, 2016 at 21:15
-
\$\begingroup\$ @200_success yes \$\endgroup\$Giuseppe Crinò– Giuseppe Crinò2016年07月26日 08:29:42 +00:00Commented Jul 26, 2016 at 8:29
1 Answer 1
Interface
assert bytes(decode(encode(list('foo'.encode())))).decode() == 'foo'
To avoid using encode
on strings or decode
on bytes, you can use byte literals: b'foo'
. This simplifies the writting to:
assert bytes(decode(encode(list(b'foo')))) == b'foo'
But then this let me wonder, why do the caller have to convert the arguments themselves? Why the need for list
and bytes
. A proper usage would be
assert decode(encode(b'foo')) == b'foo'
Which might let someone wonder decode to/encode from what? But assuming the module is properly named, that should not be an issue.
Now onto building functions that accept/returns bytes
Iteration and padding
When it comes to iterating over something in Python, it is deemed inelegant (and somewhat slower) to use indices rather than directly accessing content. Meaning for element in array:
is way better than for i in range(len(array)): element = array[i]
. Here you can argue that you can't easily iterate over groups of 3 or 4 elements at a time, but its forgetting about the itertools
module and its variety of recipes.
grouper
is the one you need:
from itertools import zip_longest
def grouper(iterable, n, fillvalue=0):
"Collect data into fixed-length chunks or blocks"
# grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx"
args = [iter(iterable)] * n
return zip_longest(*args, fillvalue=fillvalue)
There is thus no need to manually handle the padding in encode
and you can simplify iteration:
for first, second, third in grouper(unpadded_buf, 3):
encoded_padded_buf += [ first >> 2 ]
encoded_padded_buf += [ (first & 3) << 4 | second >> 4 ]
encoded_padded_buf += [ (second & 15) << 2 | third >> 6 ]
encoded_padded_buf += [ third & 63]
It also has the advantage to handle any iterable the same way. Meaning you don't need to convert your bytes
to list
beforehand, it will do just fine.
Extra characters
You don't need to differentiate the case where len(unpadded_buf) % 3
is 1 or 2. You can perform the same task using:
extra = len(unpadded_buf) % 3
if extra:
return asciied[-extra:] + '='*extra
return asciied
You can also note that there is no reason to make a copy of asciied
, like you don't make a copy of decoded_buf
in decode
.
ALPHABET?
It feels like this variable is misnamed, it is not an alphabet, it is a convertion table TO_BASE64
. Same for _alphabet
, I would have named it FROM_BASE64
and computed it once, as a constant, as its content would never be different.
I would also incorporate {'=': 0}
to simplify the map
call.
So far, the code could be:
import string
from itertools import zip_longest
TO_BASE64 = dict(zip(range(64),
string.ascii_uppercase + string.ascii_lowercase +
string.digits + '+/'))
FROM_BASE64 = {v: k for k, v in TO_BASE64.items()}
FROM_BASE64['='] = 0
def grouper(iterable, n, fillvalue=0):
"Collect data into fixed-length chunks or blocks"
# grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx"
args = [iter(iterable)] * n
return zip_longest(*args, fillvalue=fillvalue)
def encode(unpadded_buf):
encoded_buf = []
for first, second, third in grouper(unpadded_buf, 3):
encoded_buf += [ first >> 2 ]
encoded_buf += [ (first & 3) << 4 | second >> 4 ]
encoded_buf += [ (second & 15) << 2 | third >> 6 ]
encoded_buf += [ third & 63]
asciied = ''.join(TO_BASE64[n] for n in encoded_buf)
extra = len(unpadded_buf) % 3
if extra:
return asciied[-extra:] + '='*extra
return asciied
def decode(asciied):
encoded_buf = (FROM_BASE64[n] for n in asciied)
decoded_buf = []
for first, second, third, fourth in grouper(encoded_buf, 4):
decoded_buf += [ first << 2 | second >> 4 ]
decoded_buf += [ (second & 15) << 4 | third >> 2 ]
decoded_buf += [ (third & 3) << 6 | fourth ]
return bytes(decoded_buf)
if __name__ == '__main__':
assert decode(encode(b'foo')) == b'foo'
Adding elements to a list
You extend your "result" lists using other lists. It is rather inefficient as the list you use to extend the first one will be discarded right away. You’d be better of append
ing values to the "result" instead. Or use helper generator functions to build results as they are computed without having to populate a list that will be discarded right after:
import string
from itertools import zip_longest
TO_BASE64 = dict(zip(range(64),
string.ascii_uppercase + string.ascii_lowercase +
string.digits + '+/'))
FROM_BASE64 = {v: k for k, v in TO_BASE64.items()}
FROM_BASE64['='] = 0
def grouper(iterable, n, fillvalue=0):
"Collect data into fixed-length chunks or blocks"
# grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx"
args = [iter(iterable)] * n
return zip_longest(*args, fillvalue=fillvalue)
def _encoder(stream):
for first, second, third in grouper(stream, 3):
yield first >> 2
yield (first & 3) << 4 | second >> 4
yield (second & 15) << 2 | third >> 6
yield third & 63
def encode(unpadded_buf):
asciied = ''.join(TO_BASE64[n] for n in _encoder(unpadded_buf))
extra = len(unpadded_buf) % 3
if extra:
return asciied[-extra:] + '='*extra
return asciied
def _decoder(stream):
for first, second, third, fourth in grouper(stream, 4):
yield first << 2 | second >> 4
yield (second & 15) << 4 | third >> 2
yield (third & 3) << 6 | fourth
def decode(asciied):
return bytes(_decoder(FROM_BASE64[n] for n in asciied))
if __name__ == '__main__':
assert decode(encode(b'foo')) == b'foo'
Explore related questions
See similar questions with these tags.