\$\begingroup\$
\$\endgroup\$
Follow up of this question.
I implemented most suggestions from the answers of the previous question.
Code:
#[macro_use]
extern crate quick_error;
extern crate byteorder;
#[cfg(test)]
#[macro_use]
extern crate quickcheck;
use std::collections::hash_map::HashMap;
use std::collections::BinaryHeap;
use byteorder::{ReadBytesExt, WriteBytesExt, LittleEndian};
use std::{result, io, path, cmp, fs};
use std::io::prelude::*;
type Result<T> = result::Result<T, HuffmanError>;
const BITS: usize = 8;
const DEFAULT_BUFFER_SIZE: usize = 8192;
macro_rules! try_opt {
($e:expr) =>(
match $e {
Ok(v) => v,
Err(err) => return Some(Err(err)),
}
)
}
quick_error! {
#[derive(Debug)]
pub enum HuffmanError {
Io(err: io::Error) {
from()
cause(err)
description(err.description())
display("IO error: {}", err)
}
ParseTree {
description("Parse tree error: invalid encoded huffman tree")
}
AlphabetMismatch {
description("Alphabet mismatch error: alphabet doesn't match parsed tree")
}
Empty {
description("Empty stream")
}
}
}
struct DecodeBitIterator<'a> {
root: &'a HuffmanTree,
current: &'a HuffmanTree,
}
impl<'a> DecodeBitIterator<'a> {
fn next(&mut self, bit: bool) -> Option<u8> {
use HuffmanTree::*;
let (next, result) = match (self.root, self.current.next(bit)) {
(_, Some(&Leaf { byte, .. })) |
(&Leaf { byte, .. }, None) => (self.root, Some(byte)),
(_, Some(inner)) => (inner, None),
_ => panic!("assertion error"),
};
self.current = next;
result
}
}
fn create_mode_map<R: std::io::Read>(reader: &mut R,
buffer_size: usize)
-> Result<HashMap<u8, u32>> {
let mut mode_map = HashMap::new();
let mut buf = vec![0; buffer_size];
loop {
let size = reader.read(&mut buf)?;
if size == 0 {
break;
}
for &b in buf.iter().take(size) {
*(mode_map.entry(b).or_insert(0)) += 1;
}
}
Ok(mode_map)
}
fn iterate_bits<'a>(bits: &'a [u8], size: usize) -> BitVectorIter<'a> {
let itr = bits.iter()
.flat_map(|byte| (0..BITS).map(move |i| ((byte >> i) & 1) != 0))
.take(size);
BitVectorIter { iter: Box::new(itr) }
}
#[derive(Eq, Debug)]
pub enum HuffmanTree {
Inner {
frequency: u32,
left: Box<HuffmanTree>,
right: Box<HuffmanTree>,
},
Leaf { frequency: u32, byte: u8 },
}
impl HuffmanTree {
pub fn new<R: std::io::Read>(reader: &mut R, buffer_size: usize) -> Result<HuffmanTree> {
let mut pq: BinaryHeap<_> = create_mode_map(reader, buffer_size)
?
.into_iter()
.map(|(c, f)| {
HuffmanTree::Leaf {
byte: c,
frequency: f,
}
})
.collect();
for _ in 1..pq.len() {
let min1 = pq.pop().expect("assertion error");
let min2 = pq.pop().expect("assertion error");
pq.push(min1.join(min2));
}
pq.pop().ok_or(HuffmanError::Empty)
}
fn freq(&self) -> u32 {
use HuffmanTree::*;
match *self {
Inner { frequency, .. } |
Leaf { frequency, .. } => frequency,
}
}
fn join(self, other: HuffmanTree) -> HuffmanTree {
HuffmanTree::Inner {
frequency: self.freq() + other.freq(),
left: Box::new(self),
right: Box::new(other),
}
}
fn create_mapper_recur(&self, bit_vec: &mut BitVector, map: &mut HashMap<u8, BitVector>) {
use HuffmanTree::*;
match *self {
Inner { ref left, ref right, .. } => {
bit_vec.push(true);
left.create_mapper_recur(bit_vec, map);
bit_vec.push(false);
right.create_mapper_recur(bit_vec, map);
}
Leaf { byte, .. } => {
map.insert(byte, bit_vec.clone());
}
}
bit_vec.pop();
}
fn create_mapper(&self) -> HashMap<u8, BitVector> {
let mut bit_vec = BitVector::new();
let mut map = HashMap::new();
if let HuffmanTree::Leaf { .. } = *self {
bit_vec.push(true);
}
self.create_mapper_recur(
&mut bit_vec,
&mut map);
map
}
fn decode<I, T>(encoded_walk: &mut I, mut bytes: &mut T) -> Result<HuffmanTree>
where I: Iterator<Item = bool>,
T: Iterator<Item = u8>
{
match encoded_walk.next() {
Some(true) => {
let left = Self::decode(encoded_walk, bytes)?;
let right = Self::decode(encoded_walk, bytes)?;
Ok(left.join(right))
}
Some(false) => {
let c = bytes.next()
.ok_or(HuffmanError::AlphabetMismatch)?;
Ok(HuffmanTree::Leaf {
frequency: 0,
byte: c,
})
}
None => Err(HuffmanError::ParseTree),
}
}
pub fn serialize<W: io::Write>(&self, writer: &mut W) -> Result<()> {
let (encoded_walk, alphabet_bytes) = self.encode();
let walk_bit_len = encoded_walk.len() as u64;
writer.write_u64::<LittleEndian>(walk_bit_len)?;
writer.write_all(&encoded_walk.bits)?;
let alphabet_byte_len = alphabet_bytes.len() as u64;
writer.write_u64::<LittleEndian>(alphabet_byte_len)?;
writer.write_all(&alphabet_bytes)?;
Ok(())
}
pub fn deserialize<R: io::Read>(reader: &mut R) -> Result<HuffmanTree> {
let walk_bit_len = reader.read_u64::<LittleEndian>()?;
let walk_byte_len = (walk_bit_len + BITS as u64 - 1) / BITS as u64;
let mut walk_bytes = Vec::new();
reader.take(walk_byte_len)
.read_to_end(&mut walk_bytes)?;
let alphabet_bytes_len = reader.read_u64::<LittleEndian>()?;
let mut alphabet_bytes = Vec::new();
reader.take(alphabet_bytes_len)
.read_to_end(&mut alphabet_bytes)?;
let bit_vec = BitVector {
bits: walk_bytes,
size: walk_bit_len as usize,
};
let bit_vec_iter = &mut bit_vec.iter();
Self::decode(
bit_vec_iter,
&mut alphabet_bytes.iter().cloned())
}
fn encode_recur(&self, bit_vec: &mut BitVector, alphabet_bytes: &mut Vec<u8>) {
use HuffmanTree::*;
match *self {
Inner { ref left, ref right, .. } => {
bit_vec.push(true);
left.encode_recur(bit_vec, alphabet_bytes);
right.encode_recur(bit_vec, alphabet_bytes);
}
Leaf { byte, .. } => {
bit_vec.push(false);
alphabet_bytes.push(byte);
}
}
}
fn encode(&self) -> (BitVector, Vec<u8>) {
let mut bit_vector = BitVector::new();
let mut alphabet_bytes = Vec::new();
self.encode_recur(&mut bit_vector, &mut alphabet_bytes);
(bit_vector, alphabet_bytes)
}
fn decode_bit_iter<'a>(&'a self) -> DecodeBitIterator {
DecodeBitIterator {
root: self,
current: self,
}
}
pub fn encode_iter<'a, R: io::Read>(&self,
reader: &'a mut R,
buffer_size: usize)
-> EncodeIterator<'a, R> {
EncodeIterator {
reader: reader,
mapper: self.create_mapper(),
buffer: vec![0u8; buffer_size],
}
}
pub fn decode_iter<'a, 'b, R: io::Read>(&'b self,
reader: &'a mut R,
buffer_size: u64,
bit_junk: usize)
-> io::Result<DecodeIterator<'a, 'b, R>> {
let mut last_buffer = Vec::new();
reader.take(buffer_size)
.read_to_end(&mut last_buffer)?;
let itr = DecodeIterator {
reader: reader,
bit_mapper: self.decode_bit_iter(),
buffer: Vec::new(),
buffer_size: buffer_size,
last_buffer: last_buffer,
bit_junk: bit_junk,
};
Ok(itr)
}
fn next(&self, bit: bool) -> Option<&HuffmanTree> {
use HuffmanTree::*;
match *self {
Inner { ref left, .. } if bit => Some(left),
Inner { ref right, .. } => Some(right),
Leaf { .. } => None,
}
}
}
pub struct EncodeIterator<'a, R: io::Read + 'a> {
reader: &'a mut R,
mapper: HashMap<u8, BitVector>,
buffer: Vec<u8>,
}
impl<'a, R: io::Read> Iterator for EncodeIterator<'a, R> {
type Item = io::Result<BitVector>;
fn next(&mut self) -> Option<Self::Item> {
let size = try_opt!(self.reader.read(&mut self.buffer));
if size == 0 {
return None;
}
let mut bit_vec = BitVector::new();
let itr = self.buffer
.iter()
.take(size)
.map(|c| {
self.mapper
.get(&c)
.expect("assertion error")
});
for code in itr {
bit_vec.append(&code);
}
Some(Ok(bit_vec))
}
}
pub struct DecodeIterator<'a, 'b, R: io::Read + 'a> {
reader: &'a mut R,
bit_mapper: DecodeBitIterator<'b>,
buffer: Vec<u8>,
buffer_size: u64,
last_buffer: Vec<u8>,
bit_junk: usize,
}
impl<'a, 'b, R: io::Read> Iterator for DecodeIterator<'a, 'b, R> {
type Item = io::Result<Vec<u8>>;
fn next(&mut self) -> Option<Self::Item> {
fn bool_mask(b: bool) -> usize {
(((b as i32) << 31) >> 31) as usize
}
let size = try_opt!(self.reader
.take(self.buffer_size)
.read_to_end(&mut self.buffer));
if size == 0
&& self.last_buffer.is_empty() {
return None;
}
let bit_len = BITS * self.last_buffer.len()
- (bool_mask(size == 0)
& self.bit_junk);
let mut bytes = Vec::new();
for bit in iterate_bits(&self.last_buffer, bit_len) {
let byte = match self.bit_mapper.next(bit) {
Some(byte) => byte,
None => continue,
};
bytes.push(byte);
}
std::mem::swap(
&mut self.buffer,
&mut self.last_buffer);
self.buffer.clear();
Some(Ok(bytes))
}
}
impl Ord for HuffmanTree {
fn cmp(&self, other: &HuffmanTree) -> cmp::Ordering {
self.freq().cmp(&other.freq()).reverse()
}
}
impl PartialOrd for HuffmanTree {
fn partial_cmp(&self, other: &HuffmanTree) -> Option<cmp::Ordering> {
Some(self.cmp(other))
}
}
impl PartialEq for HuffmanTree {
fn eq(&self, other: &HuffmanTree) -> bool {
self.freq() == other.freq()
}
}
pub trait HuffmanCompress {
fn compress<T>(&mut self, writer: &mut T) -> Result<()>
where T: io::Write + io::Seek
{
self.compress_with_buffer_size(writer, DEFAULT_BUFFER_SIZE)
}
fn compress_with_buffer_size<T>(&mut self, &mut T, buffer: usize) -> Result<()>
where T: io::Write + io::Seek;
}
pub trait HuffmanDeCompress {
fn decompress<T>(&mut self, writer: &mut T) -> Result<()>
where T: io::Write
{
self.decompress_with_buffer_size(writer, DEFAULT_BUFFER_SIZE as u64)
}
fn decompress_with_buffer_size<T>(&mut self, writer: &mut T, buffer_size: u64) -> Result<()>
where T: io::Write;
}
pub trait HuffmanCodes: HuffmanDeCompress + HuffmanCompress {}
impl<T: io::Read + io::Seek> HuffmanCodes for T {}
impl<T: io::Read + io::Seek> HuffmanCompress for T {
fn compress_with_buffer_size<W>(&mut self, writer: &mut W, buffer_size: usize) -> Result<()>
where W: io::Write + io::Seek
{
use std::io::SeekFrom;
let tree = HuffmanTree::new(self, buffer_size)?;
tree.serialize(writer)?;
self.seek(SeekFrom::Start(0))?;
let pos = writer.seek(SeekFrom::Current(0))?;
writer.write_i8(0)?;
let mut bit_vec = BitVector::new();
for code in tree.encode_iter(self, buffer_size) {
bit_vec.append(&code?);
let bit_leftover = bit_vec.len() % BITS;
let leftover = (bit_leftover != 0) as usize;
let byte_len = bit_vec.byte_len();
writer.write_all(&bit_vec.bits[0..byte_len - leftover])?;
bit_vec.bits.swap(0, byte_len - 1);
bit_vec.bits.truncate(leftover);
bit_vec.size = bit_leftover;
}
writer.write_all(&bit_vec.bits)?;
let junk = BITS - bit_vec.len() % BITS;
let mask = (((junk != BITS) as i8) << (BITS - 1)) >> (BITS - 1);
let current = writer.seek(SeekFrom::Current(0))?;
writer.seek(SeekFrom::Start(pos))?;
writer.write_i8(junk as i8 & mask)?;
writer.seek(SeekFrom::Start(current))?;
Ok(())
}
}
impl<T: std::io::Read> HuffmanDeCompress for T {
fn decompress_with_buffer_size<W>(&mut self, writer: &mut W, buffer_size: u64) -> Result<()>
where W: io::Write
{
let tree = HuffmanTree::deserialize(self)?;
let junk = self.read_i8()? as usize;
for bytes in tree.decode_iter(self, buffer_size, junk)? {
writer.write_all(&bytes?)?
}
Ok(())
}
}
pub fn io_map<P, T, F>(path: P, target: T, mapping_function: F) -> Result<()>
where P: AsRef<path::Path>,
T: AsRef<path::Path>,
F: Fn(&mut io::BufReader<fs::File>, &mut io::BufWriter<fs::File>) -> Result<()>
{
let f1 = fs::File::open(path)?;
let mut reader = io::BufReader::new(f1);
let f2 = fs::File::create(target)?;
let mut writer = io::BufWriter::new(f2);
mapping_function(&mut reader, &mut writer)?;
Ok(())
}
pub fn decode<P, T>(path: P, target: T) -> Result<()>
where P: AsRef<path::Path>,
T: AsRef<path::Path>
{
io_map(path, target, |r, w| r.decompress(w))
}
pub fn encode<P, T>(path: P, target: T) -> Result<()>
where P: AsRef<path::Path>,
T: AsRef<path::Path>
{
io_map(path, target, |r, w| r.compress(w))
}
#[derive(Clone, Debug)]
pub struct BitVector {
bits: Vec<u8>,
size: usize,
}
impl BitVector {
pub fn new() -> BitVector {
BitVector {
bits: Vec::new(),
size: 0,
}
}
pub fn push(&mut self, bit: bool) {
let leftover = self.size % BITS;
if leftover == 0 {
self.bits.push(0);
}
let last_byte = self.bits
.last_mut()
.expect("Assertion error");
*last_byte |= (bit as u8) << leftover;
self.size += 1;
}
#[allow(dead_code)]
pub fn clear(&mut self) {
self.bits.clear();
self.size = 0;
}
pub fn pop(&mut self) {
if self.len() == 0 {
return;
}
let len = self.size - 1;
self.put(len, false);
self.size = len;
if self.len() % BITS == 0 {
self.bits.pop();
}
}
#[allow(dead_code)]
pub fn push_all(&mut self, bits: &[bool]) {
for bit in bits {
self.push(*bit);
}
}
pub fn append(&mut self, other: &BitVector) {
let leftover = self.size % BITS;
let empty_bits = BITS - leftover;
let len = self.bits.len();
self.bits.extend(other.bits.iter().cloned());
self.size += other.size;
if leftover == 0 {
return;
}
for i in len..self.bits.len() {
self.move_bits(empty_bits, i);
}
if (self.size - 1) / BITS != self.bits.len() - 1 {
self.bits.pop();
}
}
pub fn check(&self, i: usize) {
assert!(i < self.size,
format!("Index out of bounds. Index: {} >= len: {}", i, self.size));
}
#[allow(dead_code)]
pub fn get(&self, i: usize) -> bool {
self.check(i);
(1 & (self.bits[i / BITS] >> (i % BITS))) != 0
}
pub fn put(&mut self, i: usize, bit: bool) {
self.check(i);
let (byte, bit_pos) = (i / BITS, i % BITS);
self.bits[byte] &= !(1 << bit_pos);
self.bits[byte] |= (bit as u8) << bit_pos;
}
fn move_bits(&mut self, bits: usize, i: usize) {
let leftover = BITS - bits;
let (from, to) = (self.bits[i], self.bits[i - 1]);
self.bits[i - 1] = to ^ (from << leftover);
self.bits[i] = from >> bits;
}
pub fn iter<'a>(&'a self) -> BitVectorIter<'a> {
iterate_bits(&self.bits, self.size)
}
pub fn len(&self) -> usize {
self.size
}
#[allow(dead_code)]
pub fn empty(&self) -> bool {
self.size == 0
}
#[allow(dead_code)]
pub fn byte_len(&self) -> usize {
self.bits.len()
}
}
pub struct BitVectorIter<'a> {
iter: Box<std::iter::Iterator<Item = bool> + 'a>,
}
impl<'a> IntoIterator for &'a BitVector {
type Item = bool;
type IntoIter = BitVectorIter<'a>;
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}
impl<'a> Iterator for BitVectorIter<'a> {
type Item = bool;
fn next(&mut self) -> Option<Self::Item> {
self.iter.next()
}
}
Test:
#[cfg(test)]
mod test {
use quickcheck::TestResult;
use std::io;
use super::*;
use std::io::prelude::*;
fn read_stream<W: io::Read>(reader: &mut W) -> io::Result<String> {
let mut result = String::new();
reader.read_to_string(&mut result)?;
Ok(result)
}
quickcheck! {
fn prop(text: String) -> super::Result<TestResult> {
if text.is_empty() {
return Ok(TestResult::discard());
}
let mut mock_input_stream = io::Cursor::new(text.into_bytes());
let mut mock_compress_stream = io::Cursor::new(Vec::new());
let mut mock_decompress_stream = io::Cursor::new(Vec::new());
mock_input_stream.compress(&mut mock_compress_stream)?;
mock_compress_stream.seek(io::SeekFrom::Start(0))?;
mock_compress_stream.decompress(&mut mock_decompress_stream)?;
mock_input_stream.seek(io::SeekFrom::Start(0))?;
mock_decompress_stream.seek(io::SeekFrom::Start(0))?;
let expected = read_stream(&mut mock_input_stream)?;
let test = read_stream(&mut mock_decompress_stream)?;
Ok(TestResult::from_bool(expected == test))
}
}
}
lang-rust