After watching Tom Scott explain Huffman coding in this YouTube video, I wanted to implement it myself. I want to use this project to further my understanding of Python. Additionally this tool should be easily usable on the command line e.g. normally input is read from stdin and output is written to stdout.
Did I miss some obvious more Pythonic way to do something? Could I increase the usability on the command line somehow?
#!/usr/bin/env python3
"""Encode or decode text with Huffman Coding.
The program reads from stdin and writes to stdout if no input or output file is given.
positional arguments:
{decode,encode} decode or encode
optional arguments:
-h, --help show this help message and exit
-i IN, --in IN the input file
-o OUT, --out OUT the output file
"""
import argparse
import os
import string
import sys
from collections import Counter
from functools import singledispatchmethod
class Huffman:
"""A node of a binary tree saving characters in its leaves."""
def __init__(self, char: str = None, weight: int = None, left: 'Huffman' = None, right: 'Huffman' = None):
self.codes = {}
if left is None and right is None and char is None:
raise StateError("A node needs either a char or at least one child.")
self.char = char
self.weight = weight
self.left = left
self.right = right
@classmethod
def from_binary_string(cls, data: str) -> ('Huffman', int):
"""Reconstruct a Huffman tree from a string containing binary data.
:param data: A string containing a binary representation of a Huffman tree as prefix
:return: A Huffman tree and the length of its binary representation in bits
"""
if not all(char in "01" for char in data):
raise CharsetError("Only '0' and '1' are allowed in a binary string.")
if data[:2] == '00':
tree, length = cls(char=chr(int(data[2:10], 2))), 10
elif data[:2] == '01':
right, r_length = cls.from_binary_string(data[2:])
tree, length = cls(right=right), 2 + r_length
elif data[:2] == '10':
left, l_length = cls.from_binary_string(data[2:])
tree, length = cls(left=left), 2 + l_length
else:
left, l_length = cls.from_binary_string(data[2:])
right, r_length = cls.from_binary_string(data[l_length + 2:])
tree, length = cls(left=left, right=right), 2 + l_length + r_length
tree.generate_codes('')
return tree, length
@classmethod
def from_bytes(cls, data: bytes) -> ('Huffman', int):
"""Construct a Huffman tree from a bytes-like object.
:param data: A bytes-like object containing a binary encoded Huffman tree as prefix
:return: A Huffman tree and the length of its binary representation in bits
"""
return Huffman.from_binary_string(format(int.from_bytes(data, byteorder='big'), 'b'))
@classmethod
def from_counter(cls, cnt: Counter) -> 'Huffman':
"""Construct a Huffman tree from a :py:class:`Counter` that uses characters as keys.
Only printable ASCII characters are allowed as keys in the counter.
:param cnt: A counter containing only printable ASCII characters as keys
:return: A Huffman tree
"""
if not all(char in string.printable for char in cnt.keys()):
raise CharsetError("Only printable ASCII characters are allowed.")
counts = cnt.most_common()
counts.reverse()
nodes = [Huffman(char=char, weight=weight) for char, weight in counts]
nodes.sort(key=lambda _: _.weight)
while len(nodes) > 1:
nodes.append(Huffman(weight=nodes[0].weight + nodes[1].weight, left=nodes[0], right=nodes[1]))
del nodes[0:2]
nodes.sort(key=lambda _: _.weight)
nodes[0].generate_codes('')
return nodes[0]
@classmethod
def from_string(cls, data: str) -> 'Huffman':
"""Construct a Huffman tree from a string.
Only printable ASCII characters are allowed.
:param data: A string containing only printable ASCII characters
:return: A Huffman tree
"""
if not all(char in string.printable for char in data):
raise CharsetError("Only printable ASCII characters are allowed.")
cnt = Counter(data)
return cls.from_counter(cnt)
@singledispatchmethod
def decode(self, data) -> str:
"""Decode a bytes-like object or string containing binary data.
:param data: A bytes-like object or a string containing binary data
:return: A string containing the decoded text
"""
raise NotImplementedError("Cannot decode an object")
@decode.register
def decode_from_bytes(self, data: bytes, tree_length: int) -> str:
"""Decode a bytes-like object encoding a Huffman tree as prefix of length tree_length and the encoded text.
:param data: The bytes-like object encoding the tree and text
:param tree_length: The length of the tree in bits
:return: A string containing the decoded text
"""
if not self.codes:
raise CodesError()
return self.decode(format(int.from_bytes(data, byteorder='big'), 'b')[tree_length:])
@decode.register
def decode_from_string(self, data: str) -> str:
"""Decode a string containing binary data.
:param data: A string containing binary data
:return: A string containing the decoded text
"""
if not self.codes:
raise CodesError()
if not all(char in '01' for char in data):
raise CharsetError("Only binary data is allowed.")
decoded = ''
node = self
if len(self.codes) > 1:
for bit in data:
if node.char is not None:
decoded += node.char
node = self
if bit == '0':
node = node.left
elif bit == '1':
node = node.right
decoded += node.char
else:
decoded = node.char * len(data)
return decoded
def encode(self, data: str) -> str:
"""Encode a string according to this tree.
:param data: The string to be encoded
:return: The encoded data as string containing binary data
"""
if not self.codes:
raise CodesError()
if not all(char in self.codes.keys() for char in data):
raise CharsetError()
if len(self.codes) > 1:
encoded = ''.join(self.codes[char] for char in data)
else:
encoded = f"{len(data):b}"
return encoded
def as_binary(self, recursive: bool = True) -> str:
"""Encode this tree as binary data.
:param recursive: Whether only the state of this node or the whole tree should be encoded
:return: This tree encoded in binary
"""
if self.char is None:
if self.left is None:
if self.right is None:
raise StateError()
else:
ret = "01" + self.right.as_binary() if recursive else "01"
else:
if self.right is None:
ret = "10" + self.left.as_binary() if recursive else "10"
else:
ret = "11" + self.left.as_binary() + self.right.as_binary() if recursive else "11"
else:
ret = "00" + "{0:08b}".format(ord(self.char))
return ret
def generate_codes(self, path: str):
"""Generate a binary representation of the characters saved in this (sub-)tree.
Recursively follow the tree structure. When this node has a character saved, update the codes dictionary
using this character as key and the path taken to get here as value. When following a right child,
add a '1' to the path already taken to get here. When following a left child, add a '0' to the path already
taken to get here. After either or both of these update the codes dictionary with the codes dictionaries of
the children.
:param path: The path taken to get to this node
"""
if self.char is not None:
self.codes.update({self.char: path})
else:
if self.right is not None:
self.right.generate_codes(path + '1')
self.codes.update(self.right.codes)
if self.left is not None:
self.left.generate_codes(path + '0')
self.codes.update(self.left.codes)
def __repr__(self):
return "{0}(char={1}, weight={2}, left={3!r}, right={4!r})".format(
type(self).__name__,
self.char if self.char is None else f'"{self.char}"',
self.weight,
self.left,
self.right)
def __str__(self):
return "({0}: {1}, {2}, <: {3}, >: {4})".format(self.char,
self.weight,
self.as_binary(recursive=False),
self.left,
self.right)
def __eq__(self, other):
return self.char == other.char and self.right == other.right and self.left == other.left
class CodesError(Exception):
"""
Throw when no codes were generated before attempting to en- or decode something.
"""
def __init__(self, message: str = None):
if message is None:
self.message = "There are no codes generated for this tree."
else:
self.message = message
class CharsetError(Exception):
"""
Throw when an illegal character is in some input.
"""
def __init__(self, message: str = None):
if message is None:
self.message = "At least one of the characters in the input string is not represented in the tree."
else:
self.message = message
class StateError(Exception):
"""
Throw when a node is in an impossible state.
"""
def __init__(self, message: str = None):
if message is None:
self.message = "Impossible state of a node."
else:
self.message = message
def eprint(*args, **kwargs):
"""Write messages into the standard error stream
:param args: The objects to print
:param kwargs: Keyword arguments for print
"""
print(*args, file=sys.stderr, **kwargs)
def main():
"""
The main function used to avoid polluting the global scope with variables
"""
parser = argparse.ArgumentParser(description="Encode or decode text with Huffman Coding.")
parser.add_argument("action", help="decode or encode", choices=['decode', 'encode'])
parser.add_argument("-i", "--in", help="the input file",
type=argparse.FileType('r' if 'encode' in sys.argv else 'rb'), default=sys.stdin, dest="input",
metavar="IN")
parser.add_argument("-o", "--out", help="the output file",
type=argparse.FileType("wb" if 'encode' in sys.argv else "w"), default=sys.stdout,
dest="output", metavar="OUT")
args = parser.parse_args()
if args.action == 'encode':
input_string = args.input.read()
if input_string == "":
eprint("The input was empty")
sys.exit(1)
else:
tree = Huffman.from_string(input_string)
state = tree.as_binary()
encoded = tree.encode(input_string)
message = state + encoded
with os.fdopen(sys.stdout.fileno(), "wb", closefd=False) if args.output is sys.stdout else args.output as out:
out.write(int(message, 2).to_bytes((len(message) + 7) // 8, 'big'))
else:
input_bytes = args.input.read()
if input_bytes == b"":
eprint("The input was empty")
sys.exit(1)
tree, tree_length = Huffman.from_bytes(input_bytes)
decoded = tree.decode(input_bytes, tree_length)
with args.output as out:
out.write(decoded)
if __name__ == "__main__":
main()
1 Answer 1
Good job!
- I like the use of
classmethod
s to allow different constructors. - You have typed a significant amount of the code.
- You have a nice amount of documentation.
- You have input validation on most of your functions.
Improvements
Most of these can be seen as nitpicks or alternate perspectives. Your code is pretty good. Nice job!
Whilst your code is almost fully statically typed, you're not quite there for mypy in strict mode. Since mypy's main goal is to help convert people from untyped code to typed code many of the checks don't run by default. This is because it'd be demoralizing fixing hundreds or thousands of issues just to get mypy to not complain.
You've not defined a return type for
generate_codes
,eprint
,main
and many double-under (dunder) methods.You're relying on mypy to automatically apply
Optional
.def __init__(self, char: str = None, weight: int = None, left: 'Huffman' = None, right: 'Huffman' = None):
You should use
typing.Tuple
rather than "('Huffman', int)
" to specify returning a tuple. Since we can just usetuple
in Python 3.9 I'll be using that in the below code.
If you're running Python 3.7+ then we can remove the need to use "
'Huffman'
" by postponing evaluation of annotations. We can do that by importingannotations
from__future__
.I would split the tree and the Huffman interface into two separate classes. To store the tree you can just define a simple
Node
class:@dataclasses.dataclass class Node: weight: int char: Optional[str] = None left: "Optional[Node]" = None right: "Optional[Node]" = None
generate_codes
is nice, it's roughly how I'd do it.
However I'd define it onNode
and make it work the same way thatitems()
does on dictionaries.This gives users a familiar interface and doesn't give them a full blown dictionary. Which they can make if needed.
class Node: ... def items(self): yield from self._items('') def _items(self, path): if self.char is not None: yield path, self.char else: yield from self.left._items(path + '0') yield from self.right._items(path + '1')
We can change
from_counter
to useheapq
so we don't need to call.sort()
all the time.By adding the
__lt__
dunder toNode
we can just enterNode
s into the heap and it'll play ball.heap = [] for char, weight in collections.Counter(text).items(): heapq.heappush(heap, Node(weight, char)) while 1 < len(heap): right = heapq.heappop(heap) left = heapq.heappop(heap) node = Node(left.weight + right.weight, None, left, right) heapq.heappush(heap, node) root = heap[0]
The function
decode_from_string
is pretty good. I'm not a fan of the premature optimization "decoded = node.char * len(data)
".The way I'd do it includes abusing
__getitem__
and iterators to consume the text whilst getting the values. I think your way is much easier to read and understand. However I will include it below so you can see this magic.I don't think the user should call
generate_codes
. If this is needed you should build it, and cache it toself._codes
.This just removes an unneeded step for the user of your class.
I think
from_bytes
is smart and cool. Nice!I'm not a fan of
as_binary
as all those string concatenations could get expensive. Assuming CPython isn't nice and makes string concatenation run in \$O(1)\$ time.To not rely on this I'd change to using a private generator function that you then just call
''.join
on in the public one. (Likeitems
above.)I'm not a fan of how you define most of your exceptions. Having a default message kinda makes sense. However it makes your exceptions function differently to Python's exceptions where you have to provide the message.
If this is because you want to DRY the messages then you can move them into a global constant.
Again your code is pretty good.
Here is the, really hacky, solution I came up with when trying to learn how Huffman
works.
from __future__ import annotations
import collections
import dataclasses
import heapq
from collections.abc import Iterator
from pprint import pprint
from typing import Optional
@dataclasses.dataclass
class Node:
weight: int
char: Optional[str] = None
left: Optional[Node] = None
right: Optional[Node] = None
def __lt__(self, other: Node) -> bool:
return self.weight < other.weight
def __getitem__(self, key: str) -> str:
if self.char is not None:
return self.char
key = iter(key)
if next(key) == '0':
return self.left[key]
else:
return self.right[key]
def items(self) -> Iterator[tuple[str, str]]:
yield from self._items('')
def _items(self, path) -> Iterator[tuple[str, str]]:
if self.char is not None:
yield path, self.char
else:
yield from self.left._items(path + '0')
yield from self.right._items(path + '1')
class Huffman:
_tree: Node
_graph: dict[str, str]
def __init__(self, tree: Node) -> None:
self._tree = tree
self._graph = None
@classmethod
def from_text(cls, text: str) -> Huffman:
heap = []
for char, weight in collections.Counter(text).items():
heapq.heappush(heap, Node(weight, char))
while 1 < len(heap):
right = heapq.heappop(heap)
left = heapq.heappop(heap)
node = Node(
weight=left.weight + right.weight,
left=left,
right=right,
)
heapq.heappush(heap, node)
return cls(heap[0])
def encode(self, text: str) -> str:
graph = self._graph
if graph is None:
self._graph = graph = {c: p for p, c in self._tree.items()}
return ''.join(
graph[letter]
for letter in text
)
def decode(self, text: str) -> str:
return ''.join(self._decode(iter(text)))
def _decode(self, text: Iterator[str]) -> Iterator[str]:
try:
while True:
yield self._tree[text]
except StopIteration:
pass
if __name__ == '__main__':
text = 'abcdeaba'
huff = Huffman.from_text(text)
encoded = huff.encode(text)
print(encoded)
print(huff.decode(encoded))